@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271
271
272
272
output_file_found = threading .Event ()
273
273
failure_file_found = threading .Event ()
274
+ waiter_error_catched = threading .Event ()
274
275
275
276
def check_output_file ():
276
277
try :
@@ -282,7 +283,7 @@ def check_output_file():
282
283
)
283
284
output_file_found .set ()
284
285
except WaiterError :
285
- pass
286
+ waiter_error_catched . set ()
286
287
287
288
def check_failure_file ():
288
289
try :
@@ -294,33 +295,35 @@ def check_failure_file():
294
295
)
295
296
failure_file_found .set ()
296
297
except WaiterError :
297
- pass
298
+ waiter_error_catched . set ()
298
299
299
300
output_thread = threading .Thread (target = check_output_file )
300
301
failure_thread = threading .Thread (target = check_failure_file )
301
302
302
303
output_thread .start ()
303
304
failure_thread .start ()
304
305
305
- while not output_file_found .is_set () and not failure_file_found .is_set ():
306
+ while (
307
+ not output_file_found .is_set ()
308
+ and not failure_file_found .is_set ()
309
+ and not waiter_error_catched .is_set ()
310
+ ):
306
311
time .sleep (1 )
307
312
308
313
if output_file_found .is_set ():
309
314
s3_object = self .s3_client .get_object (Bucket = output_bucket , Key = output_key )
310
315
result = self .predictor ._handle_response (response = s3_object )
311
316
return result
312
317
313
- failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
314
- failure_response = self .predictor ._handle_response (response = failure_object )
318
+ if failure_file_found .is_set ():
319
+ failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
320
+ failure_response = self .predictor ._handle_response (response = failure_object )
321
+ raise AsyncInferenceModelError (message = failure_response )
315
322
316
- raise (
317
- AsyncInferenceModelError (message = failure_response )
318
- if failure_file_found .is_set ()
319
- else PollingTimeoutError (
320
- message = "Inference could still be running" ,
321
- output_path = output_path ,
322
- seconds = waiter_config .delay * waiter_config .max_attempts ,
323
- )
323
+ raise PollingTimeoutError (
324
+ message = "Inference could still be running" ,
325
+ output_path = output_path ,
326
+ seconds = waiter_config .delay * waiter_config .max_attempts ,
324
327
)
325
328
326
329
def update_endpoint (
0 commit comments