Skip to content

feature: add support for async inline error notifications #3750

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions doc/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1164,7 +1164,8 @@ More information about SageMaker Asynchronous Inference can be found in the `AWS

To deploy asynchronous inference endpoint, you will need to create a ``AsyncInferenceConfig`` object.
If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``. (example shown below):
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``, ``S3FailurePath`` will
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-failures/{UNIQUE-JOB-NAME}`` (example shown below):

.. code:: python

Expand All @@ -1174,18 +1175,21 @@ be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME
async_config = AsyncInferenceConfig()

Or you can specify configurations in ``AsyncInferenceConfig`` as you like. All of those configuration parameters
are optional but if you don’t specify the ``output_path``, Amazon SageMaker will use the default ``S3OutputPath``
are optional but if you don’t specify the ``output_path`` or ``failure_path``, Amazon SageMaker will use the
default ``S3OutputPath`` or ``S3FailurePath``
mentioned above (example shown below):

.. code:: python

# Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object
# Specify S3OutputPath, S3FailurePath, MaxConcurrentInvocationsPerInstance and NotificationConfig
# in the async config object
async_config = AsyncInferenceConfig(
output_path="s3://{s3_bucket}/{bucket_prefix}/output",
max_concurrent_invocations_per_instance=10,
notification_config = {
"SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name",
"ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name",
"IncludeInferenceResponseIn": ["SUCCESS_NOTIFICATION_TOPIC","ERROR_NOTIFICATION_TOPIC"],
}
)

Expand Down
11 changes: 11 additions & 0 deletions src/sagemaker/async_inference/async_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
max_concurrent_invocations_per_instance=None,
kms_key_id=None,
notification_config=None,
failure_path=None,
):
"""Initialize an AsyncInferenceConfig object for async inference configuration.

Expand All @@ -45,6 +46,9 @@ def __init__(
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
asynchronous inference output in Amazon S3. (Default: None)
failure_path (str): Optional. The Amazon S3 location that endpoints upload model
responses for failed requests. If no value is provided, Amazon SageMaker will
use default Amazon S3 Async Inference failure path. (Default: None)
notification_config (dict): Optional. Specifies the configuration for notifications
of inference results for asynchronous inference. Only one notification is generated
per invocation request (Default: None):
Expand All @@ -54,17 +58,24 @@ def __init__(
* error_topic (str): Amazon SNS topic to post a notification to when inference
fails. If no topic is provided, no notification is sent on failure.
The key in notification_config is 'ErrorTopic'.
* include_inference_response_in (list): Optional. When provided the inference
response will be included in the notification topics. If not provided,
a notification will still be generated on success/error, but will not
contain the inference response.
Valid options are SUCCESS_NOTIFICATION_TOPIC, ERROR_NOTIFICATION_TOPIC
"""
self.output_path = output_path
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
self.kms_key_id = kms_key_id
self.notification_config = notification_config
self.failure_path = failure_path

def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
request_dict = {
"OutputConfig": {
"S3OutputPath": self.output_path,
"S3FailurePath": self.failure_path,
},
}

Expand Down
46 changes: 30 additions & 16 deletions src/sagemaker/async_inference/async_inference_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from botocore.exceptions import ClientError
from sagemaker.s3 import parse_s3_url
from sagemaker.async_inference import WaiterConfig
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
from sagemaker.exceptions import (
ObjectNotExistedError,
UnexpectedClientError,
AsyncInferenceModelError,
)


class AsyncInferenceResponse(object):
Expand All @@ -32,6 +36,7 @@ def __init__(
self,
predictor_async,
output_path,
failure_path,
):
"""Initialize an AsyncInferenceResponse object.

Expand All @@ -43,10 +48,13 @@ def __init__(
that return this response.
output_path (str): The Amazon S3 location that endpoints upload inference responses
to.
failure_path (str): The Amazon S3 location that endpoints upload model errors
for failed requests.
"""
self.predictor_async = predictor_async
self.output_path = output_path
self._result = None
self.failure_path = failure_path

def get_result(
self,
Expand All @@ -71,28 +79,34 @@ def get_result(

if self._result is None:
if waiter_config is None:
self._result = self._get_result_from_s3(self.output_path)
self._result = self._get_result_from_s3(self.output_path, self.failure_path)
else:
self._result = self.predictor_async._wait_for_output(
self.output_path, waiter_config
self.output_path, self.failure_path, waiter_config
)
return self._result

def _get_result_from_s3(
self,
output_path,
):
def _get_result_from_s3(self, output_path, failure_path):
"""Get inference result from the output Amazon S3 path"""
bucket, key = parse_s3_url(output_path)
try:
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
return self.predictor_async.predictor._handle_response(response)
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise ObjectNotExistedError(
message="Inference could still be running",
output_path=output_path,
)
raise UnexpectedClientError(
message=ex.response["Error"]["Message"],
)
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
try:
failure_bucket, failure_key = parse_s3_url(failure_path)
failure_response = self.predictor_async.s3_client.get_object(
Bucket=failure_bucket, Key=failure_key
)
failure_response = self.predictor_async.predictor._handle_response(
failure_response
)
raise AsyncInferenceModelError(message=failure_response)
except ClientError as ex:
if ex.response["Error"]["Code"] == "NoSuchKey":
raise ObjectNotExistedError(
message="Inference could still be running", output_path=output_path
)
raise UnexpectedClientError(message=ex.response["Error"]["Message"])
raise UnexpectedClientError(message=e.response["Error"]["Message"])
9 changes: 9 additions & 0 deletions src/sagemaker/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,12 @@ def __init__(self, **kwargs):
msg = self.fmt.format(**kwargs)
Exception.__init__(self, msg)
self.kwargs = kwargs


class AsyncInferenceModelError(AsyncInferenceError):
"""Raised when model returns errors for failed requests"""

fmt = "Model returned error: {message} "

def __init__(self, message):
super().__init__(message=message)
23 changes: 17 additions & 6 deletions src/sagemaker/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,7 +1277,10 @@ def deploy(

async_inference_config_dict = None
if is_async:
if async_inference_config.output_path is None:
if (
async_inference_config.output_path is None
or async_inference_config.failure_path is None
):
async_inference_config = self._build_default_async_inference_config(
async_inference_config
)
Expand Down Expand Up @@ -1316,11 +1319,19 @@ def deploy(

def _build_default_async_inference_config(self, async_inference_config):
"""Build default async inference config and return ``AsyncInferenceConfig``"""
async_output_folder = unique_name_from_base(self.name)
async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format(
self.sagemaker_session.default_bucket(), async_output_folder
)
async_inference_config.output_path = async_output_s3uri
unique_folder = unique_name_from_base(self.name)
if async_inference_config.output_path is None:
async_output_s3uri = "s3://{}/async-endpoint-outputs/{}".format(
self.sagemaker_session.default_bucket(), unique_folder
)
async_inference_config.output_path = async_output_s3uri

if async_inference_config.failure_path is None:
async_failure_s3uri = "s3://{}/async-endpoint-failures/{}".format(
self.sagemaker_session.default_bucket(), unique_folder
)
async_inference_config.failure_path = async_failure_s3uri

return async_inference_config

def transformer(
Expand Down
105 changes: 81 additions & 24 deletions src/sagemaker/predictor_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# language governing permissions and limitations under the License.
"""Placeholder docstring"""
from __future__ import absolute_import

import threading
import time
import uuid
from botocore.exceptions import WaiterError
from sagemaker.exceptions import PollingTimeoutError
from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError
from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse
from sagemaker.s3 import parse_s3_url
from sagemaker.session import Session
Expand Down Expand Up @@ -98,7 +99,10 @@ def predict(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
result = self._wait_for_output(output_path=output_location, waiter_config=waiter_config)
failure_location = response["FailureLocation"]
result = self._wait_for_output(
output_path=output_location, failure_path=failure_location, waiter_config=waiter_config
)

return result

Expand Down Expand Up @@ -141,9 +145,11 @@ def predict_async(
self._input_path = input_path
response = self._submit_async_request(input_path, initial_args, inference_id)
output_location = response["OutputLocation"]
failure_location = response["FailureLocation"]
response_async = AsyncInferenceResponse(
predictor_async=self,
output_path=output_location,
failure_path=failure_location,
)

return response_async
Expand Down Expand Up @@ -209,30 +215,81 @@ def _submit_async_request(

return response

def _wait_for_output(
self,
output_path,
waiter_config,
):
def _wait_for_output(self, output_path, failure_path, waiter_config):
"""Check the Amazon S3 output path for the output.

Periodically check Amazon S3 output path for async inference result.
Timeout automatically after max attempts reached
"""
bucket, key = parse_s3_url(output_path)
s3_waiter = self.s3_client.get_waiter("object_exists")
try:
s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict())
except WaiterError:
raise PollingTimeoutError(
message="Inference could still be running",
output_path=output_path,
seconds=waiter_config.delay * waiter_config.max_attempts,
)
This method waits for either the output file or the failure file to be found on the
specified S3 output path. Whichever file is found first, its corresponding event is
triggered, and the method executes the appropriate action based on the event.

s3_object = self.s3_client.get_object(Bucket=bucket, Key=key)
result = self.predictor._handle_response(response=s3_object)
return result
Args:
output_path (str): The S3 path where the output file is expected to be found.
failure_path (str): The S3 path where the failure file is expected to be found.
waiter_config (boto3.waiter.WaiterConfig): The configuration for the S3 waiter.

Returns:
Any: The deserialized result from the output file, if the output file is found first.
Otherwise, raises an exception.

Raises:
AsyncInferenceModelError: If the failure file is found before the output file.
PollingTimeoutError: If both files are not found and the S3 waiter
has thrown a WaiterError.
"""
output_bucket, output_key = parse_s3_url(output_path)
failure_bucket, failure_key = parse_s3_url(failure_path)

output_file_found = threading.Event()
failure_file_found = threading.Event()

def check_output_file():
try:
output_file_waiter = self.s3_client.get_waiter("object_exists")
output_file_waiter.wait(
Bucket=output_bucket,
Key=output_key,
WaiterConfig=waiter_config._to_request_dict(),
)
output_file_found.set()
except WaiterError:
pass

def check_failure_file():
try:
failure_file_waiter = self.s3_client.get_waiter("object_exists")
failure_file_waiter.wait(
Bucket=failure_bucket,
Key=failure_key,
WaiterConfig=waiter_config._to_request_dict(),
)
failure_file_found.set()
except WaiterError:
pass

output_thread = threading.Thread(target=check_output_file)
failure_thread = threading.Thread(target=check_failure_file)

output_thread.start()
failure_thread.start()

while not output_file_found.is_set() and not failure_file_found.is_set():
time.sleep(1)

if output_file_found.is_set():
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
result = self.predictor._handle_response(response=s3_object)
return result

failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
failure_response = self.predictor._handle_response(response=failure_object)

raise AsyncInferenceModelError(
message=failure_response
) if failure_file_found.is_set() else PollingTimeoutError(
message="Inference could still be running",
output_path=output_path,
seconds=waiter_config.delay * waiter_config.max_attempts,
)

def update_endpoint(
self,
Expand Down
6 changes: 6 additions & 0 deletions tests/integ/test_async_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set):
assert result_no_wait_with_data.output_path.startswith(
"s3://" + sagemaker_session.default_bucket()
)
assert result_no_wait_with_data.failure_path.startswith(
"s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/"
)
time.sleep(5)
result_no_wait_with_data = result_no_wait_with_data.get_result()
assert len(result_no_wait_with_data) == 5
Expand Down Expand Up @@ -97,6 +100,9 @@ def test_async_walkthrough(sagemaker_session, cpu_instance_type, training_set):
result_not_wait = predictor_async.predict_async(input_path=input_s3_path)
assert isinstance(result_not_wait, AsyncInferenceResponse)
assert result_not_wait.output_path.startswith("s3://" + sagemaker_session.default_bucket())
assert result_not_wait.failure_path.startswith(
"s3://" + sagemaker_session.default_bucket() + "/async-endpoint-failures/"
)
time.sleep(5)
result_not_wait = result_not_wait.get_result()
assert len(result_not_wait) == 5
Expand Down
Loading