Skip to content

breaking: Move StringDeserializer to sagemaker.deserializers #1677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 8, 2020
29 changes: 29 additions & 0 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,35 @@ def ACCEPT(self):
"""The content type that is expected from the inference endpoint."""


class StringDeserializer(BaseDeserializer):
"""Deserialize data from an inference endpoint into a decoded string."""

ACCEPT = "application/json"

def __init__(self, encoding="UTF-8"):
"""Initialize the string encoding.

Args:
encoding (str): The string encoding to use (default: UTF-8).
"""
self.encoding = encoding

def deserialize(self, data, content_type):
"""Deserialize data from an inference endpoint into a decoded string.

Args:
data (object): Data to be deserialized.
content_type (str): The MIME type of the data.

Returns:
str: The data deserialized into a decoded string.
"""
try:
return data.read().decode(self.encoding)
finally:
data.close()


class BytesDeserializer(BaseDeserializer):
"""Deserialize a stream of bytes into a bytes object."""

Expand Down
29 changes: 0 additions & 29 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,35 +623,6 @@ def __call__(self, stream, content_type):
csv_deserializer = _CsvDeserializer()


class StringDeserializer(object):
"""Return the response as a decoded string.

Args:
encoding (str): The string encoding to use (default=utf-8).
accept (str): The Accept header to send to the server (optional).
"""

def __init__(self, encoding="utf-8", accept=None):
"""
Args:
encoding:
accept:
"""
self.encoding = encoding
self.accept = accept

def __call__(self, stream, content_type):
"""
Args:
stream:
content_type:
"""
try:
return stream.read().decode(self.encoding)
finally:
stream.close()


class StreamDeserializer(object):
"""Returns the tuple of the response stream and the content-type of the response.
It is the receivers responsibility to close the stream when they're done
Expand Down
3 changes: 2 additions & 1 deletion tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@

from sagemaker import utils
from sagemaker.amazon.randomcutforest import RandomCutForest
from sagemaker.deserializers import StringDeserializer
from sagemaker.multidatamodel import MultiDataModel
from sagemaker.mxnet import MXNet
from sagemaker.predictor import Predictor, StringDeserializer, npy_serializer
from sagemaker.predictor import Predictor, npy_serializer
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
from tests.integ.retry import retries
Expand Down
10 changes: 9 additions & 1 deletion tests/unit/sagemaker/test_deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,15 @@

import io

from sagemaker.deserializers import BytesDeserializer
from sagemaker.deserializers import StringDeserializer, BytesDeserializer


def test_string_deserializer():
deserializer = StringDeserializer()

result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == "[1, 2, 3]"


def test_bytes_deserializer():
Expand Down
7 changes: 0 additions & 7 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
json_deserializer,
csv_serializer,
csv_deserializer,
StringDeserializer,
StreamDeserializer,
numpy_deserializer,
npy_serializer,
Expand Down Expand Up @@ -183,12 +182,6 @@ def test_json_deserializer_invalid_data():
assert "column" in str(error)


def test_string_deserializer():
result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")

assert result == "[1, 2, 3]"


def test_stream_deserializer():
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
result = stream.read()
Expand Down