Skip to content

Commit 90b7c01

Browse files
author
Balaji Veeramani
committed
Add JSON serializer
1 parent b837dc2 commit 90b7c01

File tree

9 files changed

+124
-110
lines changed

9 files changed

+124
-110
lines changed

doc/frameworks/tensorflow/upgrade_from_legacy.rst

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,10 @@ For example, if you want to use JSON serialization and deserialization:
245245

246246
.. code:: python
247247
248-
from sagemaker.predictor import json_deserializer, json_serializer
248+
from sagemaker.predictor import json_deserializer
249+
from sagemaker.serializers import JSONSerializer
249250
250-
predictor.content_type = "application/json"
251-
predictor.serializer = json_serializer
251+
predictor.serializer = JSONSerializer()
252252
predictor.accept = "application/json"
253253
predictor.deserializer = json_deserializer
254254

src/sagemaker/mxnet/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
)
2727
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
2828
from sagemaker.mxnet import defaults
29-
from sagemaker.predictor import Predictor, json_serializer, json_deserializer
29+
from sagemaker.predictor import Predictor, json_deserializer
30+
from sagemaker.serializers import JSONSerializer
3031

3132
logger = logging.getLogger("sagemaker")
3233

@@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5051
using the default AWS configuration chain.
5152
"""
5253
super(MXNetPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, json_serializer, json_deserializer
54+
endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer
5455
)
5556

5657

src/sagemaker/predictor.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import codecs
1717
import csv
1818
import json
19-
import six
2019
from six import StringIO, BytesIO
2120
import numpy as np
2221

@@ -623,55 +622,6 @@ def __call__(self, stream, content_type):
623622
csv_deserializer = _CsvDeserializer()
624623

625624

626-
class _JsonSerializer(object):
627-
"""Placeholder docstring"""
628-
629-
def __init__(self):
630-
"""Placeholder docstring"""
631-
self.content_type = CONTENT_TYPE_JSON
632-
633-
def __call__(self, data):
634-
"""Take data of various formats and serialize them into the expected
635-
request body. This uses information about supported input formats for
636-
the deployed model.
637-
638-
Args:
639-
data (object): Data to be serialized.
640-
641-
Returns:
642-
object: Serialized data used for the request.
643-
"""
644-
if isinstance(data, dict):
645-
# convert each value in dict from a numpy array to a list if necessary, so they can be
646-
# json serialized
647-
return json.dumps({k: _ndarray_to_list(v) for k, v in six.iteritems(data)})
648-
649-
# files and buffers
650-
if hasattr(data, "read"):
651-
return _json_serialize_from_buffer(data)
652-
653-
return json.dumps(_ndarray_to_list(data))
654-
655-
656-
json_serializer = _JsonSerializer()
657-
658-
659-
def _ndarray_to_list(data):
660-
"""
661-
Args:
662-
data:
663-
"""
664-
return data.tolist() if isinstance(data, np.ndarray) else data
665-
666-
667-
def _json_serialize_from_buffer(buff):
668-
"""
669-
Args:
670-
buff:
671-
"""
672-
return buff.read()
673-
674-
675625
class _JsonDeserializer(object):
676626
"""Placeholder docstring"""
677627

src/sagemaker/serializers.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
from __future__ import absolute_import
1515

1616
import abc
17+
import json
18+
19+
import numpy as np
1720

1821

1922
class BaseSerializer(abc.ABC):
@@ -38,3 +41,34 @@ def serialize(self, data):
3841
@abc.abstractmethod
3942
def CONTENT_TYPE(self):
4043
"""The MIME type of the data sent to the inference endpoint."""
44+
45+
46+
class JSONSerializer(BaseSerializer):
47+
"""Serialize data to a JSON formatted string."""
48+
49+
CONTENT_TYPE = "application/json"
50+
51+
def serialize(self, data):
52+
"""Serialize data of various formats to a JSON formatted string.
53+
54+
Args:
55+
data (object): Data to be serialized.
56+
57+
Returns:
58+
str: The data serialized as a JSON string.
59+
"""
60+
if isinstance(data, dict):
61+
return json.dumps(
62+
{
63+
key: value.tolist() if isinstance(value, np.ndarray) else value
64+
for key, value in data.items()
65+
}
66+
)
67+
68+
if hasattr(data, "read"):
69+
return data.read()
70+
71+
if isinstance(data, np.ndarray):
72+
return json.dumps(data.tolist())
73+
74+
return json.dumps(data)

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import sagemaker
1919
from sagemaker.content_types import CONTENT_TYPE_JSON
2020
from sagemaker.fw_utils import create_image_uri
21-
from sagemaker.predictor import json_serializer, json_deserializer, Predictor
21+
from sagemaker.predictor import json_deserializer, Predictor
22+
from sagemaker.serializers import JSONSerializer
2223

2324

2425
class TensorFlowPredictor(Predictor):
@@ -30,7 +31,7 @@ def __init__(
3031
self,
3132
endpoint_name,
3233
sagemaker_session=None,
33-
serializer=json_serializer,
34+
serializer=JSONSerializer(),
3435
deserializer=json_deserializer,
3536
content_type=None,
3637
model_name=None,

tests/integ/test_inference_pipeline.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
from sagemaker.content_types import CONTENT_TYPE_CSV
2727
from sagemaker.model import Model
2828
from sagemaker.pipeline import PipelineModel
29-
from sagemaker.predictor import Predictor, json_serializer
29+
from sagemaker.predictor import Predictor
30+
from sagemaker.serializers import JSONSerializer
3031
from sagemaker.sparkml.model import SparkMLModel
3132
from sagemaker.utils import sagemaker_timestamp
3233

@@ -128,7 +129,7 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
128129
predictor = Predictor(
129130
endpoint_name=endpoint_name,
130131
sagemaker_session=sagemaker_session,
131-
serializer=json_serializer,
132+
serializer=JSONSerializer,
132133
content_type=CONTENT_TYPE_CSV,
133134
accept=CONTENT_TYPE_CSV,
134135
)

tests/integ/test_multidatamodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
290290
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
291291

292292
# Define a predictor to set `serializer` parameter with npy_serializer
293-
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
293+
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
294294
# Since we are using a placeholder container image the prediction results are not accurate.
295295
predictor = Predictor(
296296
endpoint_name=endpoint_name,
@@ -391,7 +391,7 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
391391
assert PRETRAINED_MODEL_PATH_2 in endpoint_models
392392

393393
# Define a predictor to set `serializer` parameter with npy_serializer
394-
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
394+
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
395395
# Since we are using a placeholder container image the prediction results are not accurate.
396396
predictor = Predictor(
397397
endpoint_name=endpoint_name,
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
18+
import numpy as np
19+
import pytest
20+
21+
from sagemaker.serializers import JSONSerializer
22+
from tests.unit import DATA_DIR
23+
24+
25+
@pytest.fixture
26+
def json_serializer():
27+
return JSONSerializer()
28+
29+
30+
def test_json_serializer_numpy_valid(json_serializer):
31+
result = json_serializer.serialize(np.array([1, 2, 3]))
32+
33+
assert result == "[1, 2, 3]"
34+
35+
36+
def test_json_serializer_numpy_valid_2dimensional(json_serializer):
37+
result = json_serializer.serialize(np.array([[1, 2, 3], [3, 4, 5]]))
38+
39+
assert result == "[[1, 2, 3], [3, 4, 5]]"
40+
41+
42+
def test_json_serializer_empty(json_serializer):
43+
assert json_serializer.serialize(np.array([])) == "[]"
44+
45+
46+
def test_json_serializer_python_array(json_serializer):
47+
result = json_serializer.serialize([1, 2, 3])
48+
49+
assert result == "[1, 2, 3]"
50+
51+
52+
def test_json_serializer_python_dictionary(json_serializer):
53+
d = {"gender": "m", "age": 22, "city": "Paris"}
54+
55+
result = json_serializer.serialize(d)
56+
57+
assert json.loads(result) == d
58+
59+
60+
def test_json_serializer_python_invalid_empty(json_serializer):
61+
assert json_serializer.serialize([]) == "[]"
62+
63+
64+
def test_json_serializer_python_dictionary_invalid_empty(json_serializer):
65+
assert json_serializer.serialize({}) == "{}"
66+
67+
68+
def test_json_serializer_csv_buffer(json_serializer):
69+
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
70+
with open(csv_file_path) as csv_file:
71+
validation_value = csv_file.read()
72+
csv_file.seek(0)
73+
result = json_serializer.serialize(csv_file)
74+
assert result == validation_value

tests/unit/test_predictor.py

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -22,64 +22,17 @@
2222

2323
from sagemaker.predictor import Predictor
2424
from sagemaker.predictor import (
25-
json_serializer,
2625
json_deserializer,
2726
csv_serializer,
2827
csv_deserializer,
2928
npy_serializer,
3029
)
30+
from sagemaker.serializers import JSONSerializer
3131
from tests.unit import DATA_DIR
3232

3333
# testing serialization functions
3434

3535

36-
def test_json_serializer_numpy_valid():
37-
result = json_serializer(np.array([1, 2, 3]))
38-
39-
assert result == "[1, 2, 3]"
40-
41-
42-
def test_json_serializer_numpy_valid_2dimensional():
43-
result = json_serializer(np.array([[1, 2, 3], [3, 4, 5]]))
44-
45-
assert result == "[[1, 2, 3], [3, 4, 5]]"
46-
47-
48-
def test_json_serializer_empty():
49-
assert json_serializer(np.array([])) == "[]"
50-
51-
52-
def test_json_serializer_python_array():
53-
result = json_serializer([1, 2, 3])
54-
55-
assert result == "[1, 2, 3]"
56-
57-
58-
def test_json_serializer_python_dictionary():
59-
d = {"gender": "m", "age": 22, "city": "Paris"}
60-
61-
result = json_serializer(d)
62-
63-
assert json.loads(result) == d
64-
65-
66-
def test_json_serializer_python_invalid_empty():
67-
assert json_serializer([]) == "[]"
68-
69-
70-
def test_json_serializer_python_dictionary_invalid_empty():
71-
assert json_serializer({}) == "{}"
72-
73-
74-
def test_json_serializer_csv_buffer():
75-
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
76-
with open(csv_file_path) as csv_file:
77-
validation_value = csv_file.read()
78-
csv_file.seek(0)
79-
result = json_serializer(csv_file)
80-
assert result == validation_value
81-
82-
8336
def test_csv_serializer_str():
8437
original = "1,2,3"
8538
result = csv_serializer("1,2,3")
@@ -404,7 +357,7 @@ def test_predict_call_with_headers_and_json():
404357
sagemaker_session,
405358
content_type="not/json",
406359
accept="also/not-json",
407-
serializer=json_serializer,
360+
serializer=JSONSerializer(),
408361
)
409362

410363
data = [1, 2]

0 commit comments

Comments
 (0)