Skip to content

Commit 211b2ed

Browse files
fix: codestyle-docs-test
1 parent 96484fa commit 211b2ed

File tree

5 files changed

+162
-37
lines changed

5 files changed

+162
-37
lines changed

src/sagemaker/base_predictor.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
JSONSerializer,
5353
NumpySerializer,
5454
)
55-
from sagemaker.iterators import LineIterator
55+
from sagemaker.iterators import ByteIterator
5656
from sagemaker.session import production_variant, Session
5757
from sagemaker.utils import name_from_base, stringify_object, format_tags
5858

@@ -303,7 +303,7 @@ def predict_stream(
303303
custom_attributes=None,
304304
component_name: Optional[str] = None,
305305
target_container_hostname=None,
306-
iterator=LineIterator,
306+
iterator=ByteIterator,
307307
):
308308
"""Return the inference from the specified endpoint.
309309
@@ -315,14 +315,14 @@ def predict_stream(
315315
predict method then sends the bytes in the request body as is.
316316
initial_args (dict[str,str]): Optional. Default arguments for boto3
317317
``invoke_endpoint_with_response_stream`` call. Default is None (no default
318-
arguments).
319-
target_variant (str): The name of the production variant to run an inference
318+
arguments). (Default: None)
319+
target_variant (str): Optional. The name of the production variant to run an inference
320320
request on (Default: None). Note that the ProductionVariant identifies the
321321
model you want to host and the resources you want to deploy for hosting it.
322-
inference_id (str): If you provide a value, it is added to the captured data
322+
inference_id (str): Optional. If you provide a value, it is added to the captured data
323323
when you enable data capture on the endpoint (Default: None).
324-
custom_attributes (str): Provides additional information about a request for an
325-
inference submitted to a model hosted at an Amazon SageMaker endpoint.
324+
custom_attributes (str): Optional. Provides additional information about a request for
325+
an inference submitted to a model hosted at an Amazon SageMaker endpoint.
326326
The information is an opaque value that is forwarded verbatim. You could use this
327327
value, for example, to provide an ID that you can use to track a request or to
328328
provide other metadata that a service endpoint was programmed to process. The value
@@ -334,14 +334,14 @@ def predict_stream(
334334
model can prepend the custom attribute with Trace ID: in your post-processing
335335
function (Default: None).
336336
component_name (str): Optional. Name of the Amazon SageMaker inference component
337-
corresponding the predictor.
338-
target_container_hostname (str): If the endpoint hosts multiple containers and is
339-
configured to use direct invocation, this parameter specifies the host name of the
340-
container to invoke. (Default: None).
337+
corresponding the predictor. (Default: None)
338+
target_container_hostname (str): Optional. If the endpoint hosts multiple containers
339+
and is configured to use direct invocation, this parameter specifies the host name
340+
of the container to invoke. (Default: None).
341341
iterator (:class:`~sagemaker.iterators.BaseIterator`): An iterator class which provides
342-
an iterable interface to deserialize a stream response from Inference Endpoint.
342+
an iterable interface to iterate Event stream response from Inference Endpoint.
343343
An object of the iterator class provided will be returned by the predict_stream
344-
method (Default::class:`~sagemaker.iterators.LineIterator`). Iterators defined in
344+
method (Default::class:`~sagemaker.iterators.ByteIterator`). Iterators defined in
345345
:class:`~sagemaker.iterators` or custom iterators (needs to inherit
346346
:class:`~sagemaker.iterators.BaseIterator`) can be specified as an input.
347347

src/sagemaker/exceptions.py

+4
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,8 @@ def __init__(self, message):
8989

9090

9191
class ModelStreamError(Exception):
92+
"""Raised when invoke_endpoint_with_response_stream Response returns ModelStreamError"""
93+
9294
def __init__(self, message="An error occurred", code=None):
9395
self.message = message
9496
self.code = code
@@ -99,6 +101,8 @@ def __init__(self, message="An error occurred", code=None):
99101

100102

101103
class InternalStreamFailure(Exception):
104+
"""Raised when invoke_endpoint_with_response_stream Response returns InternalStreamFailure"""
105+
102106
def __init__(self, message="An error occurred"):
103107
self.message = message
104108
super().__init__(self.message)

src/sagemaker/iterators.py

+54-22
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def handle_stream_errors(chunk):
4141

4242

4343
class BaseIterator(ABC):
44-
"""Abstract base class for creation of new iterators.
44+
"""Abstract base class for Inference Streaming iterators.
4545
4646
Provides a skeleton for customization requiring the overriding of iterator methods
4747
__iter__ and __next__.
@@ -53,45 +53,75 @@ class BaseIterator(ABC):
5353
2. Needs to implement logic in __next__ to:
5454
2.1. Concatenate and provide next chunk of response from botocore.eventstream.EventStream.
5555
While doing so parse the response_chunk["PayloadPart"]["Bytes"].
56-
2.2. Perform deserialization of response chunk based on expected response type.
57-
2.3. If PayloadPart not in EventStream response, handle Errors.
56+
2.2. If PayloadPart not in EventStream response, handle Errors
57+
[Recommended to use `iterators.handle_stream_errors` method].
5858
"""
5959

