Skip to content

Commit 45d71bd

Browse files
author
Balaji Veeramani
committed
Add StringDeserializer
1 parent 1487b22 commit 45d71bd

File tree

7 files changed

+99
-37
lines changed

7 files changed

+99
-37
lines changed

src/sagemaker/deserializers.py

+36
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import abc
1717

18+
from sagemaker.utils import parse_mime_type
19+
1820

1921
class BaseDeserializer(abc.ABC):
2022
"""Abstract base class for creation of new deserializers.
@@ -39,3 +41,37 @@ def deserialize(self, data, content_type):
3941
@abc.abstractmethod
4042
def ACCEPT(self):
4143
"""The content type that is expected from the inference endpoint."""
44+
45+
46+
class StringDeserializer(object):
47+
"""Deserialize data from an inference endpoint into a decoded string."""
48+
49+
def __init__(self, encoding="UTF-8"):
50+
"""Initialize the default encoding.
51+
52+
Args:
53+
encoding (str): The string encoding to use, if a charset is not
54+
provided by the server (default: UTF-8).
55+
"""
56+
self.encoding = encoding
57+
58+
def deserialize(self, data, content_type):
59+
"""Deserialize data from an inference endpoint into a decoded string.
60+
61+
Args:
62+
data (object): A string or a byte stream.
63+
content_type (str): The MIME type of the data.
64+
65+
Returns:
66+
str: The data deserialized into a decoded string.
67+
"""
68+
category, _, parameters = parse_mime_type(content_type)
69+
70+
if category == "text":
71+
return data
72+
73+
try:
74+
encoding = parameters.get("charset", self.encoding)
75+
return data.read().decode(encoding)
76+
finally:
77+
data.close()

src/sagemaker/predictor.py

-29
Original file line numberDiff line numberDiff line change
@@ -649,35 +649,6 @@ def __call__(self, stream, content_type):
649649
stream.close()
650650

651651

652-
class StringDeserializer(object):
653-
"""Return the response as a decoded string.
654-
655-
Args:
656-
encoding (str): The string encoding to use (default=utf-8).
657-
accept (str): The Accept header to send to the server (optional).
658-
"""
659-
660-
def __init__(self, encoding="utf-8", accept=None):
661-
"""
662-
Args:
663-
encoding:
664-
accept:
665-
"""
666-
self.encoding = encoding
667-
self.accept = accept
668-
669-
def __call__(self, stream, content_type):
670-
"""
671-
Args:
672-
stream:
673-
content_type:
674-
"""
675-
try:
676-
return stream.read().decode(self.encoding)
677-
finally:
678-
stream.close()
679-
680-
681652
class StreamDeserializer(object):
682653
"""Returns the tuple of the response stream and the content-type of the response.
683654
It is the receivers responsibility to close the stream when they're done

src/sagemaker/test_deserializers.py

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import io
16+
17+
from sagemaker.deserializers import StringDeserializer
18+
19+
20+
def test_string_deserializer_plain_text():
21+
deserializer = StringDeserializer()
22+
23+
result = deserializer.deserialize("Hello, world!", "text/plain")
24+
25+
assert result == "Hello, world!"
26+
27+
28+
def test_string_deserializer_octet_stream():
29+
deserializer = StringDeserializer()
30+
31+
result = deserializer.deserialize(io.BytesIO(b"Hello, world!"), "application/octet-stream")
32+
33+
assert result == "Hello, world!"

src/sagemaker/utils.py

+20
Original file line numberDiff line numberDiff line change
@@ -745,3 +745,23 @@ def _module_import_error(py_module, feature, extras):
745745
"to install all required dependencies."
746746
)
747747
return error_msg.format(py_module, feature, extras)
748+
749+
750+
def parse_mime_type(mime_type):
751+
"""Parse a MIME type and return the type, subtype, and parameters.
752+
753+
Args:
754+
mime_type (str): A MIME type.
755+
756+
Returns:
757+
tuple: A three-tuple containing the type, subtype, and parameters. The
758+
type and subtype are strings, and the parameters are stored in a
759+
dictionary.
760+
"""
761+
category, remaining = mime_type.split("/")
762+
subtype = remaining.split(";")[0]
763+
parameters = {}
764+
for parameter in remaining.split(";")[1:]:
765+
attribute, value = parameter.split("=")
766+
parameters[attribute] = value
767+
return category, subtype, parameters

tests/integ/test_multidatamodel.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424

2525
from sagemaker import utils
2626
from sagemaker.amazon.randomcutforest import RandomCutForest
27+
from sagemaker.deserializers import StringDeserializer
2728
from sagemaker.multidatamodel import MultiDataModel
2829
from sagemaker.mxnet import MXNet
29-
from sagemaker.predictor import Predictor, StringDeserializer, npy_serializer
30+
from sagemaker.predictor import Predictor, npy_serializer
3031
from sagemaker.utils import sagemaker_timestamp, unique_name_from_base, get_ecr_image_uri_prefix
3132
from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES
3233
from tests.integ.retry import retries

tests/unit/test_predictor.py

-7
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
csv_serializer,
2828
csv_deserializer,
2929
BytesDeserializer,
30-
StringDeserializer,
3130
StreamDeserializer,
3231
numpy_deserializer,
3332
npy_serializer,
@@ -190,12 +189,6 @@ def test_bytes_deserializer():
190189
assert result == b"[1, 2, 3]"
191190

192191

193-
def test_string_deserializer():
194-
result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
195-
196-
assert result == "[1, 2, 3]"
197-
198-
199192
def test_stream_deserializer():
200193
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
201194
result = stream.read()

tests/unit/test_utils.py

+8
Original file line numberDiff line numberDiff line change
@@ -779,3 +779,11 @@ def test_partition_by_region():
779779
assert sagemaker.utils._aws_partition("us-gov-east-1") == "aws-us-gov"
780780
assert sagemaker.utils._aws_partition("us-iso-east-1") == "aws-iso"
781781
assert sagemaker.utils._aws_partition("us-isob-east-1") == "aws-iso-b"
782+
783+
784+
def test_parse_mime_type():
785+
mime_type = "application/octet-stream;charset=UTF-8"
786+
category, subtype, parameters = sagemaker.utils.parse_mime_type(mime_type)
787+
assert category == "application"
788+
assert subtype == "octet-stream"
789+
assert parameters == {"charset": "UTF-8"}

0 commit comments

Comments
 (0)