Skip to content

Commit d5fb328

Browse files
committed
feature: add support for async inference
1 parent 70308b1 commit d5fb328

23 files changed

+1640
-6
lines changed

doc/api/inference/async_inference.rst

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
Async Inference
2+
-----------------
3+
4+
This module contains classes related to Amazon Sagemaker Async Inference
5+
6+
.. automodule:: sagemaker.async_inference.async_inference_config
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:
10+
11+
.. automodule:: sagemaker.async_inference.async_inference_response
12+
:members:
13+
:undoc-members:
14+
:show-inheritance:
15+
16+
.. automodule:: sagemaker.async_inference.waiter_config
17+
:members:
18+
:undoc-members:
19+
:show-inheritance:

doc/api/inference/predictor_async.rst

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
AsyncPredictor
2+
--------------------
3+
4+
Make async predictions against SageMaker endpoints with Python objects
5+
6+
.. autoclass:: sagemaker.predictor_async.AsyncPredictor
7+
:members:
8+
:undoc-members:
9+
:show-inheritance:

doc/overview.rst

+90
Original file line numberDiff line numberDiff line change
@@ -684,6 +684,96 @@ For more detailed explanations of the classes that this library provides for aut
684684
- `API docs for HyperparameterTuner and parameter range classes <https://sagemaker.readthedocs.io/en/stable/tuner.html>`__
685685
- `API docs for analytics classes <https://sagemaker.readthedocs.io/en/stable/analytics.html>`__
686686

