Skip to content

Commit 6094028

Browse files
author
Balaji Veeramani
committed
Merge branch 'add-record-serializer' of github.com:bveeramani/sagemaker-python-sdk into add-record-serializer
2 parents 1592dc7 + 30ec172 commit 6094028

File tree

7 files changed

+128
-109
lines changed

7 files changed

+128
-109
lines changed

src/sagemaker/chainer/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
)
2525
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2626
from sagemaker.chainer import defaults
27-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
27+
from sagemaker.deserializers import NumpyDeserializer
28+
from sagemaker.predictor import Predictor, npy_serializer
2829

2930
logger = logging.getLogger("sagemaker")
3031

@@ -48,7 +49,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4849
using the default AWS configuration chain.
4950
"""
5051
super(ChainerPredictor, self).__init__(
51-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
52+
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
5253
)
5354

5455

src/sagemaker/deserializers.py

+43
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import codecs
18+
import io
19+
import json
20+
21+
import numpy as np
1722

1823

1924
class BaseDeserializer(abc.ABC):
@@ -111,3 +116,41 @@ def deserialize(self, data, content_type):
111116
tuple: A two-tuple containing the stream and content-type.
112117
"""
113118
return data, content_type
119+
120+
121+
class NumpyDeserializer(BaseDeserializer):
122+
"""Deserialize a stream of data in the .npy format."""
123+
124+
ACCEPT = "application/x-npy"
125+
126+
def __init__(self, dtype=None):
127+
"""Initialize the dtype.
128+
129+
Args:
130+
dtype (str): The dtype of the data.
131+
"""
132+
self.dtype = dtype
133+
134+
def deserialize(self, data, content_type):
135+
"""Deserialize data from an inference endpoint into a NumPy array.
136+
137+
Args:
138+
data (botocore.response.StreamingBody): Data to be deserialized.
139+
content_type (str): The MIME type of the data.
140+
141+
Returns:
142+
numpy.ndarray: The data deserialized into a NumPy array.
143+
"""
144+
try:
145+
if content_type == "text/csv":
146+
return np.genfromtxt(
147+
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
148+
)
149+
if content_type == "application/json":
150+
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
151+
if content_type == "application/x-npy":
152+
return np.load(io.BytesIO(data.read()))
153+
finally:
154+
data.close()
155+
156+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

src/sagemaker/predictor.py

-44
Original file line numberDiff line numberDiff line change
@@ -698,50 +698,6 @@ def __call__(self, stream, content_type):
698698
json_deserializer = _JsonDeserializer()
699699

700700

701-
class _NumpyDeserializer(object):
702-
"""Placeholder docstring"""
703-
704-
def __init__(self, accept=CONTENT_TYPE_NPY, dtype=None):
705-
"""
706-
Args:
707-
accept:
708-
dtype:
709-
"""
710-
self.accept = accept
711-
self.dtype = dtype
712-
713-
def __call__(self, stream, content_type=CONTENT_TYPE_NPY):
714-
"""Decode from serialized data into a Numpy array.
715-
716-
Args:
717-
stream (stream): The response stream to be deserialized.
718-
content_type (str): The content type of the response. Can accept
719-
CSV, JSON, or NPY data.
720-
721-
Returns:
722-
object: Body of the response deserialized into a Numpy array.
723-
"""
724-
try:
725-
if content_type == CONTENT_TYPE_CSV:
726-
return np.genfromtxt(
727-
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
728-
)
729-
if content_type == CONTENT_TYPE_JSON:
730-
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
731-
if content_type == CONTENT_TYPE_NPY:
732-
return np.load(BytesIO(stream.read()))
733-
finally:
734-
stream.close()
735-
raise ValueError(
736-
"content_type must be one of the following: CSV, JSON, NPY. content_type: {}".format(
737-
content_type
738-
)
739-
)
740-
741-
742-
numpy_deserializer = _NumpyDeserializer()
743-
744-
745701
class _NPYSerializer(object):
746702
"""Placeholder docstring"""
747703

