Skip to content

breaking: Move _CsvDeserializer to sagemaker.deserializers and rename to CSVDeserializer #1682

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Jul 10, 2020
Merged
33 changes: 33 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
"""Implements methods for deserializing data returned from an inference endpoint."""
from __future__ import absolute_import

import csv

import abc


Expand Down Expand Up @@ -60,3 +62,34 @@ def deserialize(self, data, content_type):
return data.read()
finally:
data.close()


class CSVDeserializer(BaseDeserializer):
"""Deserialize a stream of bytes into a list of lists."""

ACCEPT = "test/csv"

def __init__(self, encoding="UTF-8"):
"""Initialize the string encoding.

Args:
encoding (str): The string encoding to use (default: "UTF-8").
"""
self.encoding = encoding

def deserialize(self, data, content_type):
"""Deserialize data from an inference endpoint into a list of lists.

Args:
data (botocore.response.StreamingBody): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
list: The data deserialized into a list of lists representing the
contents of a CSV file.
"""
try:
decoded_string = data.read().decode(self.encoding)
return list(csv.reader(decoded_string.splitlines()))
finally:
data.close()
26 changes: 0 additions & 26 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,32 +597,6 @@ def _row_to_csv(obj):
return ",".join(obj)


class _CsvDeserializer(object):
"""Placeholder docstring"""

def __init__(self, encoding="utf-8"):
"""
Args:
encoding:
"""
self.accept = CONTENT_TYPE_CSV
self.encoding = encoding

def __call__(self, stream, content_type):
"""
Args:
stream:
content_type:
"""
try:
return list(csv.reader(stream.read().decode(self.encoding).splitlines()))
finally:
stream.close()


csv_deserializer = _CsvDeserializer()


class StringDeserializer(object):
"""Return the response as a decoded string.

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/xgboost/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import logging

import sagemaker
from sagemaker.deserializers import CSVDeserializer
from sagemaker.fw_utils import model_code_key_prefix
from sagemaker.fw_registry import default_framework_uri
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import Predictor, npy_serializer, csv_deserializer
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.xgboost.defaults import XGBOOST_NAME

logger = logging.getLogger("sagemaker")
Expand All @@ -42,7 +43,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
chain.
"""
super(XGBoostPredictor, self).__init__(
endpoint_name, sagemaker_session, npy_serializer, csv_deserializer
endpoint_name, sagemaker_session, npy_serializer, CSVDeserializer()
)


Expand Down
29 changes: 28 additions & 1 deletion tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@

import io

from sagemaker.deserializers import BytesDeserializer
import pytest

from sagemaker.deserializers import BytesDeserializer, CSVDeserializer


def test_bytes_deserializer():
Expand All @@ -23,3 +25,28 @@ def test_bytes_deserializer():
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == b"[1, 2, 3]"


@pytest.fixture
def csv_deserializer():
return CSVDeserializer()


def test_csv_deserializer_single_element(csv_deserializer):
result = csv_deserializer.deserialize(io.BytesIO(b"1"), "text/csv")
assert result == [["1"]]


def test_csv_deserializer_array(csv_deserializer):
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3"), "text/csv")
assert result == [["1", "2", "3"]]


def test_csv_deserializer_2dimensional(csv_deserializer):
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
assert result == [["1", "2", "3"], ["3", "4", "5"]]


def test_csv_deserializer_posix_compliant(csv_deserializer):
result = csv_deserializer.deserialize(io.BytesIO(b"1,2,3\n3,4,5\n"), "text/csv")
assert result == [["1", "2", "3"], ["3", "4", "5"]]
16 changes: 0 additions & 16 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
json_serializer,
json_deserializer,
csv_serializer,
csv_deserializer,
StringDeserializer,
StreamDeserializer,
numpy_deserializer,
Expand Down Expand Up @@ -150,21 +149,6 @@ def test_csv_serializer_csv_reader():
assert result == validation_data


def test_csv_deserializer_single_element():
result = csv_deserializer(io.BytesIO(b"1"), "text/csv")
assert result == [["1"]]


def test_csv_deserializer_array():
result = csv_deserializer(io.BytesIO(b"1,2,3"), "text/csv")
assert result == [["1", "2", "3"]]


def test_csv_deserializer_2dimensional():
result = csv_deserializer(io.BytesIO(b"1,2,3\n3,4,5"), "text/csv")
assert result == [["1", "2", "3"], ["3", "4", "5"]]


def test_json_deserializer_array():
result = json_deserializer(io.BytesIO(b"[1, 2, 3]"), "application/json")

Expand Down