Skip to content

Commit 25ad8b5

Browse files
committed
chore: remove functional import
1 parent 343ddd2 commit 25ad8b5

File tree

9 files changed

+764
-722
lines changed

9 files changed

+764
-722
lines changed

src/sagemaker/base_deserializers.py

+324
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements base methods for deserializing data returned from an inference endpoint."""
14+
from __future__ import absolute_import
15+
16+
import csv
17+
18+
import abc
19+
import codecs
20+
import io
21+
import json
22+
23+
import numpy as np
24+
from six import with_metaclass
25+
26+
from sagemaker.utils import DeferredError
27+
28+
try:
29+
import pandas
30+
except ImportError as e:
31+
pandas = DeferredError(e)
32+
33+
34+
class BaseDeserializer(abc.ABC):
35+
"""Abstract base class for creation of new deserializers.
36+
37+
Provides a skeleton for customization requiring the overriding of the method
38+
deserialize and the class attribute ACCEPT.
39+
"""
40+
41+
@abc.abstractmethod
42+
def deserialize(self, stream, content_type):
43+
"""Deserialize data received from an inference endpoint.
44+
45+
Args:
46+
stream (botocore.response.StreamingBody): Data to be deserialized.
47+
content_type (str): The MIME type of the data.
48+
49+
Returns:
50+
object: The data deserialized into an object.
51+
"""
52+
53+
@property
54+
@abc.abstractmethod
55+
def ACCEPT(self):
56+
"""The content types that are expected from the inference endpoint."""
57+
58+
59+
class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)):
60+
"""Abstract base class for creation of new deserializers.
61+
62+
This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
63+
user-friendly options for setting the ACCEPT content type header, in situations where it can be
64+
provided at init and freely updated.
65+
"""
66+
67+
def __init__(self, accept="*/*"):
68+
"""Initialize a ``SimpleBaseDeserializer`` instance.
69+
70+
Args:
71+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
72+
is expected from the inference endpoint (default: "*/*").
73+
"""
74+
super(SimpleBaseDeserializer, self).__init__()
75+
self.accept = accept
76+
77+
@property
78+
def ACCEPT(self):
79+
"""The tuple of possible content types that are expected from the inference endpoint."""
80+
if isinstance(self.accept, str):
81+
return (self.accept,)
82+
return self.accept
83+
84+
85+
class StringDeserializer(SimpleBaseDeserializer):
86+
"""Deserialize data from an inference endpoint into a decoded string."""
87+
88+
def __init__(self, encoding="UTF-8", accept="application/json"):
89+
"""Initialize a ``StringDeserializer`` instance.
90+
91+
Args:
92+
encoding (str): The string encoding to use (default: UTF-8).
93+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
94+
is expected from the inference endpoint (default: "application/json").
95+
"""
96+
super(StringDeserializer, self).__init__(accept=accept)
97+
self.encoding = encoding
98+
99+
def deserialize(self, stream, content_type):
100+
"""Deserialize data from an inference endpoint into a decoded string.
101+
102+
Args:
103+
stream (botocore.response.StreamingBody): Data to be deserialized.
104+
content_type (str): The MIME type of the data.
105+
106+
Returns:
107+
str: The data deserialized into a decoded string.
108+
"""
109+
try:
110+
return stream.read().decode(self.encoding)
111+
finally:
112+
stream.close()
113+
114+
115+
class BytesDeserializer(SimpleBaseDeserializer):
116+
"""Deserialize a stream of bytes into a bytes object."""
117+
118+
def deserialize(self, stream, content_type):
119+
"""Read a stream of bytes returned from an inference endpoint.
120+
121+
Args:
122+
stream (botocore.response.StreamingBody): A stream of bytes.
123+
content_type (str): The MIME type of the data.
124+
125+
Returns:
126+
bytes: The bytes object read from the stream.
127+
"""
128+
try:
129+
return stream.read()
130+
finally:
131+
stream.close()
132+
133+
134+
class CSVDeserializer(SimpleBaseDeserializer):
135+
"""Deserialize a stream of bytes into a list of lists.
136+
137+
Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
138+
:class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
139+
responses directly into other data types.
140+
"""
141+
142+
def __init__(self, encoding="utf-8", accept="text/csv"):
143+
"""Initialize a ``CSVDeserializer`` instance.
144+
145+
Args:
146+
encoding (str): The string encoding to use (default: "utf-8").
147+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
148+
is expected from the inference endpoint (default: "text/csv").
149+
"""
150+
super(CSVDeserializer, self).__init__(accept=accept)
151+
self.encoding = encoding
152+
153+
def deserialize(self, stream, content_type):
154+
"""Deserialize data from an inference endpoint into a list of lists.
155+
156+
Args:
157+
stream (botocore.response.StreamingBody): Data to be deserialized.
158+
content_type (str): The MIME type of the data.
159+
160+
Returns:
161+
list: The data deserialized into a list of lists representing the
162+
contents of a CSV file.
163+
"""
164+
try:
165+
decoded_string = stream.read().decode(self.encoding)
166+
return list(csv.reader(decoded_string.splitlines()))
167+
finally:
168+
stream.close()
169+
170+
171+
class StreamDeserializer(SimpleBaseDeserializer):
172+
"""Directly return the data and content-type received from an inference endpoint.
173+
174+
It is the user's responsibility to close the data stream once they're done
175+
reading it.
176+
"""
177+
178+
def deserialize(self, stream, content_type):
179+
"""Returns a stream of the response body and the MIME type of the data.
180+
181+
Args:
182+
stream (botocore.response.StreamingBody): A stream of bytes.
183+
content_type (str): The MIME type of the data.
184+
185+
Returns:
186+
tuple: A two-tuple containing the stream and content-type.
187+
"""
188+
return stream, content_type
189+
190+
191+
class NumpyDeserializer(SimpleBaseDeserializer):
192+
"""Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array."""
193+
194+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
195+
"""Initialize a ``NumpyDeserializer`` instance.
196+
197+
Args:
198+
dtype (str): The dtype of the data (default: None).
199+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
200+
is expected from the inference endpoint (default: "application/x-npy").
201+
allow_pickle (bool): Allow loading pickled object arrays (default: True).
202+
"""
203+
super(NumpyDeserializer, self).__init__(accept=accept)
204+
self.dtype = dtype
205+
self.allow_pickle = allow_pickle
206+
207+
def deserialize(self, stream, content_type):
208+
"""Deserialize data from an inference endpoint into a NumPy array.
209+
210+
Args:
211+
stream (botocore.response.StreamingBody): Data to be deserialized.
212+
content_type (str): The MIME type of the data.
213+
214+
Returns:
215+
numpy.ndarray: The data deserialized into a NumPy array.
216+
"""
217+
try:
218+
if content_type == "text/csv":
219+
return np.genfromtxt(
220+
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
221+
)
222+
if content_type == "application/json":
223+
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
224+
if content_type == "application/x-npy":
225+
return np.load(io.BytesIO(stream.read()), allow_pickle=self.allow_pickle)
226+
finally:
227+
stream.close()
228+
229+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
230+
231+
232+
class JSONDeserializer(SimpleBaseDeserializer):
233+
"""Deserialize JSON data from an inference endpoint into a Python object."""
234+
235+
def __init__(self, accept="application/json"):
236+
"""Initialize a ``JSONDeserializer`` instance.
237+
238+
Args:
239+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
240+
is expected from the inference endpoint (default: "application/json").
241+
"""
242+
super(JSONDeserializer, self).__init__(accept=accept)
243+
244+
def deserialize(self, stream, content_type):
245+
"""Deserialize JSON data from an inference endpoint into a Python object.
246+
247+
Args:
248+
stream (botocore.response.StreamingBody): Data to be deserialized.
249+
content_type (str): The MIME type of the data.
250+
251+
Returns:
252+
object: The JSON-formatted data deserialized into a Python object.
253+
"""
254+
try:
255+
return json.load(codecs.getreader("utf-8")(stream))
256+
finally:
257+
stream.close()
258+
259+
260+
class PandasDeserializer(SimpleBaseDeserializer):
261+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
262+
263+
def __init__(self, accept=("text/csv", "application/json")):
264+
"""Initialize a ``PandasDeserializer`` instance.
265+
266+
Args:
267+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
268+
is expected from the inference endpoint (default: ("text/csv","application/json")).
269+
"""
270+
super(PandasDeserializer, self).__init__(accept=accept)
271+
272+
def deserialize(self, stream, content_type):
273+
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.
274+
275+
If the data is JSON, the data should be formatted in the 'columns' orient.
276+
See https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_json.html
277+
278+
Args:
279+
stream (botocore.response.StreamingBody): Data to be deserialized.
280+
content_type (str): The MIME type of the data.
281+
282+
Returns:
283+
pandas.DataFrame: The data deserialized into a pandas DataFrame.
284+
"""
285+
if content_type == "text/csv":
286+
return pandas.read_csv(stream)
287+
288+
if content_type == "application/json":
289+
return pandas.read_json(stream)
290+
291+
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
292+
293+
294+
class JSONLinesDeserializer(SimpleBaseDeserializer):
295+
"""Deserialize JSON lines data from an inference endpoint."""
296+
297+
def __init__(self, accept="application/jsonlines"):
298+
"""Initialize a ``JSONLinesDeserializer`` instance.
299+
300+
Args:
301+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
302+
is expected from the inference endpoint (default: ("text/csv","application/json")).
303+
"""
304+
super(JSONLinesDeserializer, self).__init__(accept=accept)
305+
306+
def deserialize(self, stream, content_type):
307+
"""Deserialize JSON lines data from an inference endpoint.
308+
309+
See https://docs.python.org/3/library/json.html#py-to-json-table to
310+
understand how JSON values are converted to Python objects.
311+
312+
Args:
313+
stream (botocore.response.StreamingBody): Data to be deserialized.
314+
content_type (str): The MIME type of the data.
315+
316+
Returns:
317+
list: A list of JSON serializable objects.
318+
"""
319+
try:
320+
body = stream.read().decode("utf-8")
321+
lines = body.rstrip().split("\n")
322+
return [json.loads(line) for line in lines]
323+
finally:
324+
stream.close()

0 commit comments

Comments
 (0)