From 05c8e6190d134c5748942e85e07fedbe25dcf5a1 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 8 Jul 2020 12:56:23 -0500 Subject: [PATCH 1/7] Rename _CsvDeserializer to CSVDeserializer --- src/sagemaker/deserializers.py | 33 ++++++++++++++++++++++ src/sagemaker/predictor.py | 26 ----------------- src/sagemaker/xgboost/model.py | 5 ++-- tests/unit/sagemaker/test_deserializers.py | 26 ++++++++++++++++- tests/unit/test_predictor.py | 16 ----------- 5 files changed, 61 insertions(+), 45 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index d062ffeb6e..8953dcd64c 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -13,6 +13,8 @@ """Implements methods for deserializing data returned from an inference endpoint.""" from __future__ import absolute_import +import csv + import abc @@ -60,3 +62,34 @@ def deserialize(self, data, content_type): return data.read() finally: data.close() + + +class CSVDeserializer(BaseDeserializer): + """Deserialize a stream of bytes into a list of lists.""" + + ACCEPT = "test/csv" + + def __init__(self, encoding="UTF-8"): + """Initialize the string encoding. + + Args: + encoding (str): The string encoding to use (default: "UTF-8"). + """ + self.encoding = encoding + + def deserialize(self, data, content_type): + """Deserialize data from an inference endpoint into a list of lists. + + Args: + data (object): Data to be deserialized. + content_type (str): The MIME type of the data. + + Returns: + list: The data deserialized into a list of lists representing the + contents of a CSV file. + """ + try: + string = data.read().decode(self.encoding) + return list(csv.reader(string.splitlines())) + finally: + data.close() diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 660f375d6c..142ddc9375 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -597,32 +597,6 @@ def _row_to_csv(obj): return ",".join(obj) -class _CsvDeserializer(object): - """Placeholder docstring""" - - def __init__(self, encoding="utf-8"): - """ - Args: - encoding: - """ - self.accept = CONTENT_TYPE_CSV - self.encoding = encoding - - def __call__(self, stream, content_type): - """ - Args: - stream: - content_type: - """ - try: - return list(csv.reader(stream.read().decode(self.encoding).splitlines())) - finally: - stream.close() - - -csv_deserializer = _CsvDeserializer() - - class StringDeserializer(object): """Return the response as a decoded string. diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index fd17abeec4..7fac3c5976 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -16,10 +16,11 @@ import logging import sagemaker +from sagemaker.deserializers import CSVDeserializer from sagemaker.fw_utils import model_code_key_prefix from sagemaker.fw_registry import default_framework_uri from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer +from sagemaker.predictor import Predictor, npy_serializer from sagemaker.xgboost.defaults import XGBOOST_NAME logger = logging.getLogger("sagemaker") @@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): chain. """ super(XGBoostPredictor, self).__init__( - endpoint_name, sagemaker_session, npy_serializer, csv_deserializer + endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer() ) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index 7b3bbf6f40..365a87fd00 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -14,7 +14,7 @@ import io -from sagemaker.deserializers import BytesDeserializer +from sagemaker.deserializers import BytesDeserializer, CSVDeserializer def test_bytes_deserializer(): @@ -23,3 +23,27 @@ def test_bytes_deserializer(): result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") assert result == b"[1, 2, 3]" + + +def test_csv_deserializer_single_element(): + deserializer = CSVDeserializer() + + result = deserializer.deserialize(io.BytesIO(b"1"), "text/csv") + + assert result == [["1"]] + + +def test_csv_deserializer_array(): + deserializer = CSVDeserializer() + + result = deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv") + + assert result == [["1", "2", "3"]] + + +def test_csv_deserializer_2dimensional(): + deserializer = CSVDeserializer() + + result = deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") + + assert result == [["1", "2", "3"], ["3", "4", "5"]] diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 3417523af9..022b7f9b6c 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -25,7 +25,6 @@ json_serializer, json_deserializer, csv_serializer, - csv_deserializer, StringDeserializer, StreamDeserializer, numpy_deserializer, @@ -150,21 +149,6 @@ def test_csv_serializer_csv_reader(): assert result == validation_data -def test_csv_deserializer_single_element(): - result = csv_deserializer(io.BytesIO(b"1"), "text/csv") - assert result == [["1"]] - - -def test_csv_deserializer_array(): - result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv") - assert result == [["1", "2", "3"]] - - -def test_csv_deserializer_2dimensional(): - result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") - assert result == [["1", "2", "3"], ["3", "4", "5"]] - - def test_json_deserializer_array(): result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json") From 29581bbba6a8223eb0aefad46acdb058115ff3b6 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 8 Jul 2020 15:10:01 -0500 Subject: [PATCH 2/7] Address review comments --- src/sagemaker/deserializers.py | 6 +++--- tests/unit/sagemaker/test_deserializers.py | 25 ++++++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 8953dcd64c..e9d6fb4367 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -81,7 +81,7 @@ def deserialize(self, data, content_type): """Deserialize data from an inference endpoint into a list of lists. Args: - data (object): Data to be deserialized. + data (botocore.response.StreamingBody): Data to be deserialized. content_type (str): The MIME type of the data. Returns: @@ -89,7 +89,7 @@ def deserialize(self, data, content_type): contents of a CSV file. """ try: - string = data.read().decode(self.encoding) - return list(csv.reader(string.splitlines())) + decoded_string = data.read().decode(self.encoding) + return list(csv.reader(decoded_string.splitlines())) finally: data.close() diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index 365a87fd00..46ea30caf0 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -14,6 +14,8 @@ import io +import pytest + from sagemaker.deserializers import BytesDeserializer, CSVDeserializer @@ -25,25 +27,26 @@ def test_bytes_deserializer(): assert result == b"[1, 2, 3]" -def test_csv_deserializer_single_element(): - deserializer = CSVDeserializer() +@pytest.fixture +def csv_deserializer(): + return CSVDeserializer() - result = deserializer.deserialize(io.BytesIO(b"1"), "text/csv") +def test_csv_deserializer_single_element(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1"), "text/csv") assert result == [["1"]] -def test_csv_deserializer_array(): - deserializer = CSVDeserializer() - - result = deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv") - +def test_csv_deserializer_array(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv") assert result == [["1", "2", "3"]] -def test_csv_deserializer_2dimensional(): - deserializer = CSVDeserializer() +def test_csv_deserializer_2dimensional(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") + assert result == [["1", "2", "3"], ["3", "4", "5"]] - result = deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") +def test_csv_deserializer_posix_compliant(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5\n"), "text/csv") assert result == [["1", "2", "3"], ["3", "4", "5"]] From d62cc8f868104d82326fdd8eb7bc2d8f88dd11fc Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 8 Jul 2020 16:32:13 -0500 Subject: [PATCH 3/7] Address review comments --- src/sagemaker/deserializers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index e9d6fb4367..fb4e669c94 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -69,11 +69,11 @@ class CSVDeserializer(BaseDeserializer): ACCEPT = "test/csv" - def __init__(self, encoding="UTF-8"): + def __init__(self, encoding="utf-8"): """Initialize the string encoding. Args: - encoding (str): The string encoding to use (default: "UTF-8"). + encoding (str): The string encoding to use (default: "utf-8"). """ self.encoding = encoding From 6bacb9ed2871970b0e3f7454aa8d522ca1287c67 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 8 Jul 2020 21:02:27 -0500 Subject: [PATCH 4/7] Appease black check --- tests/unit/sagemaker/test_deserializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index 3c735be0c8..4c2611590e 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -25,7 +25,7 @@ def test_string_deserializer(): result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") assert result == "[1, 2, 3]" - + def test_bytes_deserializer(): deserializer = BytesDeserializer() From 5befbe5b0332c6505ca5e88e0185f311b209ea0b Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 9 Jul 2020 13:03:27 -0500 Subject: [PATCH 5/7] Appease black format --- tests/unit/sagemaker/test_deserializers.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index cecf6fff9f..d52d8dd2f2 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -16,7 +16,12 @@ import pytest -from sagemaker.deserializers import StringDeserializer, BytesDeserializer, CSVDeserializer, StreamDeserializer +from sagemaker.deserializers import ( + StringDeserializer, + BytesDeserializer, + CSVDeserializer, + StreamDeserializer, +) def test_string_deserializer(): From 6aabe7745037dba2ef7668a7a65f3ed017aaeb81 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 9 Jul 2020 17:57:30 -0500 Subject: [PATCH 6/7] Apease flake8 --- tests/unit/sagemaker/test_deserializers.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index 5dc09e9dad..edd4deb474 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -22,7 +22,6 @@ BytesDeserializer, CSVDeserializer, StreamDeserializer, - StreamDeserializer, NumpyDeserializer, ) From 5309fa79396d7440a3ec09dff178c58232e21fa2 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Thu, 9 Jul 2020 18:43:50 -0500 Subject: [PATCH 7/7] Fix typo --- src/sagemaker/deserializers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index b7563693ae..b49b5aeabb 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -101,7 +101,7 @@ def deserialize(self, data, content_type): class CSVDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a list of lists.""" - ACCEPT = "test/csv" + ACCEPT = "text/csv" def __init__(self, encoding="utf-8"): """Initialize the string encoding.