src/sagemaker/pytorch/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import packaging.version
1818

1919
import sagemaker
20+
from sagemaker.deserializers import NumpyDeserializer
2021
from sagemaker.fw_utils import (
2122
create_image_uri,
2223
model_code_key_prefix,
@@ -25,7 +26,7 @@
2526
)
2627
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2728
from sagemaker.pytorch import defaults
28-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
29+
from sagemaker.predictor import Predictor, npy_serializer
2930

3031
logger = logging.getLogger("sagemaker")
3132

@@ -49,7 +50,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4950
using the default AWS configuration chain.
5051
"""
5152
super(PyTorchPredictor, self).__init__(
52-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
53+
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
5354
)
5455

5556

src/sagemaker/sklearn/model.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
import logging
1717

1818
import sagemaker
19+
from sagemaker.deserializers import NumpyDeserializer
1920
from sagemaker.fw_registry import default_framework_uri
2021
from sagemaker.fw_utils import model_code_key_prefix, validate_version_or_image_args
2122
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
22-
from sagemaker.predictor import Predictor, npy_serializer, numpy_deserializer
23+
from sagemaker.predictor import Predictor, npy_serializer
2324
from sagemaker.sklearn import defaults
2425

2526
logger = logging.getLogger("sagemaker")
@@ -44,7 +45,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4445
using the default AWS configuration chain.
4546
"""
4647
super(SKLearnPredictor, self).__init__(
47-
endpoint_name, sagemaker_session, npy_serializer, numpy_deserializer
48+
endpoint_name, sagemaker_session, npy_serializer, NumpyDeserializer()
4849
)
4950

5051

tests/unit/sagemaker/test_deserializers.py

+76-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,15 @@
1414

1515
import io
1616

17-
from sagemaker.deserializers import StringDeserializer, BytesDeserializer, StreamDeserializer
17+
import numpy as np
18+
import pytest
19+
20+
from sagemaker.deserializers import (
21+
StringDeserializer,
22+
BytesDeserializer,
23+
StreamDeserializer,
24+
NumpyDeserializer,
25+
)
1826

1927

2028
def test_string_deserializer():
@@ -44,3 +52,70 @@ def test_stream_deserializer():
4452

4553
assert result == b"[1, 2, 3]"
4654
assert content_type == "application/json"
55+
56+
57+
@pytest.fixture
58+
def numpy_deserializer():
59+
return NumpyDeserializer()
60+
61+
62+
def test_numpy_deserializer_from_csv(numpy_deserializer):
63+
stream = io.BytesIO(b"1,2,3\n4,5,6")
64+
array = numpy_deserializer.deserialize(stream, "text/csv")
65+
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]]))
66+
67+
68+
def test_numpy_deserializer_from_csv_ragged(numpy_deserializer):
69+
stream = io.BytesIO(b"1,2,3\n4,5,6,7")
70+
with pytest.raises(ValueError) as error:
71+
numpy_deserializer.deserialize(stream, "text/csv")
72+
assert "errors were detected" in str(error)
73+
74+
75+
def test_numpy_deserializer_from_csv_alpha():
76+
numpy_deserializer = NumpyDeserializer(dtype="U5")
77+
stream = io.BytesIO(b"hello,2,3\n4,5,6")
78+
array = numpy_deserializer.deserialize(stream, "text/csv")
79+
assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]]))
80+
81+
82+
def test_numpy_deserializer_from_json(numpy_deserializer):
83+
stream = io.BytesIO(b"[[1,2,3],\n[4,5,6]]")
84+
array = numpy_deserializer.deserialize(stream, "application/json")
85+
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6]]))
86+
87+
88+
# Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists)
89+
def test_numpy_deserializer_from_json_ragged(numpy_deserializer):
90+
stream = io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]")
91+
array = numpy_deserializer.deserialize(stream, "application/json")
92+
assert np.array_equal(array, np.array([[1, 2, 3], [4, 5, 6, 7]]))
93+
94+
95+
def test_numpy_deserializer_from_json_alpha():
96+
numpy_deserializer = NumpyDeserializer(dtype="U5")
97+
stream = io.BytesIO(b'[["hello",2,3],\n[4,5,6]]')
98+
array = numpy_deserializer.deserialize(stream, "application/json")
99+
assert np.array_equal(array, np.array([["hello", 2, 3], [4, 5, 6]]))
100+
101+
102+
def test_numpy_deserializer_from_npy(numpy_deserializer):
103+
array = np.ones((2, 3))
104+
stream = io.BytesIO()
105+
np.save(stream, array)
106+
stream.seek(0)
107+
108+
result = numpy_deserializer.deserialize(stream, "application/x-npy")
109+
110+
assert np.array_equal(array, result)
111+
112+
113+
def test_numpy_deserializer_from_npy_object_array(numpy_deserializer):
114+
array = np.array(["one", "two"])
115+
stream = io.BytesIO()
116+
np.save(stream, array)
117+
stream.seek(0)
118+
119+
result = numpy_deserializer.deserialize(stream, "application/x-npy")
120+
121+
assert np.array_equal(array, result)

