Skip to content

Commit 936e9a1

Browse files
mufaddal-rohawalaakrishna1995
authored andcommitted
feature: Add support for Streaming Inference
1 parent 5828ad4 commit 936e9a1

File tree

7 files changed

+473
-0
lines changed

7 files changed

+473
-0
lines changed

src/sagemaker/base_predictor.py

+82
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
JSONSerializer,
5353
NumpySerializer,
5454
)
55+
from sagemaker.iterators import LineIterator
5556
from sagemaker.session import production_variant, Session
5657
from sagemaker.utils import name_from_base, stringify_object, format_tags
5758

@@ -225,6 +226,7 @@ def _create_request_args(
225226
target_variant=None,
226227
inference_id=None,
227228
custom_attributes=None,
229+
target_container_hostname=None,
228230
):
229231
"""Placeholder docstring"""
230232

@@ -286,9 +288,89 @@ def _create_request_args(
286288
if self._get_component_name():
287289
args["InferenceComponentName"] = self.component_name
288290

291+
if target_container_hostname:
292+
args["TargetContainerHostname"] = target_container_hostname
293+
289294
args["Body"] = data
290295
return args
291296

297+
def predict_stream(
298+
self,
299+
data,
300+
initial_args=None,
301+
target_variant=None,
302+
inference_id=None,
303+
custom_attributes=None,
304+
component_name: Optional[str] = None,
305+
target_container_hostname=None,
306+
iterator=LineIterator,
307+
):
308+
"""Return the inference from the specified endpoint.
309+
310+
Args:
311+
data (object): Input data for which you want the model to provide
312+
inference. If a serializer was specified when creating the
313+
Predictor, the result of the serializer is sent as input
314+
data. Otherwise the data must be sequence of bytes, and the
315+
predict method then sends the bytes in the request body as is.
316+
initial_args (dict[str,str]): Optional. Default arguments for boto3
317+
``invoke_endpoint_with_response_stream`` call. Default is None (no default
318+
arguments).
319+
target_variant (str): The name of the production variant to run an inference
320+
request on (Default: None). Note that the ProductionVariant identifies the
321+
model you want to host and the resources you want to deploy for hosting it.
322+
inference_id (str): If you provide a value, it is added to the captured data
323+
when you enable data capture on the endpoint (Default: None).
324+
custom_attributes (str): Provides additional information about a request for an
325+
inference submitted to a model hosted at an Amazon SageMaker endpoint.
326+
The information is an opaque value that is forwarded verbatim. You could use this
327+
value, for example, to provide an ID that you can use to track a request or to
328+
provide other metadata that a service endpoint was programmed to process. The value
329+
must consist of no more than 1024 visible US-ASCII characters.
330+
331+
The code in your model is responsible for setting or updating any custom attributes
332+
in the response. If your code does not set this value in the response, an empty
333+
value is returned. For example, if a custom attribute represents the trace ID, your
334+
model can prepend the custom attribute with Trace ID: in your post-processing
335+
function (Default: None).
336+
component_name (str): Optional. Name of the Amazon SageMaker inference component
337+
corresponding the predictor.
338+
target_container_hostname (str): If the endpoint hosts multiple containers and is
339+
configured to use direct invocation, this parameter specifies the host name of the
340+
container to invoke. (Default: None).
341+
iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides
342+
an iterable interface to deserialize a stream response from Inference Endpoint.
343+
An object of the iterator class provided will be returned by the predict_stream
344+
method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in
345+
:class:`~sagemaker.iterators` or custom iterators (needs to inherit
346+
:class:`~sagemaker.iterators.BaseIterator`) can be specified as an input.
347+
348+
Returns:
349+
object (:class:`~sagemaker.iterators.BaseIterator`): An iterator object which would
350+
allow iteration on EventStream response will be returned. The object would be
351+
instantiated from `predict_stream` method's `iterator` parameter.
352+
"""
353+
# [TODO]: clean up component_name in _create_request_args
354+
request_args = self._create_request_args(
355+
data=data,
356+
initial_args=initial_args,
357+
target_variant=target_variant,
358+
inference_id=inference_id,
359+
custom_attributes=custom_attributes,
360+
target_container_hostname=target_container_hostname,
361+
)
362+
363+
inference_component_name = component_name or self._get_component_name()
364+
if inference_component_name:
365+
request_args["InferenceComponentName"] = inference_component_name
366+
367+
response = (
368+
self.sagemaker_session.sagemaker_runtime_client.invoke_endpoint_with_response_stream(
369+
**request_args
370+
)
371+
)
372+
return iterator(response["Body"])
373+
292374
def update_endpoint(
293375
self,
294376
initial_instance_count=None,

src/sagemaker/exceptions.py

+16
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,19 @@ class AsyncInferenceModelError(AsyncInferenceError):
8686

8787
def __init__(self, message):
8888
super().__init__(message=message)
89+
90+
91+
class ModelStreamError(Exception):
92+
def __init__(self, message="An error occurred", code=None):
93+
self.message = message
94+
self.code = code
95+
if code is not None:
96+
super().__init__(f"{message} (Code: {code})")
97+
else:
98+
super().__init__(message)
99+
100+
101+
class InternalStreamFailure(Exception):
102+
def __init__(self, message="An error occurred"):
103+
self.message = message
104+
super().__init__(self.message)

src/sagemaker/iterators.py

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements iterators for deserializing data returned from an inference streaming endpoint."""
14+
from __future__ import absolute_import
15+
16+
from abc import ABC, abstractmethod
17+
import io
18+
19+
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
20+
21+
22+
def handle_stream_errors(chunk):
23+
"""Handle API Response errors within `invoke_endpoint_with_response_stream` API if any.
24+
25+
Args:
26+
chunk (dict): A chunk of response received as part of `botocore.eventstream.EventStream`
27+
response object.
28+
29+
Raises:
30+
ModelStreamError: If `ModelStreamError` error is detected in a chunk of
31+
`botocore.eventstream.EventStream` response object.
32+
InternalStreamFailure: If `InternalStreamFailure` error is detected in a chunk of
33+
`botocore.eventstream.EventStream` response object.
34+
"""
35+
if "ModelStreamError" in chunk:
36+
raise ModelStreamError(
37+
chunk["ModelStreamError"]["Message"], code=chunk["ModelStreamError"]["ErrorCode"]
38+
)
39+
if "InternalStreamFailure" in chunk:
40+
raise InternalStreamFailure(chunk["InternalStreamFailure"]["Message"])
41+
42+
43+
class BaseIterator(ABC):
44+
"""Abstract base class for creation of new iterators.
45+
46+
Provides a skeleton for customization requiring the overriding of iterator methods
47+
__iter__ and __next__.
48+
49+
Tenets of iterator class for Streaming Inference API Response
50+
(https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/
51+
sagemaker-runtime/client/invoke_endpoint_with_response_stream.html):
52+
1. Needs to accept an botocore.eventstream.EventStream response.
53+
2. Needs to implement logic in __next__ to:
54+
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
55+
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
56+
2.2. Perform deserialization of response chunk based on expected response type.
57+
2.3. If PayloadPart not in EventStream response, handle Errors.
58+
"""
59+
60+
def __init__(self, stream):
61+
"""Initialises a Iterator object to help parse the byte event stream input.
62+
63+
Args:
64+
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
65+
"""
66+
self.stream = stream
67+
68+
@abstractmethod
69+
def __iter__(self):
70+
"""Abstract __iter__ method, returns an iterator object itself"""
71+
return self
72+
73+
@abstractmethod
74+
def __next__(self):
75+
"""Abstract __next__ method, is responsible for returning the next element in the
76+
iteration"""
77+
pass
78+
79+
80+
class LineIterator(BaseIterator):
81+
"""
82+
A helper class for parsing the byte stream input and provide iteration on lines with
83+
'\n' separators.
84+
"""
85+
86+
def __init__(self, stream):
87+
"""Initialises a Iterator object to help parse the byte stream input and
88+
provide iteration on lines with '\n' separators
89+
90+
Args:
91+
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
92+
"""
93+
super().__init__(stream)
94+
self.byte_iterator = iter(self.stream)
95+
self.buffer = io.BytesIO()
96+
self.read_pos = 0
97+
98+
def __iter__(self):
99+
"""Returns an iterator object itself, which allows the object to be iterated.
100+
101+
Returns:
102+
iter : object
103+
An iterator object representing the iterable.
104+
"""
105+
return self
106+
107+
def __next__(self):
108+
"""
109+
The output of the event stream will be in the following format:
110+
111+
```
112+
b'{"outputs": [" a"]}\n'
113+
b'{"outputs": [" challenging"]}\n'
114+
b'{"outputs": [" problem"]}\n'
115+
...
116+
```
117+
118+
While usually each PayloadPart event from the event stream will contain a byte array
119+
with a full json, this is not guaranteed and some of the json objects may be split across
120+
PayloadPart events. For example:
121+
```
122+
{'PayloadPart': {'Bytes': b'{"outputs": '}}
123+
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
124+
```
125+
126+
This class accounts for this by concatenating bytes written via the 'write' function
127+
and then exposing a method which will return lines (ending with a '\n' character) within
128+
the buffer via the 'scan_lines' function. It maintains the position of the last read
129+
position to ensure that previous bytes are not exposed again.
130+
131+
Returns:
132+
str: Read and return one line from the event stream.
133+
"""
134+
# Even with "while True" loop the function still behaves like a generator
135+
# and sends the next new concatenated line
136+
while True:
137+
self.buffer.seek(self.read_pos)
138+
line = self.buffer.readline()
139+
if line and line[-1] == ord("\n"):
140+
self.read_pos += len(line)
141+
return line[:-1]
142+
try:
143+
chunk = next(self.byte_iterator)
144+
except StopIteration:
145+
if self.read_pos < self.buffer.getbuffer().nbytes:
146+
continue
147+
raise
148+
if "PayloadPart" not in chunk:
149+
# handle errors within API Response if any.
150+
handle_stream_errors(chunk)
151+
print("Unknown event type:" + chunk)
152+
continue
153+
self.buffer.seek(0, io.SEEK_END)
154+
self.buffer.write(chunk["PayloadPart"]["Bytes"])
382 Bytes
Binary file not shown.

tests/integ/test_predict_stream.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
import pytest
18+
19+
import tests.integ
20+
import tests.integ.timeout
21+
22+
from sagemaker import image_uris
23+
from sagemaker.iterators import LineIterator
24+
from sagemaker.model import Model
25+
from sagemaker.predictor import Predictor
26+
from sagemaker.utils import unique_name_from_base
27+
28+
from tests.integ import DATA_DIR
29+
30+
31+
ROLE = "SageMakerRole"
32+
INSTANCE_COUNT = 1
33+
INSTANCE_TYPE = "ml.g5.2xlarge"
34+
LMI_FALCON_7B_DATA_PATH = os.path.join(DATA_DIR, "lmi-model-falcon-7b")
35+
36+
37+
@pytest.yield_fixture(scope="module")
38+
def endpoint_name(sagemaker_session):
39+
lmi_endpoint_name = unique_name_from_base("lmi-model-falcon-7b")
40+
model_data = sagemaker_session.upload_data(
41+
path=os.path.join(LMI_FALCON_7B_DATA_PATH, "mymodel-7B.tar.gz"),
42+
key_prefix="large-model-lmi/code",
43+
)
44+
45+
image_uri = image_uris.retrieve(
46+
framework="djl-deepspeed", region=sagemaker_session.boto_region_name, version="0.23.0"
47+
)
48+
49+
with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
50+
endpoint_name=lmi_endpoint_name, sagemaker_session=sagemaker_session, hours=2
51+
):
52+
lmi_model = Model(
53+
sagemaker_session=sagemaker_session,
54+
model_data=model_data,
55+
image_uri=image_uri,
56+
name=lmi_endpoint_name, # model name
57+
role=ROLE,
58+
)
59+
lmi_model.deploy(
60+
INSTANCE_COUNT,
61+
INSTANCE_TYPE,
62+
endpoint_name=lmi_endpoint_name,
63+
container_startup_health_check_timeout=900,
64+
)
65+
yield lmi_endpoint_name
66+
67+
68+
def test_predict_stream(sagemaker_session, endpoint_name):
69+
data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}}
70+
initial_args = {"ContentType": "application/json"}
71+
predictor = Predictor(
72+
endpoint_name=endpoint_name,
73+
sagemaker_session=sagemaker_session,
74+
)
75+
76+
# Validate that no exception is raised when the target_variant is specified.
77+
stream_iterator = predictor.predict_stream(
78+
data=json.dumps(data),
79+
initial_args=initial_args,
80+
iterator=LineIterator,
81+
)
82+
83+
response = ""
84+
for line in stream_iterator:
85+
resp = json.loads(line)
86+
response += resp.get("outputs")[0]
87+
88+
assert "AWS stands for Amazon Web Services." in response

0 commit comments

Comments
 (0)