diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5c8ff0b4d4..d12c404f1b 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ CHANGELOG ========= +1.7.2.dev +========= +* bug-fix: Prediction output for the TF_JSON_SERIALIZER + 1.7.1 ===== diff --git a/setup.py b/setup.py index 22e4313143..77f7dc34fb 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def read(fname): # Declare minimal set for installation install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3>=1.2', - 'PyYAML>=3.2'], + 'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5'], extras_require={ 'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist', diff --git a/src/sagemaker/tensorflow/predictor.py b/src/sagemaker/tensorflow/predictor.py index 41a054999e..5cced2b8bb 100644 --- a/src/sagemaker/tensorflow/predictor.py +++ b/src/sagemaker/tensorflow/predictor.py @@ -16,6 +16,7 @@ import google.protobuf.json_format as json_format from google.protobuf.message import DecodeError +from protobuf_to_dict import protobuf_to_dict from tensorflow.core.framework import tensor_pb2 from tensorflow.python.framework import tensor_util @@ -23,6 +24,10 @@ from sagemaker.predictor import json_serializer, csv_serializer from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2 +_POSSIBLE_RESPONSES = [predict_pb2.PredictResponse, classification_pb2.ClassificationResponse, + inference_pb2.MultiInferenceResponse, regression_pb2.RegressionResponse, + tensor_pb2.TensorProto] + REGRESSION_REQUEST = 'RegressionRequest' MULTI_INFERENCE_REQUEST = 'MultiInferenceRequest' CLASSIFICATION_REQUEST = 'ClassificationRequest' @@ -53,17 +58,12 @@ def __init__(self): self.accept = CONTENT_TYPE_OCTET_STREAM def __call__(self, stream, content_type): - possible_responses = [predict_pb2.PredictResponse, - classification_pb2.ClassificationResponse, - inference_pb2.MultiInferenceResponse, - regression_pb2.RegressionResponse] - try: data = stream.read() finally: stream.close() - for possible_response in possible_responses: + for possible_response in _POSSIBLE_RESPONSES: try: response = possible_response() response.ParseFromString(data) @@ -101,10 +101,15 @@ def __call__(self, stream, content_type): data = stream.read() finally: stream.close() - try: - return json_format.Parse(data, tensor_pb2.TensorProto()) - except json_format.ParseError: - return json.loads(data.decode()) + + for possible_response in _POSSIBLE_RESPONSES: + try: + return protobuf_to_dict(json_format.Parse(data, possible_response())) + except (UnicodeDecodeError, DecodeError, json_format.ParseError): + # given that the payload does not have the response type, there no way to infer + # the response without keeping state, so I'm iterating all the options. + pass + return json.loads(data.decode()) tf_json_deserializer = _TFJsonDeserializer() diff --git a/tests/unit/test_tf_predictor.py b/tests/unit/test_tf_predictor.py index 4a2c0ac7db..412fc0579a 100644 --- a/tests/unit/test_tf_predictor.py +++ b/tests/unit/test_tf_predictor.py @@ -16,11 +16,13 @@ import json import sys +from google.protobuf import json_format import numpy as np import pytest -from google.protobuf import json_format -import tensorflow as tf from mock import Mock +import tensorflow as tf +import six +from six import BytesIO from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, PREDICT_INPUTS from sagemaker.predictor import RealTimePredictor @@ -139,6 +141,75 @@ def test_classification_request_csv(sagemaker_session): """ +def test_json_deserializer_should_work_with_predict_response(): + data = b"""{ +"outputs": { + "example_strings": { + "dtype": "DT_STRING", + "tensorShape": { + "dim": [ + { + "size": "3" + } + ] + }, + "stringVal": [ + "YXBwbGU=", + "YmFuYW5h", + "b3Jhbmdl" + ] + }, + "ages": { + "dtype": "DT_FLOAT", + "floatVal": [ + 4.954165935516357 + ], + "tensorShape": { + "dim": [ + { + "size": "1" + } + ] + } + } + }, + "modelSpec": { + "version": "1531758457", + "name": "generic_model", + "signatureName": "serving_default" + } +}""" + + stream = BytesIO(data) + + response = tf_json_deserializer(stream, 'application/json') + + if six.PY2: + string_vals = ['apple', 'banana', 'orange'] + else: + string_vals = [b'apple', b'banana', b'orange'] + + assert response == { + 'model_spec': { + 'name': u'generic_model', + 'signature_name': u'serving_default', + 'version': {'value': 1531758457. if six.PY2 else 1531758457} + }, + 'outputs': { + u'ages': { + 'dtype': 1, + 'float_val': [4.954165935516357], + 'tensor_shape': {'dim': [{'size': 1. if six.PY2 else 1}]} + }, + u'example_strings': { + 'dtype': 7, + 'string_val': string_vals, + 'tensor_shape': {'dim': [{'size': 3. if six.PY2 else 3}]} + } + } + } + + def test_classification_request_pb(sagemaker_session): request = classification_pb2.ClassificationRequest() request.model_spec.name = "generic_model"