Skip to content

Commit db0e09c

Browse files
committed
add support for serializing python dictionaries to json
1 parent ea0c5f9 commit db0e09c

File tree

5 files changed

+28
-6
lines changed

5 files changed

+28
-6
lines changed

src/sagemaker/mxnet/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
class MXNetPredictor(RealTimePredictor):
2121
"""A RealTimePredictor for inference against MXNet Endpoints.
2222
23-
This is able to serialize Python lists and numpy arrays to multidimensional tensors for MXNet inference."""
23+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
24+
inference."""
2425

2526
def __init__(self, endpoint_name, sagemaker_session=None):
2627
"""Initialize an ``MXNetPredictor``.

src/sagemaker/predictor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,12 @@ def __call__(self, data):
240240
if isinstance(data, list):
241241
if not len(data) > 0:
242242
raise ValueError("empty array can't be serialized")
243-
return _json_serialize_python_array(data)
243+
return _json_serialize_python_object(data)
244+
245+
if isinstance(data, dict):
246+
if not len(data.keys()) > 0:
247+
raise ValueError("empty dictionary can't be serialized")
248+
return _json_serialize_python_object(data)
244249

245250
# files and buffers
246251
if hasattr(data, 'read'):
@@ -254,10 +259,10 @@ def __call__(self, data):
254259

255260
def _json_serialize_numpy_array(data):
256261
# numpy arrays can't be serialized but we know they have uniform type
257-
return _json_serialize_python_array(data.tolist())
262+
return _json_serialize_python_object(data.tolist())
258263

259264

260-
def _json_serialize_python_array(data):
265+
def _json_serialize_python_object(data):
261266
return _json_serialize_object(data)
262267

263268

src/sagemaker/tensorflow/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020

2121
class TensorFlowPredictor(RealTimePredictor):
22-
"""A ``RealTimePredictor`` for inference against MXNet ``Endpoint``s."""
22+
"""A ``RealTimePredictor`` for inference against TensorFlow ``Endpoint``s.
2323
24+
This is able to serialize Python lists, dictionaries, and numpy arrays to multidimensional tensors for MXNet
25+
inference"""
2426
def __init__(self, endpoint_name, sagemaker_session=None):
2527
"""Initialize an ``TensorFlowPredictor``.
2628

src/sagemaker/tensorflow/predictor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self):
3232
self.content_type = CONTENT_TYPE_OCTET_STREAM
3333

3434
def __call__(self, data):
35-
# isintance does not work here because a same protobuf message can be imported from a different module.
35+
# isinstance does not work here because a same protobuf message can be imported from a different module.
3636
# for example sagemaker.tensorflow.tensorflow_serving.regression_pb2 and tensorflow_serving.apis.regression_pb2
3737
predict_type = data.__class__.__name__
3838

tests/unit/test_predictor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,26 @@ def test_json_serializer_python_array():
5151
assert result == '[1, 2, 3]'
5252

5353

54+
def test_json_serializer_python_dictionary():
55+
d = {"gender": "m", "age": 22, "city": "Paris"}
56+
57+
result = json_serializer(d)
58+
59+
assert json.loads(result) == d
60+
61+
5462
def test_json_serializer_python_invalid_empty():
5563
with pytest.raises(ValueError) as error:
5664
json_serializer([])
5765
assert "empty array" in str(error)
5866

5967

68+
def test_json_serializer_python_dictionary_invalid_empty():
69+
with pytest.raises(ValueError) as error:
70+
json_serializer({})
71+
assert "empty dictionary" in str(error)
72+
73+
6074
def test_json_serializer_csv_buffer():
6175
csv_file_path = os.path.join(DATA_DIR, "with_integers.csv")
6276
with open(csv_file_path) as csv_file:

0 commit comments

Comments
 (0)