@@ -76,6 +76,37 @@ def empty_sagemaker_session():
76
76
return ims
77
77
78
78
79
+ def empty_sagemaker_session_with_null_failure_path ():
80
+ ims = Mock (name = "sagemaker_session" )
81
+ ims .default_bucket = Mock (name = "default_bucket" , return_value = BUCKET_NAME )
82
+ ims .sagemaker_runtime_client = Mock (name = "sagemaker_runtime" )
83
+ ims .sagemaker_client .describe_endpoint = Mock (return_value = ENDPOINT_DESC )
84
+ ims .sagemaker_client .describe_endpoint_config = Mock (return_value = ENDPOINT_CONFIG_DESC )
85
+
86
+ ims .sagemaker_runtime_client .invoke_endpoint_async = Mock (
87
+ name = "invoke_endpoint_async" ,
88
+ return_value = {
89
+ "OutputLocation" : ASYNC_OUTPUT_LOCATION ,
90
+ },
91
+ )
92
+
93
+ polling_timeout_error = PollingTimeoutError (
94
+ message = "Inference could still be running" ,
95
+ output_path = ASYNC_OUTPUT_LOCATION ,
96
+ seconds = DEFAULT_WAITER_CONFIG .delay * DEFAULT_WAITER_CONFIG .max_attempts ,
97
+ )
98
+
99
+ ims .s3_client = Mock (name = "s3_client" )
100
+ ims .s3_client .get_object = Mock (
101
+ name = "get_object" ,
102
+ side_effect = [polling_timeout_error ],
103
+ )
104
+
105
+ ims .s3_client .put_object = Mock (name = "put_object" )
106
+
107
+ return ims
108
+
109
+
79
110
def empty_predictor ():
80
111
predictor = Mock (name = "predictor" )
81
112
predictor .update_endpoint = Mock (name = "update_endpoint" )
@@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path():
161
192
assert result .failure_path == ASYNC_FAILURE_LOCATION
162
193
163
194
195
+ def test_async_predict_call_with_data_and_input_and_null_failure_path ():
196
+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
197
+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
198
+ predictor_async .name = ASYNC_PREDICTOR
199
+ data = DUMMY_DATA
200
+
201
+ result = predictor_async .predict_async (data = data , input_path = ASYNC_INPUT_LOCATION )
202
+ assert sagemaker_session .s3_client .put_object .called
203
+
204
+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
205
+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
206
+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
207
+
208
+ expected_request_args = {
209
+ "Accept" : DEFAULT_ACCEPT ,
210
+ "InputLocation" : ASYNC_INPUT_LOCATION ,
211
+ "EndpointName" : ENDPOINT ,
212
+ }
213
+
214
+ call_args , kwargs = sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .call_args
215
+ assert kwargs == expected_request_args
216
+ assert result .output_path == ASYNC_OUTPUT_LOCATION
217
+ assert result .failure_path is None
218
+
219
+
164
220
def test_async_predict_call_verify_exceptions ():
165
221
sagemaker_session = empty_sagemaker_session ()
166
222
predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
@@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions():
185
241
assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
186
242
187
243
188
- def test_async_predict_call_pass_through_success ():
244
+ def test_async_predict_call_verify_exceptions_with_null_failure_path ():
245
+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
246
+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
247
+
248
+ input_location = "s3://some-input-path"
249
+
250
+ with pytest .raises (
251
+ PollingTimeoutError ,
252
+ match = f"No result at { ASYNC_OUTPUT_LOCATION } after polling for "
253
+ f"{ DEFAULT_WAITER_CONFIG .delay * DEFAULT_WAITER_CONFIG .max_attempts } "
254
+ f" seconds. Inference could still be running" ,
255
+ ):
256
+ predictor_async .predict (input_path = input_location , waiter_config = DEFAULT_WAITER_CONFIG )
257
+
258
+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
259
+ assert sagemaker_session .s3_client .get_object .called
260
+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
261
+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
262
+
263
+
264
+ def test_async_predict_call_pass_through_output_failure_paths ():
189
265
sagemaker_session = empty_sagemaker_session ()
190
266
191
267
response_body = Mock ("body" )
@@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success():
222
298
assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
223
299
224
300
301
+ def test_async_predict_call_pass_through_with_null_failure_path ():
302
+ sagemaker_session = empty_sagemaker_session_with_null_failure_path ()
303
+
304
+ response_body = Mock ("body" )
305
+ response_body .read = Mock ("read" , return_value = RETURN_VALUE )
306
+ response_body .close = Mock ("close" , return_value = None )
307
+
308
+ sagemaker_session .s3_client = Mock (name = "s3_client" )
309
+ sagemaker_session .s3_client .get_object = Mock (
310
+ name = "get_object" ,
311
+ return_value = {"Body" : response_body },
312
+ )
313
+ sagemaker_session .s3_client .put_object = Mock (name = "put_object" )
314
+
315
+ predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
316
+
317
+ sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async = Mock (
318
+ name = "invoke_endpoint_async" ,
319
+ return_value = {
320
+ "OutputLocation" : ASYNC_OUTPUT_LOCATION ,
321
+ },
322
+ )
323
+
324
+ input_location = "s3://some-input-path"
325
+
326
+ result = predictor_async .predict (
327
+ input_path = input_location ,
328
+ )
329
+
330
+ assert result == RETURN_VALUE
331
+ assert sagemaker_session .sagemaker_runtime_client .invoke_endpoint_async .called
332
+ assert sagemaker_session .s3_client .get_waiter .called_with ("object_exists" )
333
+ assert sagemaker_session .sagemaker_client .describe_endpoint .not_called
334
+ assert sagemaker_session .sagemaker_client .describe_endpoint_config .not_called
335
+
336
+
225
337
def test_predict_async_call_invalid_input ():
226
338
sagemaker_session = empty_sagemaker_session ()
227
339
predictor_async = AsyncPredictor (Predictor (ENDPOINT , sagemaker_session ))
0 commit comments