60-
def __init__(self, stream):
60+
def __init__(self, event_stream):
6161
"""Initialises a Iterator object to help parse the byte event stream input.
6262
6363
Args:
64-
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
64+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
6565
"""
66-
self.stream = stream
66+
self.event_stream = event_stream
6767

6868
@abstractmethod
6969
def __iter__(self):
70-
"""Abstract __iter__ method, returns an iterator object itself"""
70+
"""Abstract method, returns an iterator object itself"""
7171
return self
7272

7373
@abstractmethod
7474
def __next__(self):
75-
"""Abstract __next__ method, is responsible for returning the next element in the
76-
iteration"""
77-
pass
75+
"""Abstract method, is responsible for returning the next element in the iteration"""
76+
77+
78+
class ByteIterator(BaseIterator):
79+
"""A helper class for parsing the byte Event Stream input to provide Byte iteration."""
80+
81+
def __init__(self, event_stream):
82+
"""Initialises a BytesIterator Iterator object
83+
84+
Args:
85+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
86+
"""
87+
super().__init__(event_stream)
88+
self.byte_iterator = iter(event_stream)
89+
90+
def __iter__(self):
91+
"""Returns an iterator object itself, which allows the object to be iterated.
92+
93+
Returns:
94+
iter : object
95+
An iterator object representing the iterable.
96+
"""
97+
return self
98+
99+
def __next__(self):
100+
"""Returns the next chunk of Byte directly."""
101+
# Even with "while True" loop the function still behaves like a generator
102+
# and sends the next new byte chunk.
103+
while True:
104+
chunk = next(self.byte_iterator)
105+
if "PayloadPart" not in chunk:
106+
# handle API response errors and force terminate.
107+
handle_stream_errors(chunk)
108+
# print and move on to next response byte
109+
print("Unknown event type:" + chunk)
110+
continue
111+
return chunk["PayloadPart"]["Bytes"]
78112

79113

80114
class LineIterator(BaseIterator):
81-
"""
82-
A helper class for parsing the byte stream input and provide iteration on lines with
83-
'\n' separators.
84-
"""
115+
"""A helper class for parsing the byte Event Stream input to provide Line iteration."""
85116

