Skip to content

Commit 8006a1d

Browse files
committed
Bug fix: #123
1 parent 5ea3fd0 commit 8006a1d

File tree

3 files changed

+83
-11
lines changed

3 files changed

+83
-11
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def read(fname):
4949

5050
extras_require={
5151
'test': ['tox', 'flake8', 'pytest', 'pytest-cov', 'pytest-xdist',
52-
'mock', 'tensorflow>=1.3.0', 'contextlib2', 'awslogs', 'pandas']},
52+
'mock', 'tensorflow>=1.3.0', 'protobuf3-to-dict', 'contextlib2', 'awslogs', 'pandas']},
5353

5454
entry_points={
5555
'console_scripts': ['sagemaker=sagemaker.cli.main:main'],

src/sagemaker/tensorflow/predictor.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,18 @@
1616

1717
import google.protobuf.json_format as json_format
1818
from google.protobuf.message import DecodeError
19+
from protobuf_to_dict import protobuf_to_dict
1920
from tensorflow.core.framework import tensor_pb2
2021
from tensorflow.python.framework import tensor_util
2122

2223
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_OCTET_STREAM, CONTENT_TYPE_CSV
2324
from sagemaker.predictor import json_serializer, csv_serializer
2425
from tensorflow_serving.apis import predict_pb2, classification_pb2, inference_pb2, regression_pb2
2526

27+
_POSSIBLE_RESPONSES = [predict_pb2.PredictResponse, classification_pb2.ClassificationResponse,
28+
inference_pb2.MultiInferenceResponse, regression_pb2.RegressionResponse,
29+
tensor_pb2.TensorProto]
30+
2631
REGRESSION_REQUEST = 'RegressionRequest'
2732
MULTI_INFERENCE_REQUEST = 'MultiInferenceRequest'
2833
CLASSIFICATION_REQUEST = 'ClassificationRequest'
@@ -53,17 +58,12 @@ def __init__(self):
5358
self.accept = CONTENT_TYPE_OCTET_STREAM
5459

5560
def __call__(self, stream, content_type):
56-
possible_responses = [predict_pb2.PredictResponse,
57-
classification_pb2.ClassificationResponse,
58-
inference_pb2.MultiInferenceResponse,
59-
regression_pb2.RegressionResponse]
60-
6161
try:
6262
data = stream.read()
6363
finally:
6464
stream.close()
6565

66-
for possible_response in possible_responses:
66+
for possible_response in _POSSIBLE_RESPONSES:
6767
try:
6868
response = possible_response()
6969
response.ParseFromString(data)
@@ -101,10 +101,15 @@ def __call__(self, stream, content_type):
101101
data = stream.read()
102102
finally:
103103
stream.close()
104-
try:
105-
return json_format.Parse(data, tensor_pb2.TensorProto())
106-
except json_format.ParseError:
107-
return json.loads(data.decode())
104+
105+
for possible_response in _POSSIBLE_RESPONSES:
106+
try:
107+
return protobuf_to_dict(json_format.Parse(data, possible_response()))
108+
except (UnicodeDecodeError, DecodeError, json_format.ParseError):
109+
# given that the payload does not have the response type, there no way to infer
110+
# the response without keeping state, so I'm iterating all the options.
111+
pass
112+
return json.loads(data.decode())
108113

109114

110115
tf_json_deserializer = _TFJsonDeserializer()

tests/unit/test_tf_predictor.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from google.protobuf import json_format
2222
import tensorflow as tf
2323
from mock import Mock
24+
from six import StringIO
2425
from tensorflow.python.saved_model.signature_constants import DEFAULT_SERVING_SIGNATURE_DEF_KEY, PREDICT_INPUTS
2526

2627
from sagemaker.predictor import RealTimePredictor
@@ -139,6 +140,72 @@ def test_classification_request_csv(sagemaker_session):
139140
"""
140141

141142

143+
def test_json_deserializer_should_work_with_predict_response():
144+
data = b"""{
145+
"outputs": {
146+
"example_strings": {
147+
"dtype": "DT_STRING",
148+
"tensorShape": {
149+
"dim": [
150+
{
151+
"size": "3"
152+
}
153+
]
154+
},
155+
"stringVal": [
156+
"YXBwbGU=",
157+
"YmFuYW5h",
158+
"b3Jhbmdl"
159+
]
160+
},
161+
"ages": {
162+
"dtype": "DT_FLOAT",
163+
"floatVal": [
164+
4.954165935516357
165+
],
166+
"tensorShape": {
167+
"dim": [
168+
{
169+
"size": "1"
170+
}
171+
]
172+
}
173+
}
174+
},
175+
"modelSpec": {
176+
"version": "1531758457",
177+
"name": "generic_model",
178+
"signatureName": "serving_default"
179+
}
180+
}"""
181+
182+
stream = StringIO(data)
183+
184+
response = tf_json_deserializer(stream, 'application/json')
185+
186+
assert response == {
187+
'model_spec': {
188+
'name': u'generic_model',
189+
'signature_name': u'serving_default',
190+
'version': {'value': 1531758457L}
191+
},
192+
'outputs': {
193+
u'ages': {
194+
'dtype': 1,
195+
'float_val': [4.954165935516357],
196+
'tensor_shape': {'dim': [{'size': 1L}]}
197+
},
198+
u'example_strings': {
199+
'dtype': 7,
200+
'string_val': ['apple',
201+
'banana',
202+
'orange'],
203+
'tensor_shape': {'dim': [{'size': 3L}]}
204+
}
205+
}
206+
}
207+
208+
142209
def test_classification_request_pb(sagemaker_session):
143210
request = classification_pb2.ClassificationRequest()
144211
request.model_spec.name = "generic_model"

0 commit comments

Comments
 (0)