From 98ddbbecf61a49c658ce70a2029875e8e000b065 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 31 Jul 2020 14:07:18 -0500 Subject: [PATCH 1/8] Add multiple Accept types --- src/sagemaker/deserializers.py | 14 +++++++------- src/sagemaker/predictor.py | 2 +- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index be6467f374..6a9f645c53 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -58,7 +58,7 @@ def ACCEPT(self): class StringDeserializer(BaseDeserializer): """Deserialize data from an inference endpoint into a decoded string.""" - ACCEPT = "application/json" + ACCEPT = ["application/json", "text/csv"] def __init__(self, encoding="UTF-8"): """Initialize the string encoding. @@ -87,7 +87,7 @@ def deserialize(self, stream, content_type): class BytesDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a bytes object.""" - ACCEPT = "*/*" + ACCEPT = ["*/*"] def deserialize(self, stream, content_type): """Read a stream of bytes returned from an inference endpoint. @@ -108,7 +108,7 @@ def deserialize(self, stream, content_type): class CSVDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a list of lists.""" - ACCEPT = "text/csv" + ACCEPT = ["text/csv"] def __init__(self, encoding="utf-8"): """Initialize the string encoding. @@ -143,7 +143,7 @@ class StreamDeserializer(BaseDeserializer): reading it. """ - ACCEPT = "*/*" + ACCEPT = ["*/*"] def deserialize(self, stream, content_type): """Returns a stream of the response body and the MIME type of the data. @@ -161,7 +161,7 @@ def deserialize(self, stream, content_type): class NumpyDeserializer(BaseDeserializer): """Deserialize a stream of data in the .npy format.""" - ACCEPT = "application/x-npy" + ACCEPT = ["application/x-npy"] def __init__(self, dtype=None, allow_pickle=True): """Initialize the dtype and allow_pickle arguments. @@ -201,7 +201,7 @@ def deserialize(self, stream, content_type): class JSONDeserializer(BaseDeserializer): """Deserialize JSON data from an inference endpoint into a Python object.""" - ACCEPT = "application/json" + ACCEPT = ["application/json"] def deserialize(self, stream, content_type): """Deserialize JSON data from an inference endpoint into a Python object. @@ -222,7 +222,7 @@ def deserialize(self, stream, content_type): class PandasDeserializer(BaseDeserializer): """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.""" - ACCEPT = "text/csv" + ACCEPT = ["text/csv", "application/json"] def deserialize(self, stream, content_type): """Deserialize CSV or JSON data from an inference endpoint into a pandas diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index f5e2f29820..2c9c3cbbd3 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -131,7 +131,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe args["ContentType"] = self.content_type if self.accept and "Accept" not in args: - args["Accept"] = self.accept + args["Accept"] = ", ".join(self.accept) if target_model: args["TargetModel"] = target_model From 93bf8553824e50d632d6091288f13d8b4245e3f9 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 31 Jul 2020 14:23:10 -0500 Subject: [PATCH 2/8] Update deserializers.py --- src/sagemaker/deserializers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 24c068be93..48760e97fe 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -52,7 +52,7 @@ def deserialize(self, stream, content_type): @property @abc.abstractmethod def ACCEPT(self): - """The content type that is expected from the inference endpoint.""" + """The content types that are expected from the inference endpoint.""" class StringDeserializer(BaseDeserializer): @@ -161,7 +161,7 @@ def deserialize(self, stream, content_type): class NumpyDeserializer(BaseDeserializer): """Deserialize a stream of data in the .npy format.""" - ACCEPT = ["application/x-npy"] + ACCEPT = ["application/x-npy", "text/csv", "application/json"] def __init__(self, dtype=None, allow_pickle=True): """Initialize the dtype and allow_pickle arguments. @@ -250,7 +250,7 @@ def deserialize(self, stream, content_type): class JSONLinesDeserializer(BaseDeserializer): """Deserialize JSON lines data from an inference endpoint.""" - ACCEPT = "application/jsonlines" + ACCEPT = ["application/jsonlines"] def deserialize(self, stream, content_type): """Deserialize JSON lines data from an inference endpoint. From bc48fc71f0103002563d0d639d6374a8baf9587c Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 31 Jul 2020 15:12:33 -0500 Subject: [PATCH 3/8] Update test_predictor.py --- tests/unit/test_predictor.py | 32 +++++++++++++++++++++++++++----- 1 file changed, 27 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index b93568fd4a..17c3a7d624 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -17,6 +17,7 @@ import pytest from mock import Mock, call, patch +from sagemaker.deserializers import CSVDeserializer, StringDeserializer from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer, CSVSerializer @@ -132,7 +133,7 @@ def json_sagemaker_session(): response_body.close = Mock("close", return_value=None) ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", - return_value={"Body": response_body, "ContentType": DEFAULT_CONTENT_TYPE}, + return_value={"Body": response_body, "ContentType": "application/json"}, ) return ims @@ -169,7 +170,7 @@ def ret_csv_sagemaker_session(): ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) response_body = Mock("body") - response_body.read = Mock("read", return_value=CSV_RETURN_VALUE) + response_body.read = Mock("read", return_value=bytes(CSV_RETURN_VALUE, "utf-8")) response_body.close = Mock("close", return_value=None) ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", @@ -180,7 +181,7 @@ def ret_csv_sagemaker_session(): def test_predict_call_with_csv(): sagemaker_session = ret_csv_sagemaker_session() - predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer()) + predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer()) data = [1, 2] result = predictor.predict(data) @@ -188,7 +189,28 @@ def test_predict_call_with_csv(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - "Accept": DEFAULT_ACCEPT, + "Accept": CSV_CONTENT_TYPE, + "Body": "1,2", + "ContentType": CSV_CONTENT_TYPE, + "EndpointName": ENDPOINT, + } + call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args + assert kwargs == expected_request_args + + assert result == [["1", "2", "3"]] + + +def test_predict_call_with_multiple_accept_types(): + sagemaker_session = ret_csv_sagemaker_session() + predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer()) + + data = [1, 2] + result = predictor.predict(data) + + assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called + + expected_request_args = { + "Accept": "application/json, text/csv", "Body": "1,2", "ContentType": CSV_CONTENT_TYPE, "EndpointName": ENDPOINT, @@ -196,7 +218,7 @@ def test_predict_call_with_csv(): call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args - assert result == CSV_RETURN_VALUE + assert result == "1,2,3\r\n" @patch("sagemaker.predictor.name_from_base") From 58eb512426b7ef78ee287930c2707d85e0baf05f Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 31 Jul 2020 15:16:31 -0500 Subject: [PATCH 4/8] Appease lint --- tests/unit/test_predictor.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 17c3a7d624..51bcbf9844 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -181,7 +181,9 @@ def ret_csv_sagemaker_session(): def test_predict_call_with_csv(): sagemaker_session = ret_csv_sagemaker_session() - predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer()) + predictor = Predictor( + ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer() + ) data = [1, 2] result = predictor.predict(data) @@ -202,7 +204,9 @@ def test_predict_call_with_csv(): def test_predict_call_with_multiple_accept_types(): sagemaker_session = ret_csv_sagemaker_session() - predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer()) + predictor = Predictor( + ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer() + ) data = [1, 2] result = predictor.predict(data) From 4bc264ed456c1059403f70105dca5e680eafc5f4 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 31 Jul 2020 15:42:14 -0500 Subject: [PATCH 5/8] Update deserializers.py --- src/sagemaker/deserializers.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 48760e97fe..c55fb048ee 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -58,7 +58,7 @@ def ACCEPT(self): class StringDeserializer(BaseDeserializer): """Deserialize data from an inference endpoint into a decoded string.""" - ACCEPT = ["application/json", "text/csv"] + ACCEPT = ("application/json", "text/csv") def __init__(self, encoding="UTF-8"): """Initialize the string encoding. @@ -87,7 +87,7 @@ def deserialize(self, stream, content_type): class BytesDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a bytes object.""" - ACCEPT = ["*/*"] + ACCEPT = ("*/*",) def deserialize(self, stream, content_type): """Read a stream of bytes returned from an inference endpoint. @@ -108,7 +108,7 @@ def deserialize(self, stream, content_type): class CSVDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a list of lists.""" - ACCEPT = ["text/csv"] + ACCEPT = ("text/csv",) def __init__(self, encoding="utf-8"): """Initialize the string encoding. @@ -143,7 +143,7 @@ class StreamDeserializer(BaseDeserializer): reading it. """ - ACCEPT = ["*/*"] + ACCEPT = ("*/*",) def deserialize(self, stream, content_type): """Returns a stream of the response body and the MIME type of the data. @@ -161,7 +161,7 @@ def deserialize(self, stream, content_type): class NumpyDeserializer(BaseDeserializer): """Deserialize a stream of data in the .npy format.""" - ACCEPT = ["application/x-npy", "text/csv", "application/json"] + ACCEPT = ("application/x-npy", "text/csv", "application/json") def __init__(self, dtype=None, allow_pickle=True): """Initialize the dtype and allow_pickle arguments. @@ -201,7 +201,7 @@ def deserialize(self, stream, content_type): class JSONDeserializer(BaseDeserializer): """Deserialize JSON data from an inference endpoint into a Python object.""" - ACCEPT = ["application/json"] + ACCEPT = ("application/json",) def deserialize(self, stream, content_type): """Deserialize JSON data from an inference endpoint into a Python object. @@ -222,7 +222,7 @@ def deserialize(self, stream, content_type): class PandasDeserializer(BaseDeserializer): """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.""" - ACCEPT = ["text/csv", "application/json"] + ACCEPT = ("text/csv", "application/json") def deserialize(self, stream, content_type): """Deserialize CSV or JSON data from an inference endpoint into a pandas @@ -250,7 +250,7 @@ def deserialize(self, stream, content_type): class JSONLinesDeserializer(BaseDeserializer): """Deserialize JSON lines data from an inference endpoint.""" - ACCEPT = ["application/jsonlines"] + ACCEPT = ("application/jsonlines",) def deserialize(self, stream, content_type): """Deserialize JSON lines data from an inference endpoint. From 1f1f70d1f22a211279a1d6c6c040d8adebf7bcad Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Sun, 2 Aug 2020 23:22:33 -0500 Subject: [PATCH 6/8] Update common.py --- src/sagemaker/amazon/common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index d69d8dacb4..d276356207 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -58,7 +58,7 @@ def serialize(self, data): class RecordDeserializer(BaseDeserializer): """Deserialize RecordIO Protobuf data from an inference endpoint.""" - ACCEPT = "application/x-recordio-protobuf" + ACCEPT = ("application/x-recordio-protobuf",) def deserialize(self, data, content_type): """Deserialize RecordIO Protobuf data from an inference endpoint. From a249739088c6515ea8fdf1e8671e6016d6de364e Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 3 Aug 2020 17:59:58 -0500 Subject: [PATCH 7/8] Update deserializers.py --- src/sagemaker/deserializers.py | 19 +++++++++++++++---- tests/unit/test_predictor.py | 2 +- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index c55fb048ee..bffdfa8bb7 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -58,7 +58,7 @@ def ACCEPT(self): class StringDeserializer(BaseDeserializer): """Deserialize data from an inference endpoint into a decoded string.""" - ACCEPT = ("application/json", "text/csv") + ACCEPT = ("application/json",) def __init__(self, encoding="UTF-8"): """Initialize the string encoding. @@ -161,16 +161,17 @@ def deserialize(self, stream, content_type): class NumpyDeserializer(BaseDeserializer): """Deserialize a stream of data in the .npy format.""" - ACCEPT = ("application/x-npy", "text/csv", "application/json") - - def __init__(self, dtype=None, allow_pickle=True): + def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True): """Initialize the dtype and allow_pickle arguments. Args: dtype (str): The dtype of the data (default: None). + accept (str): The MIME type that is expected from the inference + endpoint (default: "application/x-npy"). allow_pickle (bool): Allow loading pickled object arrays (default: True). """ self.dtype = dtype + self.accept = accept self.allow_pickle = allow_pickle def deserialize(self, stream, content_type): @@ -197,6 +198,16 @@ def deserialize(self, stream, content_type): raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type)) + @property + def ACCEPT(self): + """The content types that are expected from the inference endpoint. + + To maintain backwards compatability with legacy images, the + NumpyDeserializer supports sending only one content type in the Accept + header. + """ + return (self.accept,) + class JSONDeserializer(BaseDeserializer): """Deserialize JSON data from an inference endpoint into a Python object.""" diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 51bcbf9844..7628501873 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -214,7 +214,7 @@ def test_predict_call_with_multiple_accept_types(): assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - "Accept": "application/json, text/csv", + "Accept": "application/json", "Body": "1,2", "ContentType": CSV_CONTENT_TYPE, "EndpointName": ENDPOINT, From bfdd6aa89462c3d83651dc2570bfef0e082eef9b Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 3 Aug 2020 18:21:39 -0500 Subject: [PATCH 8/8] Update test_predictor.py --- tests/unit/test_predictor.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 7628501873..81762e1cb4 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -12,12 +12,13 @@ # language governing permissions and limitations under the License. from __future__ import absolute_import +import io import json import pytest from mock import Mock, call, patch -from sagemaker.deserializers import CSVDeserializer, StringDeserializer +from sagemaker.deserializers import CSVDeserializer, PandasDeserializer from sagemaker.predictor import Predictor from sagemaker.serializers import JSONSerializer, CSVSerializer @@ -169,9 +170,7 @@ def ret_csv_sagemaker_session(): ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC) ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC) - response_body = Mock("body") - response_body.read = Mock("read", return_value=bytes(CSV_RETURN_VALUE, "utf-8")) - response_body.close = Mock("close", return_value=None) + response_body = io.BytesIO(bytes(CSV_RETURN_VALUE, "utf-8")) ims.sagemaker_runtime_client.invoke_endpoint = Mock( name="invoke_endpoint", return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE}, @@ -205,16 +204,16 @@ def test_predict_call_with_csv(): def test_predict_call_with_multiple_accept_types(): sagemaker_session = ret_csv_sagemaker_session() predictor = Predictor( - ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=StringDeserializer() + ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=PandasDeserializer() ) data = [1, 2] - result = predictor.predict(data) + predictor.predict(data) assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called expected_request_args = { - "Accept": "application/json", + "Accept": "text/csv, application/json", "Body": "1,2", "ContentType": CSV_CONTENT_TYPE, "EndpointName": ENDPOINT, @@ -222,8 +221,6 @@ def test_predict_call_with_multiple_accept_types(): call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args assert kwargs == expected_request_args - assert result == "1,2,3\r\n" - @patch("sagemaker.predictor.name_from_base") def test_update_endpoint_no_args(name_from_base):