Skip to content

Commit fada4bf

Browse files
feature: Add support for Streaming Inference (#4497)
* feature: Add support for Streaming Inference * fix: codestyle-docs-test * fix: codestyle-docs-test
1 parent 15a40ff commit fada4bf

File tree

8 files changed

+601
-1
lines changed

8 files changed

+601
-1
lines changed

src/sagemaker/base_predictor.py

+82
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
JSONSerializer,
5353
NumpySerializer,
5454
)
55+
from sagemaker.iterators import ByteIterator
5556
from sagemaker.session import production_variant, Session
5657
from sagemaker.utils import name_from_base, stringify_object, format_tags
5758

@@ -225,6 +226,7 @@ def _create_request_args(
225226
target_variant=None,
226227
inference_id=None,
227228
custom_attributes=None,
229+
target_container_hostname=None,
228230
):
229231
"""Placeholder docstring"""
230232

@@ -286,9 +288,89 @@ def _create_request_args(
286288
if self._get_component_name():
287289
args["InferenceComponentName"] = self.component_name
288290

291+
if target_container_hostname:
292+
args["TargetContainerHostname"] = target_container_hostname
293+
289294
args["Body"] = data
290295
return args
291296

297+
def predict_stream(
298+
self,
299+
data,
300+
initial_args=None,
301+
target_variant=None,
302+
inference_id=None,
303+
custom_attributes=None,
304+
component_name: Optional[str] = None,
305+
target_container_hostname=None,
306+
iterator=ByteIterator,
307+
):
308+
"""Return the inference from the specified endpoint.
309+
310+
Args:
311+
data (object): Input data for which you want the model to provide
312+
inference. If a serializer was specified when creating the
313+
Predictor, the result of the serializer is sent as input
314+
data. Otherwise the data must be sequence of bytes, and the
315+
predict method then sends the bytes in the request body as is.
316+
initial_args (dict[str,str]): Optional. Default arguments for boto3
317+
``invoke_endpoint_with_response_stream`` call. Default is None (no default
318+
arguments). (Default: None)
319+
target_variant (str): Optional. The name of the production variant to run an inference
320+
request on (Default: None). Note that the ProductionVariant identifies the
321+
model you want to host and the resources you want to deploy for hosting it.
322+
inference_id (str): Optional. If you provide a value, it is added to the captured data
323+
when you enable data capture on the endpoint (Default: None).
324+
custom_attributes (str): Optional. Provides additional information about a request for
325+
an inference submitted to a model hosted at an Amazon SageMaker endpoint.
326+
The information is an opaque value that is forwarded verbatim. You could use this
327+
value, for example, to provide an ID that you can use to track a request or to
328+
provide other metadata that a service endpoint was programmed to process. The value
329+
must consist of no more than 1024 visible US-ASCII characters.
330+
331+
The code in your model is responsible for setting or updating any custom attributes
332+
in the response. If your code does not set this value in the response, an empty
333+
value is returned. For example, if a custom attribute represents the trace ID, your
334+
model can prepend the custom attribute with Trace ID: in your post-processing
335+
function (Default: None).
336+
component_name (str): Optional. Name of the Amazon SageMaker inference component
337+
corresponding the predictor. (Default: None)
338+
target_container_hostname (str): Optional. If the endpoint hosts multiple containers
339+
and is configured to use direct invocation, this parameter specifies the host name
340+
of the container to invoke. (Default: None).
341+
iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides
342+
an iterable interface to iterate Event stream response from Inference Endpoint.
343+
An object of the iterator class provided will be returned by the predict_stream
344+
method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in
345+
:class:`~sagemaker.iterators` or custom iterators (needs to inherit
346+
:class:`~sagemaker.iterators.BaseIterator`) can be specified as an input.
347+
348+
Returns:
349+
object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would
350+
allow iteration on EventStream response will be returned. The object would be
351+
instantiated from `predict_stream` method's `iterator` parameter.
352+
"""
353+
# [TODO]: clean up component_name in _create_request_args
354+
request_args = self._create_request_args(
355+
data=data,
356+
initial_args=initial_args,
357+
target_variant=target_variant,
358+
inference_id=inference_id,
359+
custom_attributes=custom_attributes,
360+
target_container_hostname=target_container_hostname,
361+
)
362+
363+
inference_component_name = component_name or self._get_component_name()
364+
if inference_component_name:
365+
request_args["InferenceComponentName"] = inference_component_name
366+
367+
response = (
368+
self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream(
369+
**request_args
370+
)
371+
)
372+
return iterator(response["Body"])
373+
292374
def update_endpoint(
293375
self,
294376
initial_instance_count=None,

src/sagemaker/exceptions.py

+20
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,23 @@ class AsyncInferenceModelError(AsyncInferenceError):
8686

8787
def __init__(self, message):
8888
super().__init__(message=message)
89+
90+
91+
class ModelStreamError(Exception):
92+
"""Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError"""
93+
94+
def __init__(self, message="An error occurred", code=None):
95+
self.message = message
96+
self.code = code
97+
if code is not None:
98+
super().__init__(f"{message} (Code: {code})")
99+
else:
100+
super().__init__(message)
101+
102+
103+
class InternalStreamFailure(Exception):
104+
"""Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure"""
105+
106+
def __init__(self, message="An error occurred"):
107+
self.message = message
108+
super().__init__(self.message)

src/sagemaker/iterators.py

+186
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
14+
from __future__ import absolute_import
15+
16+
from abc import ABC, abstractmethod
17+
import io
18+
19+
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
20+
21+
22+
def handle_stream_errors(chunk):
23+
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any.
24+
25+
Args:
26+
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream`
27+
response object.
28+
29+
Raises:
30+
ModelStreamError: If `ModelStreamError` error is detected in a chunk of
31+
`botocore.eventstream.EventStream` response object.
32+
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of
33+
`botocore.eventstream.EventStream` response object.
34+
"""
35+
if "ModelStreamError" in chunk:
36+
raise ModelStreamError(
37+
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"]
38+
)
39+
if "InternalStreamFailure" in chunk:
40+
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])
41+
42+
43+
class BaseIterator(ABC):
44+
"""Abstract base class for Inference Streaming iterators.
45+
46+
Provides a skeleton for customization requiring the overriding of iterator methods
47+
__iter__ and __next__.
48+
49+
Tenets of iterator class for Streaming Inference API Response
50+
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/
51+
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html):
52+
1. Needs to accept an botocore.eventstream.EventStream response.
53+
2. Needs to implement logic in __next__ to:
54+
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
55+
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
56+
2.2. If PayloadPart not in EventStream response, handle Errors
57+
[Recommended to use `iterators.handle_stream_errors` method].
58+
"""
59+
60+
def __init__(self, event_stream):
61+
"""Initialises a Iterator object to help parse the byte event stream input.
62+
63+
Args:
64+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
65+
"""
66+
self.event_stream = event_stream
67+
68+
@abstractmethod
69+
def __iter__(self):
70+
"""Abstract method, returns an iterator object itself"""
71+
return self
72+
73+
@abstractmethod
74+
def __next__(self):
75+
"""Abstract method, is responsible for returning the next element in the iteration"""
76+
77+
78+
class ByteIterator(BaseIterator):
79+
"""A helper class for parsing the byte Event Stream input to provide Byte iteration."""
80+
81+
def __init__(self, event_stream):
82+
"""Initialises a BytesIterator Iterator object
83+
84+
Args:
85+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
86+
"""
87+
super().__init__(event_stream)
88+
self.byte_iterator = iter(event_stream)
89+
90+
def __iter__(self):
91+
"""Returns an iterator object itself, which allows the object to be iterated.
92+
93+
Returns:
94+
iter : object
95+
An iterator object representing the iterable.
96+
"""
97+
return self
98+
99+
def __next__(self):
100+
"""Returns the next chunk of Byte directly."""
101+
# Even with "while True" loop the function still behaves like a generator
102+
# and sends the next new byte chunk.
103+
while True:
104+
chunk = next(self.byte_iterator)
105+
if "PayloadPart" not in chunk:
106+
# handle API response errors and force terminate.
107+
handle_stream_errors(chunk)
108+
# print and move on to next response byte
109+
print("Unknown event type:" + chunk)
110+
continue
111+
return chunk["PayloadPart"]["Bytes"]
112+
113+
114+
class LineIterator(BaseIterator):
115+
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
116+
117+
def __init__(self, event_stream):
118+
"""Initialises a LineIterator Iterator object
119+
120+
Args:
121+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
122+
"""
123+
super().__init__(event_stream)
124+
self.byte_iterator = iter(self.event_stream)
125+
self.buffer = io.BytesIO()
126+
self.read_pos = 0
127+
128+
def __iter__(self):
129+
"""Returns an iterator object itself, which allows the object to be iterated.
130+
131+
Returns:
132+
iter : object
133+
An iterator object representing the iterable.
134+
"""
135+
return self
136+
137+
def __next__(self):
138+
r"""Returns the next Line for an Line iterable.
139+
140+
The output of the event stream will be in the following format:
141+
142+
```
143+
b'{"outputs": [" a"]}\n'
144+
b'{"outputs": [" challenging"]}\n'
145+
b'{"outputs": [" problem"]}\n'
146+
...
147+
```
148+
149+
While usually each PayloadPart event from the event stream will contain a byte array
150+
with a full json, this is not guaranteed and some of the json objects may be split across
151+
PayloadPart events. For example:
152+
```
153+
{'PayloadPart': {'Bytes': b'{"outputs": '}}
154+
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
155+
```
156+
157+
This class accounts for this by concatenating bytes written via the 'write' function
158+
and then exposing a method which will return lines (ending with a '\n' character) within
159+
the buffer via the 'scan_lines' function. It maintains the position of the last read
160+
position to ensure that previous bytes are not exposed again.
161+
162+
Returns:
163+
str: Read and return one line from the event stream.
164+
"""
165+
# Even with "while True" loop the function still behaves like a generator
166+
# and sends the next new concatenated line
167+
while True:
168+
self.buffer.seek(self.read_pos)
169+
line = self.buffer.readline()
170+
if line and line[-1] == ord("\n"):
171+
self.read_pos += len(line)
172+
return line[:-1]
173+
try:
174+
chunk = next(self.byte_iterator)
175+
except StopIteration:
176+
if self.read_pos < self.buffer.getbuffer().nbytes:
177+
continue
178+
raise
179+
if "PayloadPart" not in chunk:
180+
# handle API response errors and force terminate.
181+
handle_stream_errors(chunk)
182+
# print and move on to next response byte
183+
print("Unknown event type:" + chunk)
184+
continue
185+
self.buffer.seek(0, io.SEEK_END)
186+
self.buffer.write(chunk["PayloadPart"]["Bytes"])

src/sagemaker/jumpstart/cache.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def __init__(
6565
self,
6666
region: Optional[str] = None,
6767
max_s3_cache_items: int = JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
68-
s3_cache_expiration_horizon: datetime.timedelta = JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON,
68+
s3_cache_expiration_horizon: datetime.timedelta = (
69+
JUMPSTART_DEFAULT_S3_CACHE_EXPIRATION_HORIZON
70+
),
6971
max_semantic_version_cache_items: int = JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
7072
semantic_version_cache_expiration_horizon: datetime.timedelta = (
7173
JUMPSTART_DEFAULT_SEMANTIC_VERSION_CACHE_EXPIRATION_HORIZON
382 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)