Skip to content

Bug fix: https://github.com/aws/sagemaker-python-sdk/issues/123 #294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 26, 2018
Merged
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@
CHANGELOG
=========

1.7.2.dev
=========
* bug-fix: Prediction output for the TF_JSON_SERIALIZER

1.7.1
=====

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def read(fname):

# Declare minimal set for installation
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3>=1.2',
'PyYAML>=3.2'],
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5'],

extras_require={
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',
Expand Down
25 changes: 15 additions & 10 deletions src/sagemaker/tensorflow/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@

import google.protobuf.json_format as json_format
from google.protobuf.message import DecodeError
from protobuf_to_dict import protobuf_to_dict
from tensorflow.core.framework import tensor_pb2
from tensorflow.python.framework import tensor_util

from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
from sagemaker.predictor import json_serializer, csv_serializer
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2

_POSSIBLE_RESPONSES = [predict_pb2.PredictResponse, classification_pb2.ClassificationResponse,
inference_pb2.MultiInferenceResponse, regression_pb2.RegressionResponse,
tensor_pb2.TensorProto]

REGRESSION_REQUEST = 'RegressionRequest'
MULTI_INFERENCE_REQUEST = 'MultiInferenceRequest'
CLASSIFICATION_REQUEST = 'ClassificationRequest'
Expand Down Expand Up @@ -53,17 +58,12 @@ def __init__(self):
self.accept = CONTENT_TYPE_OCTET_STREAM

def __call__(self, stream, content_type):
possible_responses = [predict_pb2.PredictResponse,
classification_pb2.ClassificationResponse,
inference_pb2.MultiInferenceResponse,
regression_pb2.RegressionResponse]

try:
data = stream.read()
finally:
stream.close()

for possible_response in possible_responses:
for possible_response in _POSSIBLE_RESPONSES:
try:
response = possible_response()
response.ParseFromString(data)
Expand Down Expand Up @@ -101,10 +101,15 @@ def __call__(self, stream, content_type):
data = stream.read()
finally:
stream.close()
try:
return json_format.Parse(data, tensor_pb2.TensorProto())
except json_format.ParseError:
return json.loads(data.decode())

for possible_response in _POSSIBLE_RESPONSES:
try:
return protobuf_to_dict(json_format.Parse(data, possible_response()))
Copy link
Contributor

@andremoeller andremoeller Jul 16, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would google.protobuf.json_format.MessageToDict (with json.loads) be suitable here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tried and won't work. google.protobuf.json_format.MessageToDict serializes the values again, which is the bug behavior that we won't to avoid.

except (UnicodeDecodeError, DecodeError, json_format.ParseError):
# given that the payload does not have the response type, there no way to infer
# the response without keeping state, so I'm iterating all the options.
pass
return json.loads(data.decode())


tf_json_deserializer = _TFJsonDeserializer()
Expand Down
75 changes: 73 additions & 2 deletions tests/unit/test_tf_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
import json
import sys

from google.protobuf import json_format
import numpy as np
import pytest
from google.protobuf import json_format
import tensorflow as tf
from mock import Mock
import tensorflow as tf
import six
from six import BytesIO
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, PREDICT_INPUTS

from sagemaker.predictor import RealTimePredictor
Expand Down Expand Up @@ -139,6 +141,75 @@ def test_classification_request_csv(sagemaker_session):
"""


def test_json_deserializer_should_work_with_predict_response():
data = b"""{
"outputs": {
"example_strings": {
"dtype": "DT_STRING",
"tensorShape": {
"dim": [
{
"size": "3"
}
]
},
"stringVal": [
"YXBwbGU=",
"YmFuYW5h",
"b3Jhbmdl"
]
},
"ages": {
"dtype": "DT_FLOAT",
"floatVal": [
4.954165935516357
],
"tensorShape": {
"dim": [
{
"size": "1"
}
]
}
}
},
"modelSpec": {
"version": "1531758457",
"name": "generic_model",
"signatureName": "serving_default"
}
}"""

stream = BytesIO(data)

response = tf_json_deserializer(stream, 'application/json')

if six.PY2:
string_vals = ['apple', 'banana', 'orange']
else:
string_vals = [b'apple', b'banana', b'orange']

assert response == {
'model_spec': {
'name': u'generic_model',
'signature_name': u'serving_default',
'version': {'value': 1531758457. if six.PY2 else 1531758457}
},
'outputs': {
u'ages': {
'dtype': 1,
'float_val': [4.954165935516357],
'tensor_shape': {'dim': [{'size': 1. if six.PY2 else 1}]}
},
u'example_strings': {
'dtype': 7,
'string_val': string_vals,
'tensor_shape': {'dim': [{'size': 3. if six.PY2 else 3}]}
}
}
}


def test_classification_request_pb(sagemaker_session):
request = classification_pb2.ClassificationRequest()
request.model_spec.name = "generic_model"
Expand Down