27
27
28
28
PING_URL = "http://localhost:8080/ping"
29
29
INVOCATION_URL = "http://localhost:8080/models/{}/invoke"
30
- MODEL_NAME = "half_plus_three"
30
+ MODEL_NAMES = [ "half_plus_three" , "half_plus_two" ]
31
31
32
32
33
33
@pytest .fixture (scope = "session" , autouse = True )
@@ -74,13 +74,14 @@ def container(docker_base_name, tag, runtime_config):
74
74
75
75
76
76
@pytest .fixture
77
- def model ():
78
- model_data = {
79
- "model_name" : MODEL_NAME ,
80
- "url" : "/opt/ml/models/half_plus_three/model/half_plus_three"
81
- }
82
- make_load_model_request (json .dumps (model_data ))
83
- return MODEL_NAME
77
+ def models ():
78
+ for MODEL_NAME in MODEL_NAMES :
79
+ model_data = {
80
+ "model_name" : MODEL_NAME ,
81
+ "url" : "/opt/ml/models/{}/model/{}" .format (MODEL_NAME ,MODEL_NAME )
82
+ }
83
+ make_load_model_request (json .dumps (model_data ))
84
+ return MODEL_NAMES
84
85
85
86
86
87
@pytest .mark .skip_gpu
@@ -90,20 +91,25 @@ def test_ping_service():
90
91
91
92
92
93
@pytest .mark .skip_gpu
93
- def test_predict_json (model ):
94
+ def test_predict_json (models ):
94
95
headers = make_headers ()
95
96
data = "{\" instances\" : [1.0, 2.0, 5.0]}"
96
- response = requests .post (INVOCATION_URL .format (model ), data = data , headers = headers ).json ()
97
- assert response == {"predictions" : [3.5 , 4.0 , 5.5 ]}
97
+ responses = []
98
+ for model in models :
99
+ response = requests .post (INVOCATION_URL .format (model ), data = data , headers = headers ).json ()
100
+ responses .append (response )
101
+ assert responses [0 ] == {"predictions" : [3.5 , 4.0 , 5.5 ]}
102
+ assert responses [1 ] == {"predictions" : [2.5 , 3.0 , 4.5 ]}
98
103
99
104
100
105
@pytest .mark .skip_gpu
101
106
def test_zero_content ():
102
107
headers = make_headers ()
103
108
x = ""
104
- response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = x , headers = headers )
105
- assert 500 == response .status_code
106
- assert "document is empty" in response .text
109
+ for MODEL_NAME in MODEL_NAMES :
110
+ response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = x , headers = headers )
111
+ assert 500 == response .status_code
112
+ assert "document is empty" in response .text
107
113
108
114
109
115
@pytest .mark .skip_gpu
@@ -113,21 +119,26 @@ def test_large_input():
113
119
with open (data_file , "r" ) as file :
114
120
x = file .read ()
115
121
headers = make_headers (content_type = "text/csv" )
116
- response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = x , headers = headers ).json ()
117
- predictions = response ["predictions" ]
118
- assert len (predictions ) == 753936
122
+ for MODEL_NAME in MODEL_NAMES :
123
+ response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = x , headers = headers ).json ()
124
+ predictions = response ["predictions" ]
125
+ assert len (predictions ) == 753936
119
126
120
127
121
128
@pytest .mark .skip_gpu
122
129
def test_csv_input ():
123
130
headers = make_headers (content_type = "text/csv" )
124
131
data = "1.0,2.0,5.0"
125
- response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = data , headers = headers ).json ()
126
- assert response == {"predictions" : [3.5 , 4.0 , 5.5 ]}
127
-
132
+ responses = []
133
+ for MODEL_NAME in MODEL_NAMES :
134
+ response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = data , headers = headers ).json ()
135
+ responses .append (response )
136
+ assert responses [0 ] == {"predictions" : [3.5 , 4.0 , 5.5 ]}
137
+ assert responses [1 ] == {"predictions" : [2.5 , 3.0 , 4.5 ]}
128
138
129
139
@pytest .mark .skip_gpu
130
140
def test_specific_versions ():
141
+ MODEL_NAME = MODEL_NAMES [0 ]
131
142
for version in ("123" , "124" ):
132
143
headers = make_headers (content_type = "text/csv" , version = version )
133
144
data = "1.0,2.0,5.0"
@@ -141,6 +152,7 @@ def test_specific_versions():
141
152
def test_unsupported_content_type ():
142
153
headers = make_headers ("unsupported-type" , "predict" )
143
154
data = "aW1hZ2UgYnl0ZXM="
144
- response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = data , headers = headers )
145
- assert 500 == response .status_code
146
- assert "unsupported content type" in response .text
155
+ for MODEL_NAME in MODEL_NAMES :
156
+ response = requests .post (INVOCATION_URL .format (MODEL_NAME ), data = data , headers = headers )
157
+ assert 500 == response .status_code
158
+ assert "unsupported content type" in response .text
0 commit comments