Skip to content

Commit 235decc

Browse files
committed
Updates to async predictor
1 parent 41d3de8 commit 235decc

File tree

2 files changed

+33
-4
lines changed

2 files changed

+33
-4
lines changed

src/sagemaker/async_inference/waiter_config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ class WaiterConfig(object):
2323

2424
def __init__(
2525
self,
26-
max_attempts=4,
27-
delay=5,
26+
max_attempts=60,
27+
delay=15,
2828
):
2929
"""Initialize a WaiterConfig object that provides parameters to control waiting behavior.
3030

src/sagemaker/predictor_async.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def predict(
9999
self._input_path = input_path
100100
response = self._submit_async_request(input_path, initial_args, inference_id)
101101
output_location = response["OutputLocation"]
102-
failure_location = response["FailureLocation"]
102+
failure_location = response.get("FailureLocation")
103103
result = self._wait_for_output(
104104
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
105105
)
@@ -145,7 +145,7 @@ def predict_async(
145145
self._input_path = input_path
146146
response = self._submit_async_request(input_path, initial_args, inference_id)
147147
output_location = response["OutputLocation"]
148-
failure_location = response["FailureLocation"]
148+
failure_location = response.get("FailureLocation")
149149
response_async = AsyncInferenceResponse(
150150
predictor_async=self,
151151
output_path=output_location,
@@ -216,6 +216,35 @@ def _submit_async_request(
216216
return response
217217

218218
def _wait_for_output(self, output_path, failure_path, waiter_config):
219+
"""Retrieve output based on the presense of failure_path."""
220+
if failure_path is not None:
221+
return self._check_output_and_failure_locations(
222+
output_path, failure_path, waiter_config
223+
)
224+
225+
return self._check_output_location(output_path, waiter_config)
226+
227+
def _check_output_location(self, output_path, waiter_config):
228+
"""Check the Amazon S3 output path for the output.
229+
230+
Periodically check Amazon S3 output path for async inference result.
231+
Timeout automatically after max attempts reached
232+
"""
233+
bucket, key = parse_s3_url(output_path)
234+
s3_waiter = self.s3_client.get_waiter("object_exists")
235+
try:
236+
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
237+
except WaiterError:
238+
raise PollingTimeoutError(
239+
message="Inference could still be running",
240+
output_path=output_path,
241+
seconds=waiter_config.delay * waiter_config.max_attempts,
242+
)
243+
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
244+
result = self.predictor._handle_response(response=s3_object)
245+
return result
246+
247+
def _check_output_and_failure_locations(self, output_path, failure_path, waiter_config):
219248
"""Check the Amazon S3 output path for the output.
220249
221250
This method waits for either the output file or the failure file to be found on the

0 commit comments

Comments
 (0)