Skip to content

Commit 3d042bd

Browse files
committed
feature: add support for async inference
1 parent 554d735 commit 3d042bd

22 files changed

+1482
-2
lines changed

doc/api/inference/async_inference.rst

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Async Inference
2+
-----------------
3+
4+
This module contains classes related to Amazon Sagemaker Async Inference
5+
6+
.. automodule:: sagemaker.async_inference.async_inference_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.async_inference.async_inference_response
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
.. automodule:: sagemaker.async_inference.waiter_config
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:

doc/api/inference/predictor_async.rst

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AsyncPredictor
2+
--------------------
3+
4+
Make async predictions against SageMaker endpoints with Python objects
5+
6+
.. autoclass:: sagemaker.predictor_async.AsyncPredictor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig # noqa: F401
18+
from sagemaker.async_inference.waiter_config import WaiterConfig # noqa: F401
19+
from sagemaker.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceConfig
14+
15+
Codes are used for configuring async inference endpoint. Use it when deploying
16+
the model to the endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class AsyncInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to async endpoint. Use this configuration
25+
when trying to create async endpoint and make async inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
output_path=None,
31+
max_concurrent_invocations_per_instance=None,
32+
kms_key_id=None,
33+
notification_config=None,
34+
):
35+
"""Initialize an AsyncInferenceConfig object for async inference related configuration.
36+
37+
Args:
38+
output_path (str): Optional. The Amazon S3 location that endpoints upload
39+
inference responses to. If no value is provided, Amazon SageMaker will
40+
use default Amazon S3 Async Inference output path. (Default: None)
41+
max_concurrent_invocations_per_instance (int): Optional. The maximum number of
42+
concurrent requests sent by the SageMaker client to the model container. If
43+
no value is provided, Amazon SageMaker will choose an optimal value for you.
44+
(Default: None)
45+
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
46+
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
47+
asynchronous inference output in Amazon S3. (Default: None)
48+
notification_config (dict): Optional. Specifies the configuration for notifications
49+
of inference results for asynchronous inference (Default: None):
50+
* success_topic (str): Amazon SNS topic to post a notification to when inference
51+
completes successfully. If no topic is provided, no notification is sent on success.
52+
The key in notification_config is 'SuccessTopic'.
53+
* error_topic (str): Amazon SNS topic to post a notification to when inference
54+
fails. If no topic is provided, no notification is sent on failure.
55+
The key in notification_config is 'ErrorTopic'.
56+
"""
57+
self.output_path = output_path
58+
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
59+
self.kms_key_id = kms_key_id
60+
self.notification_config = notification_config
61+
62+
def _to_request_dict(self):
63+
"""Generates a request dictionary using the parameters provided to the class."""
64+
request_dict = {
65+
"OutputConfig": {
66+
"S3OutputPath": self.output_path,
67+
},
68+
}
69+
70+
if self.max_concurrent_invocations_per_instance:
71+
request_dict["ClientConfig"] = {
72+
"MaxConcurrentInvocationsPerInstance": self.max_concurrent_invocations_per_instance
73+
}
74+
75+
if self.kms_key_id:
76+
request_dict["OutputConfig"]["KmsKeyId"] = self.kms_key_id
77+
78+
if self.notification_config:
79+
request_dict["OutputConfig"]["NotificationConfig"] = self.notification_config
80+
81+
return request_dict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceResponse"""
14+
15+
from __future__ import print_function, absolute_import
16+
17+
from botocore.exceptions import ClientError
18+
from sagemaker.s3 import parse_s3_url
19+
from sagemaker.async_inference import WaiterConfig
20+
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
21+
22+
23+
class AsyncInferenceResponse(object):
24+
"""Response from Async Inference endpoint
25+
26+
This response object provides a method to check the async Amazon S3
27+
output path. If result object exists in that path, decode and return
28+
the result
29+
"""
30+
31+
def __init__(
32+
self,
33+
predictor_async,
34+
output_path,
35+
):
36+
"""Initialize an AsyncInferenceResponse object.
37+
38+
AsyncInferenceResponse can help users to get async inference result
39+
from the Amazon S3 output path
40+
41+
Args:
42+
predictor_async (sagemaker.predictor.AsyncPredictor): The ``AsyncPredictor``
43+
that return this response.
44+
output_path (str): The Amazon S3 location that endpoints upload inference responses
45+
to.
46+
"""
47+
self.predictor_async = predictor_async
48+
self.output_path = output_path
49+
self._result = None
50+
51+
def get_result(
52+
self,
53+
waiter_config=None,
54+
):
55+
"""Get result from the async Amazon S3 output path
56+
57+
Args:
58+
waiter_config (sagemaker.async_inference.waiter_config.WaiterConfig): Configuration
59+
for the waiter. The pre-defined value for the delay between poll is 15 seconds
60+
and the default max attempts is 60
61+
Raises:
62+
ValueError: If a wrong type of object is provided as ``waiter_config``
63+
Returns:
64+
object: Inference result in the given Amazon S3 output path. If a deserializer was
65+
specified when creating the AsyncPredictor, the result of the deserializer is
66+
returned. Otherwise the response returns the sequence of bytes
67+
as is.
68+
"""
69+
if waiter_config is not None and not isinstance(waiter_config, WaiterConfig):
70+
raise ValueError("waiter_config should be a WaiterConfig object")
71+
72+
if self._result is None:
73+
if waiter_config is None:
74+
self._result = self._get_result_from_s3(self.output_path)
75+
else:
76+
self._result = self.predictor_async._wait_for_output(
77+
self.output_path, waiter_config
78+
)
79+
return self._result
80+
81+
def _get_result_from_s3(
82+
self,
83+
output_path,
84+
):
85+
"""Get inference result from the output Amazon S3 path"""
86+
bucket, key = parse_s3_url(output_path)
87+
try:
88+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
89+
return self.predictor_async.predictor._handle_response(response)
90+
except ClientError as ex:
91+
if ex.response["Error"]["Code"] == "NoSuchKey":
92+
raise ObjectNotExistedError(
93+
message="Inference could still be running",
94+
output_path=output_path,
95+
)
96+
raise UnexpectedClientError(
97+
message=ex.response["Error"]["Message"],
98+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for WaiterConfig used in async inference
14+
15+
Use it when using async inference and wait for the result.
16+
"""
17+
18+
from __future__ import absolute_import
19+
20+
21+
class WaiterConfig(object):
22+
"""Configuration object passed in when using async inference and wait for the result."""
23+
24+
def __init__(
25+
self,
26+
max_attempts=60,
27+
delay=15,
28+
):
29+
"""Initialize a WaiterConfig object that provides parameters to control waiting behavior.
30+
31+
Args:
32+
max_attempts (int): The maximum number of attempts to be made. (Default: 60)
33+
delay (int): The amount of time in seconds to wait between attempts. (Default: 15)
34+
"""
35+
36+
self.max_attempts = max_attempts
37+
self.delay = delay
38+
39+
def _to_waiter_dict(self):
40+
"""Generates a dictionary using the parameters provided to the class."""
41+
waiter_dict = {
42+
"Delay": self.delay,
43+
"MaxAttempts": self.max_attempts,
44+
}
45+
46+
return waiter_dict

