17
17
import pytest
18
18
from mock import Mock , call , patch
19
19
20
+ from sagemaker .deserializers import CSVDeserializer , StringDeserializer
20
21
from sagemaker .predictor import Predictor
21
22
from sagemaker .serializers import JSONSerializer , CSVSerializer
22
23
@@ -132,7 +133,7 @@ def json_sagemaker_session():
132
133
response_body .close = Mock ("close" , return_value = None )
133
134
ims .sagemaker_runtime_client .invoke_endpoint = Mock (
134
135
name = "invoke_endpoint" ,
135
- return_value = {"Body" : response_body , "ContentType" : DEFAULT_CONTENT_TYPE },
136
+ return_value = {"Body" : response_body , "ContentType" : "application/json" },
136
137
)
137
138
return ims
138
139
@@ -169,7 +170,7 @@ def ret_csv_sagemaker_session():
169
170
ims .sagemaker_client .describe_endpoint_config = Mock (return_value = ENDPOINT_CONFIG_DESC )
170
171
171
172
response_body = Mock ("body" )
172
- response_body .read = Mock ("read" , return_value = CSV_RETURN_VALUE )
173
+ response_body .read = Mock ("read" , return_value = bytes ( CSV_RETURN_VALUE , "utf-8" ) )
173
174
response_body .close = Mock ("close" , return_value = None )
174
175
ims .sagemaker_runtime_client .invoke_endpoint = Mock (
175
176
name = "invoke_endpoint" ,
@@ -180,23 +181,44 @@ def ret_csv_sagemaker_session():
180
181
181
182
def test_predict_call_with_csv ():
182
183
sagemaker_session = ret_csv_sagemaker_session ()
183
- predictor = Predictor (ENDPOINT , sagemaker_session , serializer = CSVSerializer ())
184
+ predictor = Predictor (ENDPOINT , sagemaker_session , serializer = CSVSerializer (), deserializer = CSVDeserializer () )
184
185
185
186
data = [1 , 2 ]
186
187
result = predictor .predict (data )
187
188
188
189
assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint .called
189
190
190
191
expected_request_args = {
191
- "Accept" : DEFAULT_ACCEPT ,
192
+ "Accept" : CSV_CONTENT_TYPE ,
193
+ "Body" : "1,2" ,
194
+ "ContentType" : CSV_CONTENT_TYPE ,
195
+ "EndpointName" : ENDPOINT ,
196
+ }
197
+ call_args , kwargs = sagemaker_session .sagemaker_runtime_client .invoke_endpoint .call_args
198
+ assert kwargs == expected_request_args
199
+
200
+ assert result == [["1" , "2" , "3" ]]
201
+
202
+
203
+ def test_predict_call_with_multiple_accept_types ():
204
+ sagemaker_session = ret_csv_sagemaker_session ()
205
+ predictor = Predictor (ENDPOINT , sagemaker_session , serializer = CSVSerializer (), deserializer = StringDeserializer ())
206
+
207
+ data = [1 , 2 ]
208
+ result = predictor .predict (data )
209
+
210
+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint .called
211
+
212
+ expected_request_args = {
213
+ "Accept" : "application/json, text/csv" ,
192
214
"Body" : "1,2" ,
193
215
"ContentType" : CSV_CONTENT_TYPE ,
194
216
"EndpointName" : ENDPOINT ,
195
217
}
196
218
call_args , kwargs = sagemaker_session .sagemaker_runtime_client .invoke_endpoint .call_args
197
219
assert kwargs == expected_request_args
198
220
199
- assert result == CSV_RETURN_VALUE
221
+ assert result == "1,2,3 \r \n "
200
222
201
223
202
224
@patch ("sagemaker.predictor.name_from_base" )
0 commit comments