tests/unit/test_predictor.py

-58
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,7 @@
2626
json_deserializer,
2727
csv_serializer,
2828
csv_deserializer,
29-
numpy_deserializer,
3029
npy_serializer,
31-
_NumpyDeserializer,
3230
)
3331
from tests.unit import DATA_DIR
3432

@@ -259,62 +257,6 @@ def test_npy_serializer_python_invalid_empty():
259257
assert "empty array" in str(error)
260258

261259

262-
def test_numpy_deser_from_csv():
263-
arr = numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6"), "text/csv")
264-
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]]))
265-
266-
267-
def test_numpy_deser_from_csv_ragged():
268-
with pytest.raises(ValueError) as error:
269-
numpy_deserializer(io.BytesIO(b"1,2,3\n4,5,6,7"), "text/csv")
270-
assert "errors were detected" in str(error)
271-
272-
273-
def test_numpy_deser_from_csv_alpha():
274-
arr = _NumpyDeserializer(dtype="U5")(io.BytesIO(b"hello,2,3\n4,5,6"), "text/csv")
275-
assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]]))
276-
277-
278-
def test_numpy_deser_from_json():
279-
arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6]]"), "application/json")
280-
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6]]))
281-
282-
283-
# Sadly, ragged arrays work fine in JSON (giving us a 1D array of Python lists
284-
def test_numpy_deser_from_json_ragged():
285-
arr = numpy_deserializer(io.BytesIO(b"[[1,2,3],\n[4,5,6,7]]"), "application/json")
286-
assert np.array_equal(arr, np.array([[1, 2, 3], [4, 5, 6, 7]]))
287-
288-
289-
def test_numpy_deser_from_json_alpha():
290-
arr = _NumpyDeserializer(dtype="U5")(
291-
io.BytesIO(b'[["hello",2,3],\n[4,5,6]]'), "application/json"
292-
)
293-
assert np.array_equal(arr, np.array([["hello", 2, 3], [4, 5, 6]]))
294-
295-
296-
def test_numpy_deser_from_npy():
297-
array = np.ones((2, 3))
298-
stream = io.BytesIO()
299-
np.save(stream, array)
300-
stream.seek(0)
301-
302-
result = numpy_deserializer(stream)
303-
304-
assert np.array_equal(array, result)
305-
306-
307-
def test_numpy_deser_from_npy_object_array():
308-
array = np.array(["one", "two"])
309-
stream = io.BytesIO()
310-
np.save(stream, array)
311-
stream.seek(0)
312-
313-
result = numpy_deserializer(stream)
314-
315-
assert np.array_equal(array, result)
316-
317-
318260
# testing 'predict' invocations
319261

320262

0 commit comments

Comments
 (0)