diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index d062ffeb6e..4dc694ee66 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -41,6 +41,35 @@ def ACCEPT(self): """The content type that is expected from the inference endpoint.""" +class StringDeserializer(BaseDeserializer): + """Deserialize data from an inference endpoint into a decoded string.""" + + ACCEPT = "application/json" + + def __init__(self, encoding="UTF-8"): + """Initialize the string encoding. + + Args: + encoding (str): The string encoding to use (default: UTF-8). + """ + self.encoding = encoding + + def deserialize(self, data, content_type): + """Deserialize data from an inference endpoint into a decoded string. + + Args: + data (object): Data to be deserialized. + content_type (str): The MIME type of the data. + + Returns: + str: The data deserialized into a decoded string. + """ + try: + return data.read().decode(self.encoding) + finally: + data.close() + + class BytesDeserializer(BaseDeserializer): """Deserialize a stream of bytes into a bytes object.""" diff --git a/src/sagemaker/predictor.py b/src/sagemaker/predictor.py index 660f375d6c..fad236ea0a 100644 --- a/src/sagemaker/predictor.py +++ b/src/sagemaker/predictor.py @@ -623,35 +623,6 @@ def __call__(self, stream, content_type): csv_deserializer = _CsvDeserializer() -class StringDeserializer(object): - """Return the response as a decoded string. - - Args: - encoding (str): The string encoding to use (default=utf-8). - accept (str): The Accept header to send to the server (optional). - """ - - def __init__(self, encoding="utf-8", accept=None): - """ - Args: - encoding: - accept: - """ - self.encoding = encoding - self.accept = accept - - def __call__(self, stream, content_type): - """ - Args: - stream: - content_type: - """ - try: - return stream.read().decode(self.encoding) - finally: - stream.close() - - class StreamDeserializer(object): """Returns the tuple of the response stream and the content-type of the response. It is the receivers responsibility to close the stream when they're done diff --git a/tests/integ/test_multidatamodel.py b/tests/integ/test_multidatamodel.py index 1366107365..5e29b4607b 100644 --- a/tests/integ/test_multidatamodel.py +++ b/tests/integ/test_multidatamodel.py @@ -24,9 +24,10 @@ from sagemaker import utils from sagemaker.amazon.randomcutforest import RandomCutForest +from sagemaker.deserializers import StringDeserializer from sagemaker.multidatamodel import MultiDataModel from sagemaker.mxnet import MXNet -from sagemaker.predictor import Predictor, StringDeserializer, npy_serializer +from sagemaker.predictor import Predictor, npy_serializer from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES from tests.integ.retry import retries diff --git a/tests/unit/sagemaker/test_deserializers.py b/tests/unit/sagemaker/test_deserializers.py index 7b3bbf6f40..e4e3149b7a 100644 --- a/tests/unit/sagemaker/test_deserializers.py +++ b/tests/unit/sagemaker/test_deserializers.py @@ -14,7 +14,15 @@ import io -from sagemaker.deserializers import BytesDeserializer +from sagemaker.deserializers import StringDeserializer, BytesDeserializer + + +def test_string_deserializer(): + deserializer = StringDeserializer() + + result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json") + + assert result == "[1, 2, 3]" def test_bytes_deserializer(): diff --git a/tests/unit/test_predictor.py b/tests/unit/test_predictor.py index 3417523af9..648758bc3e 100644 --- a/tests/unit/test_predictor.py +++ b/tests/unit/test_predictor.py @@ -26,7 +26,6 @@ json_deserializer, csv_serializer, csv_deserializer, - StringDeserializer, StreamDeserializer, numpy_deserializer, npy_serializer, @@ -183,12 +182,6 @@ def test_json_deserializer_invalid_data(): assert "column" in str(error) -def test_string_deserializer(): - result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") - - assert result == "[1, 2, 3]" - - def test_stream_deserializer(): stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json") result = stream.read()