Skip to content

breaking: Move _JsonSerializer to sagemaker.serializers.JSONSerializer #1698

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions doc/frameworks/tensorflow/upgrade_from_legacy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,10 @@ For example, if you want to use JSON serialization and deserialization:

.. code:: python

from sagemaker.predictor import json_deserializer, json_serializer
from sagemaker.predictor import json_deserializer
from sagemaker.serializers import JSONSerializer

predictor.content_type = "application/json"
predictor.serializer = json_serializer
predictor.serializer = JSONSerializer()
predictor.accept = "application/json"
predictor.deserializer = json_deserializer

Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
)
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet import defaults
from sagemaker.predictor import Predictor, json_serializer, json_deserializer
from sagemaker.predictor import Predictor, json_deserializer
from sagemaker.serializers import JSONSerializer

logger = logging.getLogger("sagemaker")

Expand All @@ -50,7 +51,7 @@ def __init__(self, endpoint_name, sagemaker_session=None):
using the default AWS configuration chain.
"""
super(MXNetPredictor, self).__init__(
endpoint_name, sagemaker_session, json_serializer, json_deserializer
endpoint_name, sagemaker_session, JSONSerializer(), json_deserializer
)


Expand Down
50 changes: 0 additions & 50 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import codecs
import csv
import json
import six
from six import StringIO, BytesIO
import numpy as np

Expand Down Expand Up @@ -597,55 +596,6 @@ def _row_to_csv(obj):
return ",".join(obj)


class _JsonSerializer(object):
"""Placeholder docstring"""

def __init__(self):
"""Placeholder docstring"""
self.content_type = CONTENT_TYPE_JSON

def __call__(self, data):
"""Take data of various formats and serialize them into the expected
request body. This uses information about supported input formats for
the deployed model.

Args:
data (object): Data to be serialized.

Returns:
object: Serialized data used for the request.
"""
if isinstance(data, dict):
# convert each value in dict from a numpy array to a list if necessary, so they can be
# json serialized
return json.dumps({k: _ndarray_to_list(v) for k, v in six.iteritems(data)})

# files and buffers
if hasattr(data, "read"):
return _json_serialize_from_buffer(data)

return json.dumps(_ndarray_to_list(data))


json_serializer = _JsonSerializer()


def _ndarray_to_list(data):
"""
Args:
data:
"""
return data.tolist() if isinstance(data, np.ndarray) else data


def _json_serialize_from_buffer(buff):
"""
Args:
buff:
"""
return buff.read()


class _JsonDeserializer(object):
"""Placeholder docstring"""

Expand Down
34 changes: 34 additions & 0 deletions src/sagemaker/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from __future__ import absolute_import

import abc
import json

import numpy as np


class BaseSerializer(abc.ABC):
Expand All @@ -38,3 +41,34 @@ def serialize(self, data):
@abc.abstractmethod
def CONTENT_TYPE(self):
"""The MIME type of the data sent to the inference endpoint."""


class JSONSerializer(BaseSerializer):
"""Serialize data to a JSON formatted string."""

CONTENT_TYPE = "application/json"

def serialize(self, data):
"""Serialize data of various formats to a JSON formatted string.

Args:
data (object): Data to be serialized.