src/sagemaker/estimator.py

+7
Original file line numberDiff line numberDiff line change
@@ -864,6 +864,7 @@ def deploy(
864864
kms_key=None,
865865
data_capture_config=None,
866866
tags=None,
867+
async_inference_config=None,
867868
**kwargs,
868869
):
869870
"""Deploy the trained model to an Amazon SageMaker endpoint.
@@ -910,6 +911,11 @@ def deploy(
910911
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): Specifies
911912
configuration related to Endpoint data capture for use with
912913
Amazon SageMaker Model Monitoring. Default: None.
914+
async_inference_config (sagemaker.model_monitor.AsyncInferenceConfig): Specifies
915+
configuration related to async endpoint. Use this configuration when trying
916+
to create async endpoint and make async inference. If empty config object
917+
passed through, we will use default config to deploy async endpoint
918+
(default: None)
913919
tags(List[dict[str, str]]): Optional. The list of tags to attach to this specific
914920
endpoint. Example:
915921
>>> tags = [{'Key': 'tagname', 'Value': 'tagvalue'}]
@@ -959,6 +965,7 @@ def deploy(
959965
wait=wait,
960966
kms_key=kms_key,
961967
data_capture_config=data_capture_config,
968+
async_inference_config=async_inference_config,
962969
)
963970

964971
def register(

src/sagemaker/exceptions.py

+40
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,43 @@ def __init__(self, message, allowed_statuses, actual_status):
2121
self.allowed_statuses = allowed_statuses
2222
self.actual_status = actual_status
2323
super(UnexpectedStatusException, self).__init__(message)
24+
25+
26+
class AsyncInferenceError(Exception):
27+
"""The base exception class for Async Inference exceptions."""
28+
29+
fmt = "An unspecified error occurred"
30+
31+
def __init__(self, **kwargs):
32+
msg = self.fmt.format(**kwargs)
33+
Exception.__init__(self, msg)
34+
self.kwargs = kwargs
35+
36+
37+
class ObjectNotExistedError(AsyncInferenceError):
38+
"""Raised when Amazon S3 object not exist in the given path"""
39+
40+
fmt = "Object not exist at {output_path}. {message}"
41+
42+
def __init__(self, message, output_path):
43+
super(ObjectNotExistedError, self).__init__(message=message, output_path=output_path)
44+
45+
46+
class PollingTimeoutError(AsyncInferenceError):
47+
"""Raised when wait longer than expected and no result object in Amazon S3 bucket yet"""
48+
49+
fmt = "No result at {output_path} after polling for {seconds} seconds. {message}"
50+
51+
def __init__(self, message, output_path, seconds):
52+
super(PollingTimeoutError, self).__init__(
53+
message=message, output_path=output_path, seconds=seconds
54+
)
55+
56+
57+
class UnexpectedClientError(AsyncInferenceError):
58+
"""Raised when ClientError's error code is not expected"""
59+
60+
fmt = "Encountered unexpected client error: {message}"
61+
62+
def __init__(self, message):
63+
super(UnexpectedClientError, self).__init__(message=message)

0 commit comments

Comments
 (0)