@@ -96,7 +96,7 @@ def test_transform(validate, retrieve_content_type_header, run_handler, accept_k
96
96
validate .assert_called_once ()
97
97
retrieve_content_type_header .assert_called_once_with (request_property )
98
98
run_handler .assert_called_once_with (
99
- transformer ._transform_fn , MODEL , INPUT_DATA , CONTENT_TYPE , ACCEPT
99
+ transformer ._transform_fn , MODEL , [ INPUT_DATA ] , CONTENT_TYPE , ACCEPT
100
100
)
101
101
context .set_response_content_type .assert_called_once_with (0 , ACCEPT )
102
102
assert isinstance (result , list )
@@ -125,16 +125,13 @@ def test_batch_transform(validate, retrieve_content_type_header, run_handler, ac
125
125
result = transformer .transform (data , context )
126
126
127
127
validate .assert_called_once ()
128
- retrieve_content_type_header .assert_called_with (request_property )
129
- assert retrieve_content_type_header .call_count == 2
130
- run_handler .assert_called_with (
131
- transformer ._transform_fn , MODEL , INPUT_DATA , CONTENT_TYPE , ACCEPT
128
+ retrieve_content_type_header .assert_called_once_with (request_property )
129
+ run_handler .assert_called_once_with (
130
+ transformer ._transform_fn , MODEL , [INPUT_DATA , INPUT_DATA ], CONTENT_TYPE , ACCEPT
132
131
)
133
- assert run_handler .call_count == 2
134
- context .set_response_content_type .assert_called_with (0 , ACCEPT )
135
- assert context .set_response_content_type .call_count == 2
132
+ context .set_response_content_type .assert_called_once_with (0 , ACCEPT )
136
133
assert isinstance (result , list )
137
- assert result == [ RESULT , RESULT ]
134
+ assert result [ 0 ] == RESULT
138
135
139
136
140
137
@patch ("sagemaker_inference.transformer.Transformer._run_handler_function" )
@@ -161,7 +158,7 @@ def test_transform_no_accept(validate, retrieve_content_type_header, run_handler
161
158
162
159
validate .assert_called_once ()
163
160
run_handler .assert_called_once_with (
164
- transformer ._transform_fn , MODEL , INPUT_DATA , CONTENT_TYPE , DEFAULT_ACCEPT
161
+ transformer ._transform_fn , MODEL , [ INPUT_DATA ] , CONTENT_TYPE , DEFAULT_ACCEPT
165
162
)
166
163
167
164
@@ -189,7 +186,7 @@ def test_transform_any_accept(validate, retrieve_content_type_header, run_handle
189
186
190
187
validate .assert_called_once ()
191
188
run_handler .assert_called_once_with (
192
- transformer ._transform_fn , MODEL , INPUT_DATA , CONTENT_TYPE , DEFAULT_ACCEPT
189
+ transformer ._transform_fn , MODEL , [ INPUT_DATA ] , CONTENT_TYPE , DEFAULT_ACCEPT
193
190
)
194
191
195
192
@@ -218,7 +215,7 @@ def test_transform_decode(validate, retrieve_content_type_header, run_handler, c
218
215
219
216
input_data .decode .assert_called_once_with ("utf-8" )
220
217
run_handler .assert_called_once_with (
221
- transformer ._transform_fn , MODEL , INPUT_DATA , content_type , ACCEPT
218
+ transformer ._transform_fn , MODEL , [ INPUT_DATA ] , content_type , ACCEPT
222
219
)
223
220
224
221
@@ -245,7 +242,7 @@ def test_transform_tuple(validate, retrieve_content_type_header, run_handler):
245
242
result = transformer .transform (data , context )
246
243
247
244
run_handler .assert_called_once_with (
248
- transformer ._transform_fn , MODEL , INPUT_DATA , CONTENT_TYPE , ACCEPT
245
+ transformer ._transform_fn , MODEL , [ INPUT_DATA ] , CONTENT_TYPE , ACCEPT
249
246
)
250
247
context .set_response_content_type .assert_called_once_with (0 , run_handler ()[1 ])
251
248
assert isinstance (result , list )
0 commit comments