Returns:
str: The data serialized as a JSON string.
"""
if isinstance(data, dict):
return json.dumps(
{
key: value.tolist() if isinstance(value, np.ndarray) else value
for key, value in data.items()
}
)

if hasattr(data, "read"):
return data.read()

if isinstance(data, np.ndarray):
return json.dumps(data.tolist())

return json.dumps(data)
5 changes: 3 additions & 2 deletions src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import sagemaker
from sagemaker.content_types import CONTENT_TYPE_JSON
from sagemaker.fw_utils import create_image_uri
from sagemaker.predictor import json_serializer, json_deserializer, Predictor
from sagemaker.predictor import json_deserializer, Predictor
from sagemaker.serializers import JSONSerializer


class TensorFlowPredictor(Predictor):
Expand All @@ -30,7 +31,7 @@ def __init__(
self,
endpoint_name,
sagemaker_session=None,
serializer=json_serializer,
serializer=JSONSerializer(),
deserializer=json_deserializer,
content_type=None,
model_name=None,
Expand Down
5 changes: 3 additions & 2 deletions tests/integ/test_inference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
from sagemaker.content_types import CONTENT_TYPE_CSV
from sagemaker.model import Model
from sagemaker.pipeline import PipelineModel
from sagemaker.predictor import Predictor, json_serializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
from sagemaker.sparkml.model import SparkMLModel
from sagemaker.utils import sagemaker_timestamp

Expand Down Expand Up @@ -128,7 +129,7 @@ def test_inference_pipeline_model_deploy(sagemaker_session, cpu_instance_type):
predictor = Predictor(
endpoint_name=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=json_serializer,
serializer=JSONSerializer,
content_type=CONTENT_TYPE_CSV,
accept=CONTENT_TYPE_CSV,
)
Expand Down
4 changes: 2 additions & 2 deletions tests/integ/test_multidatamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_multi_data_model_deploy_trained_model_from_framework_estimator(
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_multi_data_model_deploy_train_model_from_amazon_first_party_estimator(
assert PRETRAINED_MODEL_PATH_2 in endpoint_models

# Define a predictor to set `serializer` parameter with npy_serializer
# instead of `json_serializer` in the default predictor returned by `MXNetPredictor`
# instead of `JSONSerializer` in the default predictor returned by `MXNetPredictor`
# Since we are using a placeholder container image the prediction results are not accurate.
predictor = Predictor(
endpoint_name=endpoint_name,
Expand Down
74 changes: 74 additions & 0 deletions tests/unit/sagemaker/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os

import numpy as np
import pytest

from sagemaker.serializers import JSONSerializer
from tests.unit import DATA_DIR


@pytest.fixture
def json_serializer():
return JSONSerializer()


def test_json_serializer_numpy_valid(json_serializer):
result = json_serializer.serialize(np.array([1, 2, 3]))

assert result == "[1, 2, 3]"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I didn't really notice this before because I know you're just copying the tests over, but I believe the expected result is actually supposed to go first in the comparison. (The order is for the error message when the assertion fails.)



def test_json_serializer_numpy_valid_2dimensional(json_serializer):
result = json_serializer.serialize(np.array([[1, 2, 3], [3, 4, 5]]))

assert result == "[[1, 2, 3], [3, 4, 5]]"


def test_json_serializer_empty(json_serializer):
assert json_serializer.serialize(np.array([])) == "[]"


def test_json_serializer_python_array(json_serializer):
result = json_serializer.serialize([1, 2, 3])

assert result == "[1, 2, 3]"


def test_json_serializer_python_dictionary(json_serializer):
d = {"gender": "m", "age": 22, "city": "Paris"}

result = json_serializer.serialize(d)

assert json.loads(result) == d


def test_json_serializer_python_invalid_empty(json_serializer):
assert json_serializer.serialize([]) == "[]"


def test_json_serializer_python_dictionary_invalid_empty(json_serializer):
assert json_serializer.serialize({}) == "{}"


def test_json_serializer_csv_buffer(json_serializer):
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
with open(csv_file_path) as csv_file:
validation_value = csv_file.read()
csv_file.seek(0)
result = json_serializer.serialize(csv_file)
assert result == validation_value
51 changes: 2 additions & 49 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,63 +22,16 @@

from sagemaker.predictor import Predictor
from sagemaker.predictor import (
json_serializer,
json_deserializer,
csv_serializer,
npy_serializer,
)
from sagemaker.serializers import JSONSerializer
from tests.unit import DATA_DIR

# testing serialization functions


def test_json_serializer_numpy_valid():
result = json_serializer(np.array([1, 2, 3]))

assert result == "[1, 2, 3]"


def test_json_serializer_numpy_valid_2dimensional():
result = json_serializer(np.array([[1, 2, 3], [3, 4, 5]]))

assert result == "[[1, 2, 3], [3, 4, 5]]"


def test_json_serializer_empty():
assert json_serializer(np.array([])) == "[]"


def test_json_serializer_python_array():
result = json_serializer([1, 2, 3])

assert result == "[1, 2, 3]"


def test_json_serializer_python_dictionary():
d = {"gender": "m", "age": 22, "city": "Paris"}

result = json_serializer(d)

assert json.loads(result) == d


def test_json_serializer_python_invalid_empty():
assert json_serializer([]) == "[]"


def test_json_serializer_python_dictionary_invalid_empty():
assert json_serializer({}) == "{}"


def test_json_serializer_csv_buffer():
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
with open(csv_file_path) as csv_file:
validation_value = csv_file.read()
csv_file.seek(0)
result = json_serializer(csv_file)
assert result == validation_value


def test_csv_serializer_str():
original = "1,2,3"
result = csv_serializer("1,2,3")
Expand Down Expand Up @@ -388,7 +341,7 @@ def test_predict_call_with_headers_and_json():
sagemaker_session,
content_type="not/json",
accept="also/not-json",
serializer=json_serializer,
serializer=JSONSerializer(),
)

data = [1, 2]
Expand Down