Skip to content

Commit 0f37695

Browse files
feature: Handle use case where endpoint is created outside of python … (#3867)
1 parent 4430c8c commit 0f37695

File tree

4 files changed

+235
-3
lines changed

4 files changed

+235
-3
lines changed

src/sagemaker/async_inference/async_inference_response.py

+23
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,31 @@ def get_result(
8787
return self._result
8888

8989
def _get_result_from_s3(self, output_path, failure_path):
90+
"""Retrieve output based on the presense of failure_path"""
91+
if failure_path is not None:
92+
return self._get_result_from_s3_output_failure_paths(output_path, failure_path)
93+
94+
return self._get_result_from_s3_output_path(output_path)
95+
96+
def _get_result_from_s3_output_path(self, output_path):
9097
"""Get inference result from the output Amazon S3 path"""
9198
bucket, key = parse_s3_url(output_path)
99+
try:
100+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
101+
return self.predictor_async.predictor._handle_response(response)
102+
except ClientError as ex:
103+
if ex.response["Error"]["Code"] == "NoSuchKey":
104+
raise ObjectNotExistedError(
105+
message="Inference could still be running",
106+
output_path=output_path,
107+
)
108+
raise UnexpectedClientError(
109+
message=ex.response["Error"]["Message"],
110+
)
111+
112+
def _get_result_from_s3_output_failure_paths(self, output_path, failure_path):
113+
"""Get inference result from the output & failure Amazon S3 path"""
114+
bucket, key = parse_s3_url(output_path)
92115
try:
93116
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
94117
return self.predictor_async.predictor._handle_response(response)

src/sagemaker/predictor_async.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def predict(
9999
self._input_path = input_path
100100
response = self._submit_async_request(input_path, initial_args, inference_id)
101101
output_location = response["OutputLocation"]
102-
failure_location = response["FailureLocation"]
102+
failure_location = response.get("FailureLocation")
103103
result = self._wait_for_output(
104104
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
105105
)
@@ -145,7 +145,7 @@ def predict_async(
145145
self._input_path = input_path
146146
response = self._submit_async_request(input_path, initial_args, inference_id)
147147
output_location = response["OutputLocation"]
148-
failure_location = response["FailureLocation"]
148+
failure_location = response.get("FailureLocation")
149149
response_async = AsyncInferenceResponse(
150150
predictor_async=self,
151151
output_path=output_location,
@@ -216,6 +216,33 @@ def _submit_async_request(
216216
return response
217217

218218
def _wait_for_output(self, output_path, failure_path, waiter_config):
219+
"""Retrieve output based on the presense of failure_path."""
220+
if failure_path is not None:
221+
return self._check_output_and_failure_paths(output_path, failure_path, waiter_config)
222+
223+
return self._check_output_path(output_path, waiter_config)
224+
225+
def _check_output_path(self, output_path, waiter_config):
226+
"""Check the Amazon S3 output path for the output.
227+
228+
Periodically check Amazon S3 output path for async inference result.
229+
Timeout automatically after max attempts reached
230+
"""
231+
bucket, key = parse_s3_url(output_path)
232+
s3_waiter = self.s3_client.get_waiter("object_exists")
233+
try:
234+
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
235+
except WaiterError:
236+
raise PollingTimeoutError(
237+
message="Inference could still be running",
238+
output_path=output_path,
239+
seconds=waiter_config.delay * waiter_config.max_attempts,
240+
)
241+
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
242+
result = self.predictor._handle_response(response=s3_object)
243+
return result
244+
245+
def _check_output_and_failure_paths(self, output_path, failure_path, waiter_config):
219246
"""Check the Amazon S3 output path for the output.
220247
221248
This method waits for either the output file or the failure file to be found on the

tests/unit/sagemaker/async_inference/test_async_inference_response.py

+70
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ def empty_s3_client():
7474
return s3_client
7575

7676

77+
def empty_s3_client_to_verify_exceptions_for_null_failure_path():
78+
"""
79+
Returns a mocked S3 client with the `get_object` method overridden
80+
to raise different exceptions based on the input.
81+
82+
Exceptions raised:
83+
- `ObjectNotExistedError`
84+
- `UnexpectedClientError`
85+
86+
"""
87+
s3_client = Mock(name="s3-client")
88+
89+
object_error = ObjectNotExistedError("Inference could still be running", DEFAULT_OUTPUT_PATH)
90+
91+
unexpected_error = UnexpectedClientError("some error message")
92+
93+
s3_client.get_object = Mock(
94+
name="get_object",
95+
side_effect=[
96+
object_error,
97+
unexpected_error,
98+
],
99+
)
100+
return s3_client
101+
102+
77103
def mock_s3_client():
78104
"""
79105
This function returns a mocked S3 client object that has a get_object method with a side_effect
@@ -172,3 +198,47 @@ def test_get_result_verify_exceptions():
172198
UnexpectedClientError, match="Encountered unexpected client error: some error message"
173199
):
174200
async_inference_response.get_result()
201+
202+
203+
def test_get_result_with_null_failure_path():
204+
"""
205+
verifies that the result is returned correctly if no errors occur.
206+
"""
207+
# Initialize AsyncInferenceResponse
208+
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
209+
predictor_async.s3_client = mock_s3_client()
210+
async_inference_response = AsyncInferenceResponse(
211+
output_path=DEFAULT_OUTPUT_PATH, predictor_async=predictor_async, failure_path=None
212+
)
213+
214+
result = async_inference_response.get_result()
215+
assert async_inference_response._result == result
216+
assert result == RETURN_VALUE
217+
218+
219+
def test_get_result_verify_exceptions_with_null_failure_path():
220+
"""
221+
Verifies that get_result method raises the expected exception
222+
when an error occurs while fetching the result.
223+
"""
224+
# Initialize AsyncInferenceResponse
225+
predictor_async = AsyncPredictor(Predictor(ENDPOINT_NAME))
226+
predictor_async.s3_client = empty_s3_client_to_verify_exceptions_for_null_failure_path()
227+
async_inference_response = AsyncInferenceResponse(
228+
output_path=DEFAULT_OUTPUT_PATH,
229+
predictor_async=predictor_async,
230+
failure_path=None,
231+
)
232+
233+
# Test ObjectNotExistedError
234+
with pytest.raises(
235+
ObjectNotExistedError,
236+
match=f"Object not exist at {DEFAULT_OUTPUT_PATH}. Inference could still be running",
237+
):
238+
async_inference_response.get_result()
239+
240+
# Test UnexpectedClientError
241+
with pytest.raises(
242+
UnexpectedClientError, match="Encountered unexpected client error: some error message"
243+
):
244+
async_inference_response.get_result()

tests/unit/test_predictor_async.py

+113-1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,37 @@ def empty_sagemaker_session():
7676
return ims
7777

7878

79+
def empty_sagemaker_session_with_null_failure_path():
80+
ims = Mock(name="sagemaker_session")
81+
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
82+
ims.sagemaker_runtime_client = Mock(name="sagemaker_runtime")
83+
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
84+
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)
85+
86+
ims.sagemaker_runtime_client.invoke_endpoint_async = Mock(
87+
name="invoke_endpoint_async",
88+
return_value={
89+
"OutputLocation": ASYNC_OUTPUT_LOCATION,
90+
},
91+
)
92+
93+
polling_timeout_error = PollingTimeoutError(
94+
message="Inference could still be running",
95+
output_path=ASYNC_OUTPUT_LOCATION,
96+
seconds=DEFAULT_WAITER_CONFIG.delay * DEFAULT_WAITER_CONFIG.max_attempts,
97+
)
98+
99+
ims.s3_client = Mock(name="s3_client")
100+
ims.s3_client.get_object = Mock(
101+
name="get_object",
102+
side_effect=[polling_timeout_error],
103+
)
104+
105+
ims.s3_client.put_object = Mock(name="put_object")
106+
107+
return ims
108+
109+
79110
def empty_predictor():
80111
predictor = Mock(name="predictor")
81112
predictor.update_endpoint = Mock(name="update_endpoint")
@@ -161,6 +192,31 @@ def test_async_predict_call_with_data_and_input_path():
161192
assert result.failure_path == ASYNC_FAILURE_LOCATION
162193

163194

195+
def test_async_predict_call_with_data_and_input_and_null_failure_path():
196+
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
197+
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
198+
predictor_async.name = ASYNC_PREDICTOR
199+
data = DUMMY_DATA
200+
201+
result = predictor_async.predict_async(data=data, input_path=ASYNC_INPUT_LOCATION)
202+
assert sagemaker_session.s3_client.put_object.called
203+
204+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
205+
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
206+
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
207+
208+
expected_request_args = {
209+
"Accept": DEFAULT_ACCEPT,
210+
"InputLocation": ASYNC_INPUT_LOCATION,
211+
"EndpointName": ENDPOINT,
212+
}
213+
214+
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.call_args
215+
assert kwargs == expected_request_args
216+
assert result.output_path == ASYNC_OUTPUT_LOCATION
217+
assert result.failure_path is None
218+
219+
164220
def test_async_predict_call_verify_exceptions():
165221
sagemaker_session = empty_sagemaker_session()
166222
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
@@ -185,7 +241,27 @@ def test_async_predict_call_verify_exceptions():
185241
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
186242

187243

188-
def test_async_predict_call_pass_through_success():
244+
def test_async_predict_call_verify_exceptions_with_null_failure_path():
245+
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
246+
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
247+
248+
input_location = "s3://some-input-path"
249+
250+
with pytest.raises(
251+
PollingTimeoutError,
252+
match=f"No result at {ASYNC_OUTPUT_LOCATION} after polling for "
253+
f"{DEFAULT_WAITER_CONFIG.delay*DEFAULT_WAITER_CONFIG.max_attempts}"
254+
f" seconds. Inference could still be running",
255+
):
256+
predictor_async.predict(input_path=input_location, waiter_config=DEFAULT_WAITER_CONFIG)
257+
258+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
259+
assert sagemaker_session.s3_client.get_object.called
260+
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
261+
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
262+
263+
264+
def test_async_predict_call_pass_through_output_failure_paths():
189265
sagemaker_session = empty_sagemaker_session()
190266

191267
response_body = Mock("body")
@@ -222,6 +298,42 @@ def test_async_predict_call_pass_through_success():
222298
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
223299

224300

301+
def test_async_predict_call_pass_through_with_null_failure_path():
302+
sagemaker_session = empty_sagemaker_session_with_null_failure_path()
303+
304+
response_body = Mock("body")
305+
response_body.read = Mock("read", return_value=RETURN_VALUE)
306+
response_body.close = Mock("close", return_value=None)
307+
308+
sagemaker_session.s3_client = Mock(name="s3_client")
309+
sagemaker_session.s3_client.get_object = Mock(
310+
name="get_object",
311+
return_value={"Body": response_body},
312+
)
313+
sagemaker_session.s3_client.put_object = Mock(name="put_object")
314+
315+
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))
316+
317+
sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async = Mock(
318+
name="invoke_endpoint_async",
319+
return_value={
320+
"OutputLocation": ASYNC_OUTPUT_LOCATION,
321+
},
322+
)
323+
324+
input_location = "s3://some-input-path"
325+
326+
result = predictor_async.predict(
327+
input_path=input_location,
328+
)
329+
330+
assert result == RETURN_VALUE
331+
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint_async.called
332+
assert sagemaker_session.s3_client.get_waiter.called_with("object_exists")
333+
assert sagemaker_session.sagemaker_client.describe_endpoint.not_called
334+
assert sagemaker_session.sagemaker_client.describe_endpoint_config.not_called
335+
336+
225337
def test_predict_async_call_invalid_input():
226338
sagemaker_session = empty_sagemaker_session()
227339
predictor_async = AsyncPredictor(Predictor(ENDPOINT, sagemaker_session))

0 commit comments

Comments
 (0)