86-
def __init__(self, stream):
87-
"""Initialises a Iterator object to help parse the byte stream input and
88-
provide iteration on lines with '\n' separators
117+
def __init__(self, event_stream):
118+
"""Initialises a LineIterator Iterator object
89119
90120
Args:
91-
stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
121+
event_stream: (botocore.eventstream.EventStream): Event Stream object to be iterated.
92122
"""
93-
super().__init__(stream)
94-
self.byte_iterator = iter(self.stream)
123+
super().__init__(event_stream)
124+
self.byte_iterator = iter(self.event_stream)
95125
self.buffer = io.BytesIO()
96126
self.read_pos = 0
97127

@@ -105,7 +135,8 @@ def __iter__(self):
105135
return self
106136

107137
def __next__(self):
108-
"""
138+
r"""Returns the next Line for an Line iterable.
139+
109140
The output of the event stream will be in the following format:
110141
111142
```
@@ -146,8 +177,9 @@ def __next__(self):
146177
continue
147178
raise
148179
if "PayloadPart" not in chunk:
149-
# handle errors within API Response if any.
180+
# handle API response errors and force terminate.
150181
handle_stream_errors(chunk)
182+
# print and move on to next response byte
151183
print("Unknown event type:" + chunk)
152184
continue
153185
self.buffer.seek(0, io.SEEK_END)

tests/integ/test_predict_stream.py

+21
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,24 @@ def test_predict_stream(sagemaker_session, endpoint_name):
8686
response += resp.get("outputs")[0]
8787

8888
assert "AWS stands for Amazon Web Services." in response
89+
90+
data = {"inputs": "what does AWS stand for?", "parameters": {"max_new_tokens": 400}}
91+
initial_args = {"ContentType": "application/json"}
92+
predictor = Predictor(
93+
endpoint_name=endpoint_name,
94+
sagemaker_session=sagemaker_session,
95+
)
96+
97+
# Validate that no exception is raised when the target_variant is specified.
98+
# uses the default `sagemaker.iterator.ByteIterator`
99+
stream_iterator = predictor.predict_stream(
100+
data=json.dumps(data),
101+
initial_args=initial_args,
102+
)
103+
104+
response = ""
105+
for line in stream_iterator:
106+
resp = json.loads(line)
107+
response += resp.get("outputs")[0]
108+
109+
assert "AWS stands for Amazon Web Services." in response

tests/unit/sagemaker/iterators/test_iterators.py

+70-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
115
import unittest
216
from unittest.mock import MagicMock
317

418
from sagemaker.exceptions import ModelStreamError, InternalStreamFailure
5-
from sagemaker.iterators import LineIterator
19+
from sagemaker.iterators import ByteIterator, LineIterator
620

721

8-
class TestLineIterator(unittest.TestCase):
22+
class TestByteIterator(unittest.TestCase):
923
def test_iteration_with_payload_parts(self):
1024
# Mocking the stream object
1125
self.stream = MagicMock()
@@ -14,6 +28,60 @@ def test_iteration_with_payload_parts(self):
1428
{"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}},
1529
{"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}},
1630
]
31+
self.iterator = ByteIterator(self.stream)
32+
33+
lines = list(self.iterator)
34+
expected_lines = [
35+
b'{"outputs": [" a"]}\n',
36+
b'{"outputs": [" challenging"]}\n',
37+
b'{"outputs": [" problem"]}\n',
38+
]
39+
self.assertEqual(lines, expected_lines)
40+
41+
def test_iteration_with_model_stream_error(self):
42+
# Mocking the stream object
43+
self.stream = MagicMock()
44+
self.stream.__iter__.return_value = [
45+
{"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}},
46+
{"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}},
47+
{"ModelStreamError": {"Message": "Error message", "ErrorCode": "500"}},
48+
{"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}},
49+
]
50+
self.iterator = ByteIterator(self.stream)
51+
52+
with self.assertRaises(ModelStreamError) as e:
53+
list(self.iterator)
54+
55+
self.assertEqual(str(e.exception.message), "Error message")
56+
self.assertEqual(str(e.exception.code), "500")
57+
58+
def test_iteration_with_internal_stream_failure(self):
59+
# Mocking the stream object
60+
self.stream = MagicMock()
61+
self.stream.__iter__.return_value = [
62+
{"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}},
63+
{"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}},
64+
{"InternalStreamFailure": {"Message": "Error internal stream failure"}},
65+
{"PayloadPart": {"Bytes": b'{"outputs": [" problem"]}\n'}},
66+
]
67+
self.iterator = ByteIterator(self.stream)
68+
69+
with self.assertRaises(InternalStreamFailure) as e:
70+
list(self.iterator)
71+
72+
self.assertEqual(str(e.exception.message), "Error internal stream failure")
73+
74+
75+
class TestLineIterator(unittest.TestCase):
76+
def test_iteration_with_payload_parts(self):
77+
# Mocking the stream object
78+
self.stream = MagicMock()
79+
self.stream.__iter__.return_value = [
80+
{"PayloadPart": {"Bytes": b'{"outputs": [" a"]}\n'}},
81+
{"PayloadPart": {"Bytes": b'{"outputs": [" challenging"]}\n'}},
82+
{"PayloadPart": {"Bytes": b'{"outputs": '}},
83+
{"PayloadPart": {"Bytes": b'[" problem"]}\n'}},
84+
]
1785
self.iterator = LineIterator(self.stream)
1886

1987
lines = list(self.iterator)

0 commit comments

Comments
 (0)