Skip to content

Commit 6638474

Browse files
authored
Bug fix: #123 (#294)
* Bug fix: #123
1 parent eb5099c commit 6638474

File tree

4 files changed

+93
-13
lines changed

4 files changed

+93
-13
lines changed

CHANGELOG.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
CHANGELOG
33
=========
44

5+
1.7.2.dev
6+
=========
7+
* bug-fix: Prediction output for the TF_JSON_SERIALIZER
8+
59
1.7.1
610
=====
711

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def read(fname):
4545

4646
# Declare minimal set for installation
4747
install_requires=['boto3>=1.4.8', 'numpy>=1.9.0', 'protobuf>=3.1', 'scipy>=0.19.0', 'urllib3>=1.2',
48-
'PyYAML>=3.2'],
48+
'PyYAML>=3.2', 'protobuf3-to-dict>=0.1.5'],
4949

5050
extras_require={
5151
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',

src/sagemaker/tensorflow/predictor.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
import google.protobuf.json_format as json_format
1818
from google.protobuf.message import DecodeError
19+
from protobuf_to_dict import protobuf_to_dict
1920
from tensorflow.core.framework import tensor_pb2
2021
from tensorflow.python.framework import tensor_util
2122

2223
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
2324
from sagemaker.predictor import json_serializer, csv_serializer
2425
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2
2526

27+
_POSSIBLE_RESPONSES = [predict_pb2.PredictResponse, classification_pb2.ClassificationResponse,
28+
inference_pb2.MultiInferenceResponse, regression_pb2.RegressionResponse,
29+
tensor_pb2.TensorProto]
30+
2631
REGRESSION_REQUEST = 'RegressionRequest'
2732
MULTI_INFERENCE_REQUEST = 'MultiInferenceRequest'
2833
CLASSIFICATION_REQUEST = 'ClassificationRequest'
@@ -53,17 +58,12 @@ def __init__(self):
5358
self.accept = CONTENT_TYPE_OCTET_STREAM
5459

5560
def __call__(self, stream, content_type):
56-
possible_responses = [predict_pb2.PredictResponse,
57-
classification_pb2.ClassificationResponse,
58-
inference_pb2.MultiInferenceResponse,
59-
regression_pb2.RegressionResponse]
60-
6161
try:
6262
data = stream.read()
6363
finally:
6464
stream.close()
6565

66-
for possible_response in possible_responses:
66+
for possible_response in _POSSIBLE_RESPONSES:
6767
try:
6868
response = possible_response()
6969
response.ParseFromString(data)
@@ -101,10 +101,15 @@ def __call__(self, stream, content_type):
101101
data = stream.read()
102102
finally:
103103
stream.close()
104-
try:
105-
return json_format.Parse(data, tensor_pb2.TensorProto())
106-
except json_format.ParseError:
107-
return json.loads(data.decode())
104+
105+
for possible_response in _POSSIBLE_RESPONSES:
106+
try:
107+
return protobuf_to_dict(json_format.Parse(data, possible_response()))
108+
except (UnicodeDecodeError, DecodeError, json_format.ParseError):
109+
# given that the payload does not have the response type, there no way to infer
110+
# the response without keeping state, so I'm iterating all the options.
111+
pass
112+
return json.loads(data.decode())
108113

109114

110115
tf_json_deserializer = _TFJsonDeserializer()

tests/unit/test_tf_predictor.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
import json
1717
import sys
1818

19+
from google.protobuf import json_format
1920
import numpy as np
2021
import pytest
21-
from google.protobuf import json_format
22-
import tensorflow as tf
2322
from mock import Mock
23+
import tensorflow as tf
24+
import six
25+
from six import BytesIO
2426
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, PREDICT_INPUTS
2527

2628
from sagemaker.predictor import RealTimePredictor
@@ -139,6 +141,75 @@ def test_classification_request_csv(sagemaker_session):
139141
"""
140142

141143

144+
def test_json_deserializer_should_work_with_predict_response():
145+
data = b"""{
146+
"outputs": {
147+
"example_strings": {
148+
"dtype": "DT_STRING",
149+
"tensorShape": {
150+
"dim": [
151+
{
152+
"size": "3"
153+
}
154+
]
155+
},
156+
"stringVal": [
157+
"YXBwbGU=",
158+
"YmFuYW5h",
159+
"b3Jhbmdl"
160+
]
161+
},
162+
"ages": {
163+
"dtype": "DT_FLOAT",
164+
"floatVal": [
165+
4.954165935516357
166+
],
167+
"tensorShape": {
168+
"dim": [
169+
{
170+
"size": "1"
171+
}
172+
]
173+
}
174+
}
175+
},
176+
"modelSpec": {
177+
"version": "1531758457",
178+
"name": "generic_model",
179+
"signatureName": "serving_default"
180+
}
181+
}"""
182+
183+
stream = BytesIO(data)
184+
185+
response = tf_json_deserializer(stream, 'application/json')
186+
187+
if six.PY2:
188+
string_vals = ['apple', 'banana', 'orange']
189+
else:
190+
string_vals = [b'apple', b'banana', b'orange']
191+
192+
assert response == {
193+
'model_spec': {
194+
'name': u'generic_model',
195+
'signature_name': u'serving_default',
196+
'version': {'value': 1531758457. if six.PY2 else 1531758457}
197+
},
198+
'outputs': {
199+
u'ages': {
200+
'dtype': 1,
201+
'float_val': [4.954165935516357],
202+
'tensor_shape': {'dim': [{'size': 1. if six.PY2 else 1}]}
203+
},
204+
u'example_strings': {
205+
'dtype': 7,
206+
'string_val': string_vals,
207+
'tensor_shape': {'dim': [{'size': 3. if six.PY2 else 3}]}
208+
}
209+
}
210+
}
211+
212+
142213
def test_classification_request_pb(sagemaker_session):
143214
request = classification_pb2.ClassificationRequest()
144215
request.model_spec.name = "generic_model"

0 commit comments

Comments
 (0)