Skip to content

JSON serializer: predictor.predict accepts dictionaries #62

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 15, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/sagemaker/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.

This is able to serialize Python lists and numpy arrays to multidimensional tensors for MXNet inference."""
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
inference."""

def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``MXNetPredictor``.
Expand Down
11 changes: 8 additions & 3 deletions src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@ def __call__(self, data):
if isinstance(data, list):
if not len(data) > 0:
raise ValueError("empty array can't be serialized")
return _json_serialize_python_array(data)
return _json_serialize_python_object(data)

if isinstance(data, dict):
if not len(data.keys()) > 0:
raise ValueError("empty dictionary can't be serialized")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any special reason for handling the empty dictionary case? An empty dictionary is a valid json.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Valid JSON: sure, but similar to an empty list, what would I attempt to be predicting without data? I cannot think of such a use case where I'd expect my payload to be empty. My prediction response wouldn't depend on my payload. But maybe it's better to be permissive here and with the list. Let me know what you think, I can change this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I noticed that you are applying the same previously applied pattern to check if it is empty or not. I'm ok with that.

return _json_serialize_python_object(data)

# files and buffers
if hasattr(data, 'read'):
Expand All @@ -254,10 +259,10 @@ def __call__(self, data):

def _json_serialize_numpy_array(data):
# numpy arrays can't be serialized but we know they have uniform type
return _json_serialize_python_array(data.tolist())
return _json_serialize_python_object(data.tolist())


def _json_serialize_python_array(data):
def _json_serialize_python_object(data):
return _json_serialize_object(data)


Expand Down
4 changes: 3 additions & 1 deletion src/sagemaker/tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@


class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against MXNet ``Endpoint``s."""
"""A ``RealTimePredictor`` for inference against TensorFlow ``Endpoint``s.

This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
inference"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``TensorFlowPredictor``.

Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/tensorflow/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self):
self.content_type = CONTENT_TYPE_OCTET_STREAM

def __call__(self, data):
# isintance does not work here because a same protobuf message can be imported from a different module.
# isinstance does not work here because a same protobuf message can be imported from a different module.
# for example sagemaker.tensorflow.tensorflow_serving.regression_pb2 and tensorflow_serving.apis.regression_pb2
predict_type = data.__class__.__name__

Expand Down
9 changes: 7 additions & 2 deletions tests/integ/test_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,13 @@ def test_tf(sagemaker_session):
with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
json_predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')

result = json_predictor.predict([6.4, 3.2, 4.5, 1.5])
print('predict result: {}'.format(result))
features = [6.4, 3.2, 4.5, 1.5]
dict_result = json_predictor.predict({'inputs': features})
print('predict result: {}'.format(dict_result))
list_result = json_predictor.predict(features)
print('predict result: {}'.format(list_result))

assert dict_result == list_result


def test_failed_tf_training(sagemaker_session):
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,26 @@ def test_json_serializer_python_array():
assert result == '[1, 2, 3]'


def test_json_serializer_python_dictionary():
d = {"gender": "m", "age": 22, "city": "Paris"}

result = json_serializer(d)

assert json.loads(result) == d


def test_json_serializer_python_invalid_empty():
with pytest.raises(ValueError) as error:
json_serializer([])
assert "empty array" in str(error)


def test_json_serializer_python_dictionary_invalid_empty():
with pytest.raises(ValueError) as error:
json_serializer({})
assert "empty dictionary" in str(error)


def test_json_serializer_csv_buffer():
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
with open(csv_file_path) as csv_file:
Expand Down