Skip to content

Commit 4cf728c

Browse files
author
태영돈
committed
Fix transform function to support proper batch inference
1 parent 6b33274 commit 4cf728c

File tree

2 files changed

+30
-40
lines changed

2 files changed

+30
-40
lines changed

src/sagemaker_inference/transformer.py

Lines changed: 20 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -122,40 +122,33 @@ def transform(self, data, context):
122122
model_dir = properties.get("model_dir")
123123
self.validate_and_initialize(model_dir=model_dir, context=context)
124124

125-
response_list = []
125+
input_data = [req.get("body") for req in data]
126126

127-
for i in range(len(data)):
128-
input_data = data[i].get("body")
127+
request_processor = context.request_processor[0]
129128

130-
request_processor = context.request_processor[0]
129+
request_property = request_processor.get_request_properties()
130+
content_type = utils.retrieve_content_type_header(request_property)
131+
accept = request_property.get("Accept") or request_property.get("accept")
131132

132-
request_property = request_processor.get_request_properties()
133-
content_type = utils.retrieve_content_type_header(request_property)
134-
accept = request_property.get("Accept") or request_property.get("accept")
133+
if not accept or accept == content_types.ANY:
134+
accept = self._environment.default_accept
135135

136-
if not accept or accept == content_types.ANY:
137-
accept = self._environment.default_accept
138-
139-
if content_type in content_types.UTF8_TYPES:
140-
input_data = input_data.decode("utf-8")
141-
142-
result = self._run_handler_function(
143-
self._transform_fn, *(self._model, input_data, content_type, accept)
144-
)
145-
146-
response = result
147-
response_content_type = accept
148-
149-
if isinstance(result, tuple):
150-
# handles tuple for backwards compatibility
151-
response = result[0]
152-
response_content_type = result[1]
136+
if content_type in content_types.UTF8_TYPES:
137+
input_data = [item.decode("utf-8") for item in input_data]
138+
result = self._run_handler_function(
139+
self._transform_fn, *(self._model, input_data, content_type, accept)
140+
)
153141

154-
context.set_response_content_type(0, response_content_type)
142+
response = result
143+
response_content_type = accept
155144

156-
response_list.append(response)
145+
if isinstance(result, tuple):
146+
# handles tuple for backwards compatibility
147+
response = result[0]
148+
response_content_type = result[1]
157149

158-
return response_list
150+
context.set_response_content_type(0, response_content_type)
151+
return [response]
159152
except Exception as e: # pylint: disable=broad-except
160153
trace = traceback.format_exc()
161154
if isinstance(e, BaseInferenceToolkitError):

test/unit/test_transfomer.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def test_transform(validate, retrieve_content_type_header, run_handler, accept_k
9696
validate.assert_called_once()
9797
retrieve_content_type_header.assert_called_once_with(request_property)
9898
run_handler.assert_called_once_with(
99-
transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT
99+
transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, ACCEPT
100100
)
101101
context.set_response_content_type.assert_called_once_with(0, ACCEPT)
102102
assert isinstance(result, list)
@@ -125,16 +125,13 @@ def test_batch_transform(validate, retrieve_content_type_header, run_handler, ac
125125
result = transformer.transform(data, context)
126126

127127
validate.assert_called_once()
128-
retrieve_content_type_header.assert_called_with(request_property)
129-
assert retrieve_content_type_header.call_count == 2
130-
run_handler.assert_called_with(
131-
transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT
128+
retrieve_content_type_header.assert_called_once_with(request_property)
129+
run_handler.assert_called_once_with(
130+
transformer._transform_fn, MODEL, [INPUT_DATA, INPUT_DATA], CONTENT_TYPE, ACCEPT
132131
)
133-
assert run_handler.call_count == 2
134-
context.set_response_content_type.assert_called_with(0, ACCEPT)
135-
assert context.set_response_content_type.call_count == 2
132+
context.set_response_content_type.assert_called_once_with(0, ACCEPT)
136133
assert isinstance(result, list)
137-
assert result == [RESULT, RESULT]
134+
assert result[0] == RESULT
138135

139136

140137
@patch("sagemaker_inference.transformer.Transformer._run_handler_function")
@@ -161,7 +158,7 @@ def test_transform_no_accept(validate, retrieve_content_type_header, run_handler
161158

162159
validate.assert_called_once()
163160
run_handler.assert_called_once_with(
164-
transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT
161+
transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, DEFAULT_ACCEPT
165162
)
166163

167164

@@ -189,7 +186,7 @@ def test_transform_any_accept(validate, retrieve_content_type_header, run_handle
189186

190187
validate.assert_called_once()
191188
run_handler.assert_called_once_with(
192-
transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, DEFAULT_ACCEPT
189+
transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, DEFAULT_ACCEPT
193190
)
194191

195192

@@ -218,7 +215,7 @@ def test_transform_decode(validate, retrieve_content_type_header, run_handler, c
218215

219216
input_data.decode.assert_called_once_with("utf-8")
220217
run_handler.assert_called_once_with(
221-
transformer._transform_fn, MODEL, INPUT_DATA, content_type, ACCEPT
218+
transformer._transform_fn, MODEL, [INPUT_DATA], content_type, ACCEPT
222219
)
223220

224221

@@ -245,7 +242,7 @@ def test_transform_tuple(validate, retrieve_content_type_header, run_handler):
245242
result = transformer.transform(data, context)
246243

247244
run_handler.assert_called_once_with(
248-
transformer._transform_fn, MODEL, INPUT_DATA, CONTENT_TYPE, ACCEPT
245+
transformer._transform_fn, MODEL, [INPUT_DATA], CONTENT_TYPE, ACCEPT
249246
)
250247
context.set_response_content_type.assert_called_once_with(0, run_handler()[1])
251248
assert isinstance(result, list)

0 commit comments

Comments
 (0)