Skip to content

Commit 6fb3b81

Browse files
haozhx23UbuntunargokulUbuntu
authored
fix: addWaiterTimeoutHandling (#4951)
* addWaiterTimeoutHandling * codeStyleUpdate * updateCodeStyle * updateCodeStyle * updateCodeStyle * updateCodeStyle * updateCodeStyle * updateCodeStyle --------- Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Gokul Anantha Narayanan <[email protected]> Co-authored-by: Ubuntu <[email protected]>
1 parent 3f484d7 commit 6fb3b81

File tree

1 file changed

+16
-13
lines changed

1 file changed

+16
-13
lines changed

src/sagemaker/predictor_async.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271271

272272
output_file_found = threading.Event()
273273
failure_file_found = threading.Event()
274+
waiter_error_catched = threading.Event()
274275

275276
def check_output_file():
276277
try:
@@ -282,7 +283,7 @@ def check_output_file():
282283
)
283284
output_file_found.set()
284285
except WaiterError:
285-
pass
286+
waiter_error_catched.set()
286287

287288
def check_failure_file():
288289
try:
@@ -294,33 +295,35 @@ def check_failure_file():
294295
)
295296
failure_file_found.set()
296297
except WaiterError:
297-
pass
298+
waiter_error_catched.set()
298299

299300
output_thread = threading.Thread(target=check_output_file)
300301
failure_thread = threading.Thread(target=check_failure_file)
301302

302303
output_thread.start()
303304
failure_thread.start()
304305

305-
while not output_file_found.is_set() and not failure_file_found.is_set():
306+
while (
307+
not output_file_found.is_set()
308+
and not failure_file_found.is_set()
309+
and not waiter_error_catched.is_set()
310+
):
306311
time.sleep(1)
307312

308313
if output_file_found.is_set():
309314
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
310315
result = self.predictor._handle_response(response=s3_object)
311316
return result
312317

313-
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
314-
failure_response = self.predictor._handle_response(response=failure_object)
318+
if failure_file_found.is_set():
319+
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
320+
failure_response = self.predictor._handle_response(response=failure_object)
321+
raise AsyncInferenceModelError(message=failure_response)
315322

316-
raise (
317-
AsyncInferenceModelError(message=failure_response)
318-
if failure_file_found.is_set()
319-
else PollingTimeoutError(
320-
message="Inference could still be running",
321-
output_path=output_path,
322-
seconds=waiter_config.delay * waiter_config.max_attempts,
323-
)
323+
raise PollingTimeoutError(
324+
message="Inference could still be running",
325+
output_path=output_path,
326+
seconds=waiter_config.delay * waiter_config.max_attempts,
324327
)
325328

326329
def update_endpoint(

0 commit comments

Comments
 (0)