Skip to content

Commit 5832820

Browse files
feature: add support for async inline error notifications (#3750)
* feature: add support for async inline error notifications * fix:pylint errors * fix: integ tests
1 parent e4d0874 commit 5832820

13 files changed

+357
-106
lines changed

doc/overview.rst

+7-3
Original file line numberDiff line numberDiff line change
@@ -1164,7 +1164,8 @@ More information about SageMaker Asynchronous Inference can be found in the `AWS
11641164

11651165
To deploy asynchronous inference endpoint, you will need to create a ``AsyncInferenceConfig`` object.
11661166
If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will
1167-
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``. (example shown below):
1167+
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``, ``S3FailurePath`` will
1168+
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-failures/{UNIQUE-JOB-NAME}`` (example shown below):
11681169

11691170
.. code:: python
11701171
@@ -1174,18 +1175,21 @@ be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME
11741175
async_config = AsyncInferenceConfig()
11751176
11761177
Or you can specify configurations in ``AsyncInferenceConfig`` as you like. All of those configuration parameters
1177-
are optional but if you don’t specify the ``output_path``, Amazon SageMaker will use the default ``S3OutputPath``
1178+
are optional but if you don’t specify the ``output_path`` or ``failure_path``, Amazon SageMaker will use the
1179+
default ``S3OutputPath`` or ``S3FailurePath``
11781180
mentioned above (example shown below):
11791181

11801182
.. code:: python
11811183
1182-
# Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object
1184+
# Specify S3OutputPath, S3FailurePath, MaxConcurrentInvocationsPerInstance and NotificationConfig
1185+
# in the async config object
11831186
async_config = AsyncInferenceConfig(
11841187
output_path="s3://{s3_bucket}/{bucket_prefix}/output",
11851188
max_concurrent_invocations_per_instance=10,
11861189
notification_config = {
11871190
"SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name",
11881191
"ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name",
1192+
"IncludeInferenceResponseIn": ["SUCCESS_NOTIFICATION_TOPIC","ERROR_NOTIFICATION_TOPIC"],
11891193
}
11901194
)
11911195

src/sagemaker/async_inference/async_inference_config.py

+11
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ def __init__(
3131
max_concurrent_invocations_per_instance=None,
3232
kms_key_id=None,
3333
notification_config=None,
34+
failure_path=None,
3435
):
3536
"""Initialize an AsyncInferenceConfig object for async inference configuration.
3637
@@ -45,6 +46,9 @@ def __init__(
4546
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
4647
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
4748
asynchronous inference output in Amazon S3. (Default: None)
49+
failure_path (str): Optional. The Amazon S3 location that endpoints upload model
50+
responses for failed requests. If no value is provided, Amazon SageMaker will
51+
use default Amazon S3 Async Inference failure path. (Default: None)
4852
notification_config (dict): Optional. Specifies the configuration for notifications
4953
of inference results for asynchronous inference. Only one notification is generated
5054
per invocation request (Default: None):
@@ -54,17 +58,24 @@ def __init__(
5458
* error_topic (str): Amazon SNS topic to post a notification to when inference
5559
fails. If no topic is provided, no notification is sent on failure.
5660
The key in notification_config is 'ErrorTopic'.
61+
* include_inference_response_in (list): Optional. When provided the inference
62+
response will be included in the notification topics. If not provided,
63+
a notification will still be generated on success/error, but will not
64+
contain the inference response.
65+
Valid options are SUCCESS_NOTIFICATION_TOPIC, ERROR_NOTIFICATION_TOPIC
5766
"""
5867
self.output_path = output_path
5968
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
6069
self.kms_key_id = kms_key_id
6170
self.notification_config = notification_config
71+
self.failure_path = failure_path
6272

6373
def _to_request_dict(self):
6474
"""Generates a request dictionary using the parameters provided to the class."""
6575
request_dict = {
6676
"OutputConfig": {
6777
"S3OutputPath": self.output_path,
78+
"S3FailurePath": self.failure_path,
6879
},
6980
}
7081

src/sagemaker/async_inference/async_inference_response.py

+30-16
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from botocore.exceptions import ClientError
1818
from sagemaker.s3 import parse_s3_url
1919
from sagemaker.async_inference import WaiterConfig
20-
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
20+
from sagemaker.exceptions import (
21+
ObjectNotExistedError,
22+
UnexpectedClientError,
23+
AsyncInferenceModelError,
24+
)
2125

2226

2327
class AsyncInferenceResponse(object):
@@ -32,6 +36,7 @@ def __init__(
3236
self,
3337
predictor_async,
3438
output_path,
39+
failure_path,
3540
):
3641
"""Initialize an AsyncInferenceResponse object.
3742
@@ -43,10 +48,13 @@ def __init__(
4348
that return this response.
4449
output_path (str): The Amazon S3 location that endpoints upload inference responses
4550
to.
51+
failure_path (str): The Amazon S3 location that endpoints upload model errors
52+
for failed requests.
4653
"""
4754
self.predictor_async = predictor_async
4855
self.output_path = output_path
4956
self._result = None
57+
self.failure_path = failure_path
5058

5159
def get_result(
5260
self,
@@ -71,28 +79,34 @@ def get_result(
7179

7280
if self._result is None:
7381
if waiter_config is None:
74-
self._result = self._get_result_from_s3(self.output_path)
82+
self._result = self._get_result_from_s3(self.output_path, self.failure_path)
7583
else:
7684
self._result = self.predictor_async._wait_for_output(
77-
self.output_path, waiter_config
85+
self.output_path, self.failure_path, waiter_config
7886
)
7987
return self._result
8088

81-
def _get_result_from_s3(
82-
self,
83-
output_path,
84-
):
89+
def _get_result_from_s3(self, output_path, failure_path):
8590
"""Get inference result from the output Amazon S3 path"""
8691
bucket, key = parse_s3_url(output_path)
8792
try:
8893
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
8994
return self.predictor_async.predictor._handle_response(response)
90-
except ClientError as ex:
91-
if ex.response["Error"]["Code"] == "NoSuchKey":
92-
raise ObjectNotExistedError(
93-
message="Inference could still be running",
94-
output_path=output_path,
95-
)
96-
raise UnexpectedClientError(
97-
message=ex.response["Error"]["Message"],
98-
)
95+
except ClientError as e:
96+
if e.response["Error"]["Code"] == "NoSuchKey":
97+
try:
98+
failure_bucket, failure_key = parse_s3_url(failure_path)
99+
failure_response = self.predictor_async.s3_client.get_object(
100+
Bucket=failure_bucket, Key=failure_key
101+
)
102+
failure_response = self.predictor_async.predictor._handle_response(
103+
failure_response
104+
)
105+
raise AsyncInferenceModelError(message=failure_response)
106+
except ClientError as ex:
107+
if ex.response["Error"]["Code"] == "NoSuchKey":
108+
raise ObjectNotExistedError(
109+
message="Inference could still be running", output_path=output_path
110+
)
111+
raise UnexpectedClientError(message=ex.response["Error"]["Message"])
112+
raise UnexpectedClientError(message=e.response["Error"]["Message"])

src/sagemaker/exceptions.py

+9
Original file line numberDiff line numberDiff line change
@@ -77,3 +77,12 @@ def __init__(self, **kwargs):
7777
msg = self.fmt.format(**kwargs)
7878
Exception.__init__(self, msg)
7979
self.kwargs = kwargs
80+
81+
82+
class AsyncInferenceModelError(AsyncInferenceError):
83+
"""Raised when model returns errors for failed requests"""
84+
85+
fmt = "Model returned error: {message} "
86+
87+
def __init__(self, message):
88+
super().__init__(message=message)

src/sagemaker/model.py

+17-6
Original file line numberDiff line numberDiff line change
@@ -1277,7 +1277,10 @@ def deploy(
12771277

12781278
async_inference_config_dict = None
12791279
if is_async:
1280-
if async_inference_config.output_path is None:
1280+
if (
1281+
async_inference_config.output_path is None
1282+
or async_inference_config.failure_path is None
1283+
):
12811284
async_inference_config = self._build_default_async_inference_config(
12821285
async_inference_config
12831286
)
@@ -1316,11 +1319,19 @@ def deploy(
13161319

13171320
def _build_default_async_inference_config(self, async_inference_config):
13181321
"""Build default async inference config and return ``AsyncInferenceConfig``"""
1319-
async_output_folder = unique_name_from_base(self.name)
1320-
async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format(
1321-
self.sagemaker_session.default_bucket(), async_output_folder
1322-
)
1323-
async_inference_config.output_path = async_output_s3uri
1322+
unique_folder = unique_name_from_base(self.name)
1323+
if async_inference_config.output_path is None:
1324+
async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format(
1325+
self.sagemaker_session.default_bucket(), unique_folder
1326+
)
1327+
async_inference_config.output_path = async_output_s3uri
1328+
1329+
if async_inference_config.failure_path is None:
1330+
async_failure_s3uri = "s3://{}/async-endpoint-failures/{}".format(
1331+
self.sagemaker_session.default_bucket(), unique_folder
1332+
)
1333+
async_inference_config.failure_path = async_failure_s3uri
1334+
13241335
return async_inference_config
13251336

13261337
def transformer(

src/sagemaker/predictor_async.py

+81-24
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
# language governing permissions and limitations under the License.
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
15-
15+
import threading
16+
import time
1617
import uuid
1718
from botocore.exceptions import WaiterError
18-
from sagemaker.exceptions import PollingTimeoutError
19+
from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError
1920
from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse
2021
from sagemaker.s3 import parse_s3_url
2122
from sagemaker.session import Session
@@ -98,7 +99,10 @@ def predict(
9899
self._input_path = input_path
99100
response = self._submit_async_request(input_path, initial_args, inference_id)
100101
output_location = response["OutputLocation"]
101-
result = self._wait_for_output(output_path=output_location, waiter_config=waiter_config)
102+
failure_location = response["FailureLocation"]
103+
result = self._wait_for_output(
104+
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
105+
)
102106

103107
return result
104108

@@ -141,9 +145,11 @@ def predict_async(
141145
self._input_path = input_path
142146
response = self._submit_async_request(input_path, initial_args, inference_id)
143147
output_location = response["OutputLocation"]
148+
failure_location = response["FailureLocation"]
144149
response_async = AsyncInferenceResponse(
145150
predictor_async=self,
146151
output_path=output_location,
152+
failure_path=failure_location,
147153
)
148154

149155
return response_async
@@ -209,30 +215,81 @@ def _submit_async_request(
209215

210216
return response
211217

212-
def _wait_for_output(
213-
self,
214-
output_path,
215-
waiter_config,
216-
):
218+
def _wait_for_output(self, output_path, failure_path, waiter_config):
217219
"""Check the Amazon S3 output path for the output.
218220
219-
Periodically check Amazon S3 output path for async inference result.
220-
Timeout automatically after max attempts reached
221-
"""
222-
bucket, key = parse_s3_url(output_path)
223-
s3_waiter = self.s3_client.get_waiter("object_exists")
224-
try:
225-
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
226-
except WaiterError:
227-
raise PollingTimeoutError(
228-
message="Inference could still be running",
229-
output_path=output_path,
230-
seconds=waiter_config.delay * waiter_config.max_attempts,
231-
)
221+
This method waits for either the output file or the failure file to be found on the
222+
specified S3 output path. Whichever file is found first, its corresponding event is
223+
triggered, and the method executes the appropriate action based on the event.
232224
233-
s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
234-
result = self.predictor._handle_response(response=s3_object)
235-
return result
225+
Args:
226+
output_path (str): The S3 path where the output file is expected to be found.
227+
failure_path (str): The S3 path where the failure file is expected to be found.
228+
waiter_config (boto3.waiter.WaiterConfig): The configuration for the S3 waiter.
229+
230+
Returns:
231+
Any: The deserialized result from the output file, if the output file is found first.
232+
Otherwise, raises an exception.
233+
234+
Raises:
235+
AsyncInferenceModelError: If the failure file is found before the output file.
236+
PollingTimeoutError: If both files are not found and the S3 waiter
237+
has thrown a WaiterError.
238+
"""
239+
output_bucket, output_key = parse_s3_url(output_path)
240+
failure_bucket, failure_key = parse_s3_url(failure_path)
241+
242+
output_file_found = threading.Event()
243+
failure_file_found = threading.Event()
244+
245+
def check_output_file():
246+
try:
247+
output_file_waiter = self.s3_client.get_waiter("object_exists")
248+
output_file_waiter.wait(
249+
Bucket=output_bucket,
250+
Key=output_key,
251+
WaiterConfig=waiter_config._to_request_dict(),
252+
)
253+
output_file_found.set()
254+
except WaiterError:
255+
pass
256+
257+
def check_failure_file():
258+
try:
259+
failure_file_waiter = self.s3_client.get_waiter("object_exists")
260+
failure_file_waiter.wait(
261+
Bucket=failure_bucket,
262+
Key=failure_key,
263+
WaiterConfig=waiter_config._to_request_dict(),
264+
)
265+
failure_file_found.set()
266+
except WaiterError:
267+
pass
268+
269+
output_thread = threading.Thread(target=check_output_file)
270+
failure_thread = threading.Thread(target=check_failure_file)
271+
272+
output_thread.start()
273+
failure_thread.start()
274+
275+
while not output_file_found.is_set() and not failure_file_found.is_set():
276+
time.sleep(1)
277+
278+
if output_file_found.is_set():
279+
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
280+
result = self.predictor._handle_response(response=s3_object)
281+
return result
282+
283+
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
284+
failure_response = self.predictor._handle_response(response=failure_object)
285+
286+
raise AsyncInferenceModelError(
287+
message=failure_response
288+
) if failure_file_found.is_set() else PollingTimeoutError(
289+
message="Inference could still be running",
290+
output_path=output_path,
291+
seconds=waiter_config.delay * waiter_config.max_attempts,
292+
)
236293

237294
def update_endpoint(
238295
self,

tests/integ/test_async_inference.py

+6
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set):
6363
assert result_no_wait_with_data.output_path.startswith(
6464
"s3://" + sagemaker_session.default_bucket()
6565
)
66+
assert result_no_wait_with_data.failure_path.startswith(
67+
"s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/"
68+
)
6669
time.sleep(5)
6770
result_no_wait_with_data = result_no_wait_with_data.get_result()
6871
assert len(result_no_wait_with_data) == 5
@@ -97,6 +100,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set):
97100
result_not_wait = predictor_async.predict_async(input_path=input_s3_path)
98101
assert isinstance(result_not_wait, AsyncInferenceResponse)
99102
assert result_not_wait.output_path.startswith("s3://" + sagemaker_session.default_bucket())
103+
assert result_not_wait.failure_path.startswith(
104+
"s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/"
105+
)
100106
time.sleep(5)
101107
result_not_wait = result_not_wait.get_result()
102108
assert len(result_not_wait) == 5

0 commit comments

Comments
 (0)