Skip to content

Commit 9a0f8ac

Browse files
authored
breaking: Add BytesDeserializer (#1674)
1 parent e10b29b commit 9a0f8ac

File tree

5 files changed

+50
-35
lines changed

5 files changed

+50
-35
lines changed

src/sagemaker/deserializers.py

+21
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,24 @@ def deserialize(self, data, content_type):
3939
@abc.abstractmethod
4040
def ACCEPT(self):
4141
"""The content type that is expected from the inference endpoint."""
42+
43+
44+
class BytesDeserializer(BaseDeserializer):
45+
"""Deserialize a stream of bytes into a bytes object."""
46+
47+
ACCEPT = "application/octet-stream"
48+
49+
def deserialize(self, data, content_type):
50+
"""Read a stream of bytes returned from an inference endpoint.
51+
52+
Args:
53+
data (object): A stream of bytes.
54+
content_type (str): The MIME type of the data.
55+
56+
Returns:
57+
bytes: The bytes object read from the stream.
58+
"""
59+
try:
60+
return data.read()
61+
finally:
62+
data.close()

src/sagemaker/predictor.py

-26
Original file line numberDiff line numberDiff line change
@@ -623,32 +623,6 @@ def __call__(self, stream, content_type):
623623
csv_deserializer = _CsvDeserializer()
624624

625625

626-
class BytesDeserializer(object):
627-
"""Return the response as an undecoded array of bytes.
628-
629-
Args:
630-
accept (str): The Accept header to send to the server (optional).
631-
"""
632-
633-
def __init__(self, accept=None):
634-
"""
635-
Args:
636-
accept:
637-
"""
638-
self.accept = accept
639-
640-
def __call__(self, stream, content_type):
641-
"""
642-
Args:
643-
stream:
644-
content_type:
645-
"""
646-
try:
647-
return stream.read()
648-
finally:
649-
stream.close()
650-
651-
652626
class StringDeserializer(object):
653627
"""Return the response as a decoded string.
654628

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def test_import_from_node_should_be_modified_random_import():
117117
def test_import_from_modify_node():
118118
modifier = predictors.PredictorImportFromRenamer()
119119

120-
node = ast_import("from sagemaker.predictor import BytesDeserializer, RealTimePredictor")
120+
node = ast_import(
121+
"from sagemaker.predictor import ClassThatHasntBeenRenamed, RealTimePredictor"
122+
)
121123
modifier.modify_node(node)
122-
expected_result = "from sagemaker.predictor import BytesDeserializer, Predictor"
124+
expected_result = "from sagemaker.predictor import ClassThatHasntBeenRenamed, Predictor"
123125
assert expected_result == pasta.dump(node)
124126

125127
node = ast_import("from sagemaker.predictor import RealTimePredictor as RTP")
+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import io
16+
17+
from sagemaker.deserializers import BytesDeserializer
18+
19+
20+
def test_bytes_deserializer():
21+
deserializer = BytesDeserializer()
22+
23+
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
24+
25+
assert result == b"[1, 2, 3]"

tests/unit/test_predictor.py

-7
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
json_deserializer,
2727
csv_serializer,
2828
csv_deserializer,
29-
BytesDeserializer,
3029
StringDeserializer,
3130
StreamDeserializer,
3231
numpy_deserializer,
@@ -184,12 +183,6 @@ def test_json_deserializer_invalid_data():
184183
assert "column" in str(error)
185184

186185

187-
def test_bytes_deserializer():
188-
result = BytesDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
189-
190-
assert result == b"[1, 2, 3]"
191-
192-
193186
def test_string_deserializer():
194187
result = StringDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
195188

0 commit comments

Comments
 (0)