Skip to content

Commit e7a60f6

Browse files
authored
Change error classes used by _default_input_fn() and _default_output_fn() (#65)
This will make the containers return more appropriate HTTP status codes when _default_input_fn() receives an unsupported content type or when _default_output_fn() receives an unsupported accpet type.
1 parent e342aa6 commit e7a60f6

File tree

2 files changed

+32
-16
lines changed

2 files changed

+32
-16
lines changed

src/tf_container/serve.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from tf_container import proxy_client
2626
from six import StringIO
2727
import csv
28-
from container_support.serving import JSON_CONTENT_TYPE, CSV_CONTENT_TYPE, OCTET_STREAM_CONTENT_TYPE, ANY_CONTENT_TYPE
28+
from container_support.serving import UnsupportedContentTypeError, UnsupportedAcceptTypeError, \
29+
JSON_CONTENT_TYPE, CSV_CONTENT_TYPE, \
30+
OCTET_STREAM_CONTENT_TYPE, ANY_CONTENT_TYPE
2931
from tf_container.run import logger
3032
import time
3133

@@ -227,19 +229,17 @@ def _default_output_fn(data, accepts):
227229
if accepts == OCTET_STREAM_CONTENT_TYPE:
228230
return data.SerializeToString()
229231

230-
raise ValueError('invalid accept type {}'.format(accepts))
232+
raise UnsupportedAcceptTypeError('invalid accept type {}'.format(accepts))
231233

232234
def _default_input_fn(self, serialized_data, content_type):
233235
if content_type == JSON_CONTENT_TYPE:
234-
data = self._parse_json_request(serialized_data)
235-
elif content_type == CSV_CONTENT_TYPE:
236-
data = self._parse_csv_request(serialized_data)
237-
elif content_type == OCTET_STREAM_CONTENT_TYPE:
238-
data = self.proxy_client.parse_request(serialized_data)
239-
else:
240-
raise ValueError("Unsupported content-type {}".format(content_type))
236+
return self._parse_json_request(serialized_data)
237+
if content_type == CSV_CONTENT_TYPE:
238+
return self._parse_csv_request(serialized_data)
239+
if content_type == OCTET_STREAM_CONTENT_TYPE:
240+
return self.proxy_client.parse_request(serialized_data)
241241

242-
return data
242+
raise UnsupportedContentTypeError('Unsupported content-type {}'.format(content_type))
243243

244244
@classmethod
245245
def from_module(cls, m, grpc_proxy_client):

test/unit/test_serve.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2-
#
2+
#
33
# Licensed under the Apache License, Version 2.0 (the "License").
44
# You may not use this file except in compliance with the License.
55
# A copy of the License is located at
6-
#
6+
#
77
# http://www.apache.org/licenses/LICENSE-2.0
8-
#
9-
# or in the "license" file accompanying this file. This file is distributed
10-
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11-
# express or implied. See the License for the specific language governing
8+
#
9+
# or in the "license" file accompanying this file. This file is distributed
10+
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
11+
# express or implied. See the License for the specific language governing
1212
# permissions and limitations under the License.
1313

1414
import json
@@ -18,6 +18,8 @@
1818
from test.unit.utils import mock_import_modules
1919
from types import ModuleType
2020

21+
from container_support.serving import UnsupportedAcceptTypeError, UnsupportedContentTypeError
22+
2123
JSON_CONTENT_TYPE = "application/json"
2224

2325

@@ -279,6 +281,20 @@ def test_wait_model_to_load(proxy_client, serve):
279281
client.cache_prediction_metadata.assert_called_once_with()
280282

281283

284+
def test_transformer_default_output_fn_unsupported_type(serve):
285+
accept_type = 'fake/accept-type'
286+
287+
with pytest.raises(UnsupportedAcceptTypeError):
288+
serve.Transformer._default_output_fn(None, accept_type)
289+
290+
291+
def test_transformer_default_input_fn_unsupported_type(serve):
292+
content_type = 'fake/content-type'
293+
294+
with pytest.raises(UnsupportedContentTypeError):
295+
serve.Transformer(None)._default_input_fn(None, content_type)
296+
297+
282298
class DummyTransformer(object):
283299
def transform(self, content, mimetype):
284300
if content.startswith('500'):

0 commit comments

Comments
 (0)