687+
**********************************
688+
SageMaker Asynchronous Inference
689+
**********************************
690+
Amazon SageMaker Asynchronous Inference is a new capability in SageMaker that queues incoming requests and processes them asynchronously.
691+
This option is ideal for requests with large payload sizes up to 1GB, long processing times, and near real-time latency requirements.
692+
Asynchronous Inference enables you to save on costs by autoscaling the instance count to zero when there are no requests to process,
693+
so you only pay when your endpoint is processing requests. More information about
694+
SageMaker Serverless Inference can be found in the `AWS documentation <https://docs.aws.amazon.com/sagemaker/latest/dg/async-inference.html>`__.
695+
696+
To deploy asynchronous endpoint, you will need to create a ``AsyncInferenceConfig`` object.
697+
If you create ``AsyncInferenceConfig`` without specifying its arguments, the default ``S3OutputPath`` will
698+
be ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-output/{UNIQUE-JOB-NAME}``. (example shown below):
699+
700+
.. code:: python
701+
702+
from sagemaker.async_inference import AsyncInferenceConfig
703+
704+
# Create an empty AsyncInferenceConfig object to use default values
705+
async_config = new AsyncInferenceConfig()
706+
707+
Or you can specify configurations in ``AsyncInferenceConfig`` as you like (example shown below):
708+
709+
.. code:: python
710+
711+
# Specify S3OutputPath, MaxConcurrentInvocationsPerInstance and NotificationConfig in the async config object
712+
async_config = new AsyncInferenceConfig(
713+
output_path="s3://{s3_bucket}/{bucket_prefix}/output",
714+
max_concurrent_invocations_per_instance=10,
715+
notification_config = {
716+
"SuccessTopic": "arn:aws:sns:aws-region:account-id:topic-name",
717+
"ErrorTopic": "arn:aws:sns:aws-region:account-id:topic-name",
718+
}
719+
)
720+
721+
Then use the ``AsyncInferenceConfig`` in the estimator's ``deploy()`` method to deploy an asynchronous endpoint:
722+
723+
.. code:: python
724+
725+
# Deploys the model that was generated by fit() to a SageMaker asynchronous endpoint
726+
async_predictor = estimator.deploy(async_inference_config=async_config)
727+
728+
After deployment is complete, it will return an ``AsyncPredictor``. You can use it to perform asynchronous inference
729+
by using ``predict_async()`` and then get the result in the future. For input data, you can upload data to S3 bucket
730+
and use that:
731+
732+
.. code:: python
733+
734+
# Upload data to S3 bucket then use that as input
735+
async_response = async_predictor.predict_async(input_path=input_s3_path)
736+
737+
Or you can serialize data and use it directly just like real-time inference. This option will let Amazon SageMaker SDK
738+
upload the data to Amazon S3 bucket under ``s3://sagemaker-{REGION}-{ACCOUNTID}/async-input/``.
739+
740+
.. code:: python
741+
742+
# Serializes data and makes a prediction request to the SageMaker asynchronous endpoint
743+
async_response = async_predictor.predict_async(data=data)
744+
745+
Then you can switch to other stuff and wait the inference to complete. After it completed, you can check
746+
the result then:
747+
748+
.. code:: python
749+
750+
# Switch back to check the result
751+
result = async_response.get_result()
752+
753+
If you want to wait the result at the first place, you can use ``predict()`` method. It will check the result
754+
periodically and return the result when it appears in the output Amazon S3 path:
755+
756+
.. code:: python
757+
758+
# Use predict() to wait for the result
759+
response = async_predictor.predict(data=data)
760+
761+
# Or use Amazon S3 input path
762+
response = async_predictor.predict(input_path=input_s3_path)
763+
764+
Clean up the endpoint and model if needed after inference:
765+
766+
.. code:: python
767+
768+
# Tears down the SageMaker endpoint and endpoint configuration
769+
async_predictor.delete_endpoint()
770+
771+
# Deletes the SageMaker model
772+
async_predictor.delete_model()
773+
774+
For more details about Asynchronous Inference,
775+
see the API docs for `Asynchronous Inference <https://sagemaker.readthedocs.io/en/stable/api/inference/async_inference.html>`__
776+
687777
*******************************
688778
SageMaker Serverless Inference
689779
*******************************
+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Imports the classes in this module to simplify customer imports"""
14+
15+
from __future__ import absolute_import
16+
17+
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig # noqa: F401
18+
from sagemaker.async_inference.waiter_config import WaiterConfig # noqa: F401
19+
from sagemaker.async_inference.async_inference_response import AsyncInferenceResponse # noqa: F401
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceConfig
14+
15+
Codes are used for configuring async inference endpoint. Use it when deploying
16+
the model to the endpoints.
17+
"""
18+
from __future__ import print_function, absolute_import
19+
20+
21+
class AsyncInferenceConfig(object):
22+
"""Configuration object passed in when deploying models to Amazon SageMaker Endpoints.
23+
24+
This object specifies configuration related to async endpoint. Use this configuration
25+
when trying to create async endpoint and make async inference
26+
"""
27+
28+
def __init__(
29+
self,
30+
output_path=None,
31+
max_concurrent_invocations_per_instance=None,
32+
kms_key_id=None,
33+
notification_config=None,
34+
):
35+
"""Initialize an AsyncInferenceConfig object for async inference related configuration.
36+
37+
Args:
38+
output_path (str): Optional. The Amazon S3 location that endpoints upload
39+
inference responses to. If no value is provided, Amazon SageMaker will
40+
use default Amazon S3 Async Inference output path. (Default: None)
41+
max_concurrent_invocations_per_instance (int): Optional. The maximum number of
42+
concurrent requests sent by the SageMaker client to the model container. If
43+
no value is provided, Amazon SageMaker will choose an optimal value for you.
44+
(Default: None)
45+
kms_key_id (str): Optional. The Amazon Web Services Key Management Service
46+
(Amazon Web Services KMS) key that Amazon SageMaker uses to encrypt the
47+
asynchronous inference output in Amazon S3. (Default: None)
48+
notification_config (dict): Optional. Specifies the configuration for notifications
49+
of inference results for asynchronous inference (Default: None):
50+
* success_topic (str): Amazon SNS topic to post a notification to when inference
51+
completes successfully. If no topic is provided, no notification is sent on success.
52+
The key in notification_config is 'SuccessTopic'.
53+
* error_topic (str): Amazon SNS topic to post a notification to when inference
54+
fails. If no topic is provided, no notification is sent on failure.
55+
The key in notification_config is 'ErrorTopic'.
56+
"""
57+
self.output_path = output_path
58+
self.max_concurrent_invocations_per_instance = max_concurrent_invocations_per_instance
59+
self.kms_key_id = kms_key_id
60+
self.notification_config = notification_config
61+
62+
def _to_request_dict(self):
63+
"""Generates a request dictionary using the parameters provided to the class."""
64+
request_dict = {
65+
"OutputConfig": {
66+
"S3OutputPath": self.output_path,
67+
},
68+
}
69+
70+
if self.max_concurrent_invocations_per_instance:
71+
request_dict["ClientConfig"] = {
72+
"MaxConcurrentInvocationsPerInstance": self.max_concurrent_invocations_per_instance
73+
}
74+
75+
if self.kms_key_id:
76+
request_dict["OutputConfig"]["KmsKeyId"] = self.kms_key_id
77+
78+
if self.notification_config:
79+
request_dict["OutputConfig"]["NotificationConfig"] = self.notification_config
80+
81+
return request_dict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for AsyncInferenceResponse"""
14+
15+
from __future__ import print_function, absolute_import
16+
17+
from botocore.exceptions import ClientError
18+
from sagemaker.s3 import parse_s3_url
19+
from sagemaker.async_inference import WaiterConfig
20+
from sagemaker.exceptions import ObjectNotExistedError, UnexpectedClientError
21+
22+
23+
class AsyncInferenceResponse(object):
24+
"""Response from Async Inference endpoint
25+
26+
This response object provides a method to check the async Amazon S3
27+
output path. If result object exists in that path, decode and return
28+
the result
29+
"""
30+
31+
def __init__(
32+
self,
33+
predictor_async,
34+
output_path,
35+
):
36+
"""Initialize an AsyncInferenceResponse object.
37+
38+
AsyncInferenceResponse can help users to get async inference result
39+
from the Amazon S3 output path
40+
41+
Args:
42+
predictor_async (sagemaker.predictor.AsyncPredictor): The ``AsyncPredictor``
43+
that return this response.
44+
output_path (str): The Amazon S3 location that endpoints upload inference responses
45+
to.
46+
"""
47+
self.predictor_async = predictor_async
48+
self.output_path = output_path
49+
self._result = None
50+
51+
def get_result(
52+
self,
53+
waiter_config=None,
54+
):
55+
"""Get result from the async Amazon S3 output path
56+
57+
Args:
58+
waiter_config (sagemaker.async_inference.waiter_config.WaiterConfig): Configuration
59+
for the waiter. The pre-defined value for the delay between poll is 15 seconds
60+
and the default max attempts is 60
61+
Raises:
62+
ValueError: If a wrong type of object is provided as ``waiter_config``
63+
Returns:
64+
object: Inference result in the given Amazon S3 output path. If a deserializer was
65+
specified when creating the AsyncPredictor, the result of the deserializer is
66+
returned. Otherwise the response returns the sequence of bytes
67+
as is.
68+
"""
69+
if waiter_config is not None and not isinstance(waiter_config, WaiterConfig):
70+
raise ValueError("waiter_config should be a WaiterConfig object")
71+
72+
if self._result is None:
73+
if waiter_config is None:
74+
self._result = self._get_result_from_s3(self.output_path)
75+
else:
76+
self._result = self.predictor_async._wait_for_output(
77+
self.output_path, waiter_config
78+
)
79+
return self._result
80+
81+
def _get_result_from_s3(
82+
self,
83+
output_path,
84+
):
85+
"""Get inference result from the output Amazon S3 path"""
86+
bucket, key = parse_s3_url(output_path)
87+
try:
88+
response = self.predictor_async.s3_client.get_object(Bucket=bucket, Key=key)
89+
return self.predictor_async.predictor._handle_response(response)
90+
except ClientError as ex:
91+
if ex.response["Error"]["Code"] == "NoSuchKey":
92+
raise ObjectNotExistedError(
93+
message="Inference could still be running",
94+
output_path=output_path,
95+
)
96+
raise UnexpectedClientError(
97+
message=ex.response["Error"]["Message"],
98+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class for WaiterConfig used in async inference
14+
15+
Use it when using async inference and wait for the result.
16+
"""
17+
18+
from __future__ import absolute_import
19+
20+
21+
class WaiterConfig(object):
22+
"""Configuration object passed in when using async inference and wait for the result."""
23+
24+
def __init__(
25+
self,
26+
max_attempts=60,
27+
delay=15,
28+
):
29+
"""Initialize a WaiterConfig object that provides parameters to control waiting behavior.
30+
31+
Args:
32+
max_attempts (int): The maximum number of attempts to be made. (Default: 60)
33+
delay (int): The amount of time in seconds to wait between attempts. (Default: 15)
34+
"""
35+
36+
self.max_attempts = max_attempts
37+
self.delay = delay
38+
39+
def _to_request_dict(self):
40+
"""Generates a dictionary using the parameters provided to the class."""
41+
waiter_dict = {
42+
"Delay": self.delay,
43+
"MaxAttempts": self.max_attempts,
44+
}
45+
46+
return waiter_dict

0 commit comments

Comments
 (0)