Skip to content

Commit 4873622

Browse files
committed
feature: Adds support for async inference
1 parent d9e2567 commit 4873622

23 files changed

+1659
-6
lines changed

doc/api/inference/async_inference.rst

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Async Inference
2+
-----------------
3+
4+
This module contains classes related to Amazon Sagemaker Async Inference
5+
6+
.. automodule:: sagemaker.async_inference.async_inference_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.async_inference.async_inference_response
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
.. automodule:: sagemaker.async_inference.waiter_config
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:

doc/api/inference/predictor_async.rst

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AsyncPredictor
2+
--------------------
3+
4+
Make async predictions against SageMaker endpoints with Python objects
5+
6+
.. autoclass:: sagemaker.predictor_async.AsyncPredictor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:

doc/overview.rst

+92
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,98 @@ For more detailed explanations of the classes that this library provides for aut
684684
- `API docs for HyperparameterTuner and parameter range classes <https://sagemaker.readthedocs.io/en/stable/tuner.html>`__
685685
- `API docs for analytics classes <https://sagemaker.readthedocs.io/en/stable/analytics.html>`__
686686

687+
**********************************
688+
SageMaker Asynchronous Inference
689+
**********************************
690+
Amazon SageMaker Asynchronous Inference is a new capability in SageMaker that queues incoming requests and processes them asynchronously.
691+
This option is ideal for requests with large payload sizes up to 1GB, long processing times, and near real-time latency requirements.
692+
You can configure Asynchronous Inference scale the instance count to zero when there are no requests to process, thereby saving costs.
693+
More information about SageMaker Asynchronous Inference can be found in the `AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html>`__.
694+
695+
To deploy asynchronous inference endpoint, you will need to create a ``AsyncInferenceConfig`` object.
696+
If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will
697+
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-outputs/{UNIQUE-JOB-NAME}``. (example shown below):
698+
699+
.. code:: python
700+
701+
from sagemaker.async_inference import AsyncInferenceConfig
702+
703+
# Create an empty AsyncInferenceConfig object to use default values
704+
async_config = new AsyncInferenceConfig()
705+
706+
Or you can specify configurations in ``AsyncInferenceConfig`` as you like. All of those configuration parameters
707+
are optional but if you don’t specify the ``output_path``, Amazon SageMaker will use the default ``S3OutputPath``
708+
mentioned above (example shown below):
709+
710+
.. code:: python
711+
712+
# Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object
713+
async_config = new AsyncInferenceConfig(
714+
output_path="s3://{s3_bucket}/{bucket_prefix}/output",
715+
max_concurrent_invocations_per_instance=10,
716+
notification_config = {
717+
"SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name",
718+
"ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name",
719+
}
720+
)
721+
722+
Then use the ``AsyncInferenceConfig`` in the estimator's ``deploy()`` method to deploy an asynchronous inference endpoint:
723+
724+
.. code:: python
725+
726+
# Deploys the model that was generated by fit() to a SageMaker asynchronous inference endpoint
727+
async_predictor = estimator.deploy(async_inference_config=async_config)
728+
729+
After deployment is complete, it will return an ``AsyncPredictor`` object. To perform asynchronous inference, you first
730+
need to upload data to S3 and then use the ``predict_async()`` method with the s3 URI as the input. It will return an
731+
``AsyncInferenceResponse`` object:
732+
733+
.. code:: python
734+
735+
# Upload data to S3 bucket then use that as input
736+
async_response = async_predictor.predict_async(input_path=input_s3_path)
737+
738+
The Amazon SageMaker SDK also enables you to serialize the data and pass the payload data directly to the
739+
``predict_async()`` method. For this pattern of invocation, the Amazon SageMaker SDK will upload the data to an Amazon
740+
S3 bucket under ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-endpoint-inputs/``.
741+
742+
.. code:: python
743+
744+
# Serializes data and makes a prediction request to the SageMaker asynchronous endpoint
745+
async_response = async_predictor.predict_async(data=data)
746+
747+
Then you can switch to other stuff and wait the inference to complete. After it is completed, you can check
748+
the result using ``AsyncInferenceResponse``:
749+
750+
.. code:: python
751+
752+
# Switch back to check the result
753+
result = async_response.get_result()
754+
755+
Alternatively, if you would like to check for a result periodically and return it upon generation, use the
756+
``predict()`` method
757+
758+
.. code:: python
759+
760+
# Use predict() to wait for the result
761+
response = async_predictor.predict(data=data)
762+
763+
# Or use Amazon S3 input path
764+
response = async_predictor.predict(input_path=input_s3_path)
765+
766+
Clean up the endpoint and model if needed after inference:
767+
768+
.. code:: python
769+
770+
# Tears down the SageMaker endpoint and endpoint configuration
771+
async_predictor.delete_endpoint()
772+
773+
# Deletes the SageMaker model
774+
async_predictor.delete_model()
775+
776+
For more details about Asynchronous Inference,
777+
see the API docs for `Asynchronous Inference <https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html>`__
778+
687779
*******************************
688780
SageMaker Serverless Inference
689781
*******************************
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig # noqa: F401
18+
from sagemaker.async_inference.waiter_config import WaiterConfig # noqa: F401
19+
from sagemaker.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceConfig
14+
15+
Used for configuring async inference endpoint. Use AsyncInferenceConfig when deploying
16+
the model to the async inference endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class AsyncInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to async endpoint. Use this configuration
25+
when trying to create async endpoint and make async inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
output_path=None,
31+
max_concurrent_invocations_per_instance=None,
32+
kms_key_id=None,
33+
notification_config=None,
34+
):
35+
"""Initialize an AsyncInferenceConfig object for async inference configuration.
36+
37+
Args:
38+
output_path (str): Optional. The Amazon S3 location that endpoints upload
39+
inference responses to. If no value is provided, Amazon SageMaker will
40+
use default Amazon S3 Async Inference output path. (Default: None)
41+
max_concurrent_invocations_per_instance (int): Optional. The maximum number of
42+
concurrent requests sent by the SageMaker client to the model container. If
43+
no value is provided, Amazon SageMaker will choose an optimal value for you.
44+
(Default: None)
45+
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
46+
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
47+
asynchronous inference output in Amazon S3. (Default: None)
48+
notification_config (dict): Optional. Specifies the configuration for notifications
49+
of inference results for asynchronous inference. Only one notification is generated
50+
per invocation request (Default: None):
51+
* success_topic (str): Amazon SNS topic to post a notification to when inference
52+
completes successfully. If no topic is provided, no notification is sent on success.
53+
The key in notification_config is 'SuccessTopic'.
54+
* error_topic (str): Amazon SNS topic to post a notification to when inference
55+
fails. If no topic is provided, no notification is sent on failure.
56+
The key in notification_config is 'ErrorTopic'.
57+
"""
58+
self.output_path = output_path
59+
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
60+
self.kms_key_id = kms_key_id
61+
self.notification_config = notification_config
62+
63+
def _to_request_dict(self):
64+
"""Generates a request dictionary using the parameters provided to the class."""
65+
request_dict = {
66+
"OutputConfig": {
67+
"S3OutputPath": self.output_path,
68+
},
69+
}
70+
71+
if self.max_concurrent_invocations_per_instance:
72+
request_dict["ClientConfig"] = {
73+
"MaxConcurrentInvocationsPerInstance": self.max_concurrent_invocations_per_instance
74+
}
75+
76+
if self.kms_key_id:
77+
request_dict["OutputConfig"]["KmsKeyId"] = self.kms_key_id
78+
79+
if self.notification_config:
80+
request_dict["OutputConfig"]["NotificationConfig"] = self.notification_config
81+
82+
return request_dict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceResponse"""
14+
15+
from __future__ import print_function, absolute_import
16+
17+
from botocore.exceptions import ClientError
18+
from sagemaker.s3 import parse_s3_url
19+
from sagemaker.async_inference import WaiterConfig
20+
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
21+
22+
23+
class AsyncInferenceResponse(object):
24+
"""Response from Async Inference endpoint
25+
26+
This response object provides a method to check for an async inference result in the
27+
Amazon S3 output path specified. If result object exists in that path, get and return
28+
the result
29+
"""
30+
31+
def __init__(
32+
self,
33+
predictor_async,
34+
output_path,
35+
):
36+
"""Initialize an AsyncInferenceResponse object.
37+
38+
AsyncInferenceResponse can help users to get async inference result
39+
from the Amazon S3 output path
40+
41+
Args:
42+
predictor_async (sagemaker.predictor.AsyncPredictor): The ``AsyncPredictor``
43+
that return this response.
44+
output_path (str): The Amazon S3 location that endpoints upload inference responses
45+
to.
46+
"""
47+
self.predictor_async = predictor_async
48+
self.output_path = output_path
49+
self._result = None
50+
51+
def get_result(
52+
self,
53+
waiter_config=None,
54+
):
55+
"""Get async inference result in the Amazon S3 output path specified
56+
57+
Args:
58+
waiter_config (sagemaker.async_inference.waiter_config.WaiterConfig): Configuration
59+
for the waiter. The pre-defined value for the delay between poll is 15 seconds
60+
and the default max attempts is 60
61+
Raises:
62+
ValueError: If a wrong type of object is provided as ``waiter_config``
63+
Returns:
64+
object: Inference result in the given Amazon S3 output path. If a deserializer was
65+
specified when creating the AsyncPredictor, the result of the deserializer is
66+
returned. Otherwise the response returns the sequence of bytes
67+
as is.
68+
"""
69+
if waiter_config is not None and not isinstance(waiter_config, WaiterConfig):
70+
raise ValueError("waiter_config should be a WaiterConfig object")
71+
72+
if self._result is None:
73+
if waiter_config is None:
74+
self._result = self._get_result_from_s3(self.output_path)
75+
else:
76+
self._result = self.predictor_async._wait_for_output(
77+
self.output_path, waiter_config
78+
)
79+
return self._result
80+
81+
def _get_result_from_s3(
82+
self,
83+
output_path,
84+
):
85+
"""Get inference result from the output Amazon S3 path"""
86+
bucket, key = parse_s3_url(output_path)
87+
try:
88+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
89+
return self.predictor_async.predictor._handle_response(response)
90+
except ClientError as ex:
91+
if ex.response["Error"]["Code"] == "NoSuchKey":
92+
raise ObjectNotExistedError(
93+
message="Inference could still be running",
94+
output_path=output_path,
95+
)
96+
raise UnexpectedClientError(
97+
message=ex.response["Error"]["Message"],
98+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for WaiterConfig used in async inference
14+
15+
Use it when using async inference and wait for the result.
16+
"""
17+
18+
from __future__ import absolute_import
19+
20+
21+
class WaiterConfig(object):
22+
"""Configuration object passed in when using async inference and wait for the result."""
23+
24+
def __init__(
25+
self,
26+
max_attempts=60,
27+
delay=15,
28+
):
29+
"""Initialize a WaiterConfig object that provides parameters to control waiting behavior.
30+
31+
Args:
32+
max_attempts (int): The maximum number of attempts to be made. If the max attempts is
33+
exceeded, Amazon SageMaker will raise ``PollingTimeoutError``. (Default: 60)
34+
delay (int): The amount of time in seconds to wait between attempts. (Default: 15)
35+
"""
36+
37+
self.max_attempts = max_attempts
38+
self.delay = delay
39+
40+
def _to_request_dict(self):
41+
"""Generates a dictionary using the parameters provided to the class."""
42+
waiter_dict = {
43+
"Delay": self.delay,
44+
"MaxAttempts": self.max_attempts,
45+
}
46+
47+
return waiter_dict

0 commit comments

Comments
 (0)