Skip to content

Commit 4a17b9a

Browse files
authored
Merge branch 'zwei' into add-string-deserializer
2 parents 1dc5157 + 9a0f8ac commit 4a17b9a

File tree

5 files changed

+35
-37
lines changed

5 files changed

+35
-37
lines changed

src/sagemaker/deserializers.py

+21
Original file line numberDiff line numberDiff line change
@@ -68,3 +68,24 @@ def deserialize(self, data, content_type):
6868
return data.read().decode(self.encoding)
6969
finally:
7070
data.close()
71+
72+
73+
class BytesDeserializer(BaseDeserializer):
74+
"""Deserialize a stream of bytes into a bytes object."""
75+
76+
ACCEPT = "application/octet-stream"
77+
78+
def deserialize(self, data, content_type):
79+
"""Read a stream of bytes returned from an inference endpoint.
80+
81+
Args:
82+
data (object): A stream of bytes.
83+
content_type (str): The MIME type of the data.
84+
85+
Returns:
86+
bytes: The bytes object read from the stream.
87+
"""
88+
try:
89+
return data.read()
90+
finally:
91+
data.close()

src/sagemaker/predictor.py

-26
Original file line numberDiff line numberDiff line change
@@ -623,32 +623,6 @@ def __call__(self, stream, content_type):
623623
csv_deserializer = _CsvDeserializer()
624624

625625

626-
class BytesDeserializer(object):
627-
"""Return the response as an undecoded array of bytes.
628-
629-
Args:
630-
accept (str): The Accept header to send to the server (optional).
631-
"""
632-
633-
def __init__(self, accept=None):
634-
"""
635-
Args:
636-
accept:
637-
"""
638-
self.accept = accept
639-
640-
def __call__(self, stream, content_type):
641-
"""
642-
Args:
643-
stream:
644-
content_type:
645-
"""
646-
try:
647-
return stream.read()
648-
finally:
649-
stream.close()
650-
651-
652626
class StreamDeserializer(object):
653627
"""Returns the tuple of the response stream and the content-type of the response.
654628
It is the receivers responsibility to close the stream when they're done

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_predictors.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,11 @@ def test_import_from_node_should_be_modified_random_import():
117117
def test_import_from_modify_node():
118118
modifier = predictors.PredictorImportFromRenamer()
119119

120-
node = ast_import("from sagemaker.predictor import BytesDeserializer, RealTimePredictor")
120+
node = ast_import(
121+
"from sagemaker.predictor import ClassThatHasntBeenRenamed, RealTimePredictor"
122+
)
121123
modifier.modify_node(node)
122-
expected_result = "from sagemaker.predictor import BytesDeserializer, Predictor"
124+
expected_result = "from sagemaker.predictor import ClassThatHasntBeenRenamed, Predictor"
123125
assert expected_result == pasta.dump(node)
124126

125127
node = ast_import("from sagemaker.predictor import RealTimePredictor as RTP")

tests/unit/sagemaker/test_deserializers.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import io
1616

17-
from sagemaker.deserializers import StringDeserializer
17+
from sagemaker.deserializers import StringDeserializer, BytesDeserializer
1818

1919

2020
def test_string_deserializer():
@@ -23,3 +23,11 @@ def test_string_deserializer():
2323
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
2424

2525
assert result == "[1, 2, 3]"
26+
27+
28+
def test_bytes_deserializer():
29+
deserializer = BytesDeserializer()
30+
31+
result = deserializer.deserialize(io.BytesIO(b"[1, 2, 3]"), "application/json")
32+
33+
assert result == b"[1, 2, 3]"

tests/unit/test_predictor.py

+1-8
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
json_deserializer,
2727
csv_serializer,
2828
csv_deserializer,
29-
BytesDeserializer,
3029
StreamDeserializer,
3130
numpy_deserializer,
3231
npy_serializer,
@@ -181,13 +180,7 @@ def test_json_deserializer_invalid_data():
181180
with pytest.raises(ValueError) as error:
182181
json_deserializer(io.BytesIO(b"[[1]"), "application/json")
183182
assert "column" in str(error)
184-
185-
186-
def test_bytes_deserializer():
187-
result = BytesDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")
188-
189-
assert result == b"[1, 2, 3]"
190-
183+
191184

192185
def test_stream_deserializer():
193186
stream, content_type = StreamDeserializer()(io.BytesIO(b"[1, 2, 3]"), "application/json")

0 commit comments

Comments
 (0)