Skip to content

Commit bc48fc7

Browse files
author
Balaji Veeramani
committed
Update test_predictor.py
1 parent 93bf855 commit bc48fc7

File tree

1 file changed

+27
-5
lines changed

1 file changed

+27
-5
lines changed

tests/unit/test_predictor.py

+27-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import pytest
1818
from mock import Mock, call, patch
1919

20+
from sagemaker.deserializers import CSVDeserializer, StringDeserializer
2021
from sagemaker.predictor import Predictor
2122
from sagemaker.serializers import JSONSerializer, CSVSerializer
2223

@@ -132,7 +133,7 @@ def json_sagemaker_session():
132133
response_body.close = Mock("close", return_value=None)
133134
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
134135
name="invoke_endpoint",
135-
return_value={"Body": response_body, "ContentType": DEFAULT_CONTENT_TYPE},
136+
return_value={"Body": response_body, "ContentType": "application/json"},
136137
)
137138
return ims
138139

@@ -169,7 +170,7 @@ def ret_csv_sagemaker_session():
169170
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
170171

171172
response_body = Mock("body")
172-
response_body.read = Mock("read", return_value=CSV_RETURN_VALUE)
173+
response_body.read = Mock("read", return_value=bytes(CSV_RETURN_VALUE, "utf-8"))
173174
response_body.close = Mock("close", return_value=None)
174175
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
175176
name="invoke_endpoint",
@@ -180,23 +181,44 @@ def ret_csv_sagemaker_session():
180181

181182
def test_predict_call_with_csv():
182183
sagemaker_session = ret_csv_sagemaker_session()
183-
predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer())
184+
predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer())
184185

185186
data = [1, 2]
186187
result = predictor.predict(data)
187188

188189
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
189190

190191
expected_request_args = {
191-
"Accept": DEFAULT_ACCEPT,
192+
"Accept": CSV_CONTENT_TYPE,
193+
"Body": "1,2",
194+
"ContentType": CSV_CONTENT_TYPE,
195+
"EndpointName": ENDPOINT,
196+
}
197+
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
198+
assert kwargs == expected_request_args
199+
200+
assert result == [["1", "2", "3"]]
201+
202+
203+
def test_predict_call_with_multiple_accept_types():
204+
sagemaker_session = ret_csv_sagemaker_session()
205+
predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer())
206+
207+
data = [1, 2]
208+
result = predictor.predict(data)
209+
210+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
211+
212+
expected_request_args = {
213+
"Accept": "application/json, text/csv",
192214
"Body": "1,2",
193215
"ContentType": CSV_CONTENT_TYPE,
194216
"EndpointName": ENDPOINT,
195217
}
196218
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
197219
assert kwargs == expected_request_args
198220

199-
assert result == CSV_RETURN_VALUE
221+
assert result == "1,2,3\r\n"
200222

201223

202224
@patch("sagemaker.predictor.name_from_base")

0 commit comments

Comments
 (0)