diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index 8dd5ca66d8..b49b5aeabb 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -13,6 +13,8 @@ """Implements methods for deserializing data returned from an inference endpoint.""" from __future__ import absolute_import +import csv + import abc import codecs import io @@ -96,6 +98,37 @@ def deserialize(self, data, content_type): data.close() +class CSVDeserializer(BaseDeserializer): + """Deserialize a stream of bytes into a list of lists.""" + + ACCEPT = "text/csv" + + def __init__(self, encoding="utf-8"): + """Initialize the string encoding. + + Args: + encoding (str): The string encoding to use (default: "utf-8"). + """ + self.encoding = encoding + + def deserialize(self, data, content_type): + """Deserialize data from an inference endpoint into a list of lists. + + Args: + data (botocore.response.StreamingBody): Data to be deserialized. + content_type (str): The MIME type of the data. + + Returns: + list: The data deserialized into a list of lists representing the + contents of a CSV file. + """ + try: + decoded_string = data.read().decode(self.encoding) + return list(csv.reader(decoded_string.splitlines())) + finally: + data.close() + + class StreamDeserializer(BaseDeserializer): """Returns the data and content-type received from an inference endpoint. diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 04ccc36a69..ace8c5886b 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -597,32 +597,6 @@ def _row_to_csv(obj): return ",".join(obj) -class _CsvDeserializer(object): - """Placeholder docstring""" - - def __init__(self, encoding="utf-8"): - """ - Args: - encoding: - """ - self.accept = CONTENT_TYPE_CSV - self.encoding = encoding - - def __call__(self, stream, content_type): - """ - Args: - stream: - content_type: - """ - try: - return list(csv.reader(stream.read().decode(self.encoding).splitlines())) - finally: - stream.close() - - -csv_deserializer = _CsvDeserializer() - - class _JsonSerializer(object): """Placeholder docstring""" diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index fd17abeec4..7fac3c5976 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -16,10 +16,11 @@ import logging import sagemaker +from sagemaker.deserializers import CSVDeserializer from sagemaker.fw_utils import model_code_key_prefix from sagemaker.fw_registry import default_framework_uri from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME -from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer +from sagemaker.predictor import Predictor, npy_serializer from sagemaker.xgboost.defaults import XGBOOST_NAME logger = logging.getLogger("sagemaker") @@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None): chain. """ super(XGBoostPredictor, self).__init__( - endpoint_name, sagemaker_session, npy_serializer, csv_deserializer + endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer() ) diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index d3a806f489..edd4deb474 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -20,6 +20,7 @@ from sagemaker.deserializers import ( StringDeserializer, BytesDeserializer, + CSVDeserializer, StreamDeserializer, NumpyDeserializer, ) @@ -41,6 +42,31 @@ def test_bytes_deserializer(): assert result == b"[1, 2, 3]" +@pytest.fixture +def csv_deserializer(): + return CSVDeserializer() + + +def test_csv_deserializer_single_element(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1"), "text/csv") + assert result == [["1"]] + + +def test_csv_deserializer_array(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv") + assert result == [["1", "2", "3"]] + + +def test_csv_deserializer_2dimensional(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") + assert result == [["1", "2", "3"], ["3", "4", "5"]] + + +def test_csv_deserializer_posix_compliant(csv_deserializer): + result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5\n"), "text/csv") + assert result == [["1", "2", "3"], ["3", "4", "5"]] + + def test_stream_deserializer(): deserializer = StreamDeserializer() diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 0c760a6332..b385c3129d 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -25,7 +25,6 @@ json_serializer, json_deserializer, csv_serializer, - csv_deserializer, npy_serializer, ) from tests.unit import DATA_DIR @@ -146,21 +145,6 @@ def test_csv_serializer_csv_reader(): assert result == validation_data -def test_csv_deserializer_single_element(): - result = csv_deserializer(io.BytesIO(b"1"), "text/csv") - assert result == [["1"]] - - -def test_csv_deserializer_array(): - result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv") - assert result == [["1", "2", "3"]] - - -def test_csv_deserializer_2dimensional(): - result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv") - assert result == [["1", "2", "3"], ["3", "4", "5"]] - - def test_json_deserializer_array(): result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")