Skip to content

Commit d47f6d1

Browse files
andremoellerlukmis
authored andcommitted
JSON serializer: predictor.predict accepts dictionaries (#62)
Add support for serializing python dictionaries to json Add prediction with dictionary in tf iris integ test
1 parent 795b030 commit d47f6d1

File tree

6 files changed

+35
-8
lines changed

6 files changed

+35
-8
lines changed

src/sagemaker/mxnet/model.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
class MXNetPredictor(RealTimePredictor):
2121
"""A RealTimePredictor for inference against MXNet Endpoints.
2222
23-
This is able to serialize Python lists and numpy arrays to multidimensional tensors for MXNet inference."""
23+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
24+
inference."""
2425

2526
def __init__(self, endpoint_name, sagemaker_session=None):
2627
"""Initialize an ``MXNetPredictor``.

src/sagemaker/predictor.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def __call__(self, data):
240240
if isinstance(data, list):
241241
if not len(data) > 0:
242242
raise ValueError("empty array can't be serialized")
243-
return _json_serialize_python_array(data)
243+
return _json_serialize_python_object(data)
244+
245+
if isinstance(data, dict):
246+
if not len(data.keys()) > 0:
247+
raise ValueError("empty dictionary can't be serialized")
248+
return _json_serialize_python_object(data)
244249

245250
# files and buffers
246251
if hasattr(data, 'read'):
@@ -254,10 +259,10 @@ def __call__(self, data):
254259

255260
def _json_serialize_numpy_array(data):
256261
# numpy arrays can't be serialized but we know they have uniform type
257-
return _json_serialize_python_array(data.tolist())
262+
return _json_serialize_python_object(data.tolist())
258263

259264

260-
def _json_serialize_python_array(data):
265+
def _json_serialize_python_object(data):
261266
return _json_serialize_object(data)
262267

263268

src/sagemaker/tensorflow/model.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020

2121
class TensorFlowPredictor(RealTimePredictor):
22-
"""A ``RealTimePredictor`` for inference against MXNet ``Endpoint``s."""
22+
"""A ``RealTimePredictor`` for inference against TensorFlow ``Endpoint``s.
2323
24+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
25+
inference"""
2426
def __init__(self, endpoint_name, sagemaker_session=None):
2527
"""Initialize an ``TensorFlowPredictor``.
2628

src/sagemaker/tensorflow/predictor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self):
3232
self.content_type = CONTENT_TYPE_OCTET_STREAM
3333

3434
def __call__(self, data):
35-
# isintance does not work here because a same protobuf message can be imported from a different module.
35+
# isinstance does not work here because a same protobuf message can be imported from a different module.
3636
# for example sagemaker.tensorflow.tensorflow_serving.regression_pb2 and tensorflow_serving.apis.regression_pb2
3737
predict_type = data.__class__.__name__
3838

tests/integ/test_tf.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,13 @@ def test_tf(sagemaker_session):
4949
with timeout_and_delete_endpoint(estimator=estimator, minutes=20):
5050
json_predictor = estimator.deploy(initial_instance_count=1, instance_type='ml.c4.xlarge')
5151

52-
result = json_predictor.predict([6.4, 3.2, 4.5, 1.5])
53-
print('predict result: {}'.format(result))
52+
features = [6.4, 3.2, 4.5, 1.5]
53+
dict_result = json_predictor.predict({'inputs': features})
54+
print('predict result: {}'.format(dict_result))
55+
list_result = json_predictor.predict(features)
56+
print('predict result: {}'.format(list_result))
57+
58+
assert dict_result == list_result
5459

5560

5661
def test_tf_async(sagemaker_session):

tests/unit/test_predictor.py

+14
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,26 @@ def test_json_serializer_python_array():
5151
assert result == '[1, 2, 3]'
5252

5353

54+
def test_json_serializer_python_dictionary():
55+
d = {"gender": "m", "age": 22, "city": "Paris"}
56+
57+
result = json_serializer(d)
58+
59+
assert json.loads(result) == d
60+
61+
5462
def test_json_serializer_python_invalid_empty():
5563
with pytest.raises(ValueError) as error:
5664
json_serializer([])
5765
assert "empty array" in str(error)
5866

5967

68+
def test_json_serializer_python_dictionary_invalid_empty():
69+
with pytest.raises(ValueError) as error:
70+
json_serializer({})
71+
assert "empty dictionary" in str(error)
72+
73+
6074
def test_json_serializer_csv_buffer():
6175
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
6276
with open(csv_file_path) as csv_file:

0 commit comments

Comments
 (0)