Skip to content

Commit 56e4de4

Browse files
authored
Merge branch 'zwei' into remove-content-types
2 parents 350588a + d4f7ce8 commit 56e4de4

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

src/sagemaker/deserializers.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ class BaseDeserializer(abc.ABC):
3131
"""
3232

3333
@abc.abstractmethod
34-
def deserialize(self, data, content_type):
34+
def deserialize(self, stream, content_type):
3535
"""Deserialize data received from an inference endpoint.
3636
3737
Args:
38-
data (object): Data to be deserialized.
38+
stream (botocore.response.StreamingBody): Data to be deserialized.
3939
content_type (str): The MIME type of the data.
4040
4141
Returns:
@@ -61,41 +61,41 @@ def __init__(self, encoding="UTF-8"):
6161
"""
6262
self.encoding = encoding
6363

64-
def deserialize(self, data, content_type):
64+
def deserialize(self, stream, content_type):
6565
"""Deserialize data from an inference endpoint into a decoded string.
6666
6767
Args:
68-
data (object): Data to be deserialized.
68+
stream (botocore.response.StreamingBody): Data to be deserialized.
6969
content_type (str): The MIME type of the data.
7070
7171
Returns:
7272
str: The data deserialized into a decoded string.
7373
"""
7474
try:
75-
return data.read().decode(self.encoding)
75+
return stream.read().decode(self.encoding)
7676
finally:
77-
data.close()
77+
stream.close()
7878

7979

8080
class BytesDeserializer(BaseDeserializer):
8181
"""Deserialize a stream of bytes into a bytes object."""
8282

8383
ACCEPT = "*/*"
8484

85-
def deserialize(self, data, content_type):
85+
def deserialize(self, stream, content_type):
8686
"""Read a stream of bytes returned from an inference endpoint.
8787
8888
Args:
89-
data (object): A stream of bytes.
89+
stream (botocore.response.StreamingBody): A stream of bytes.
9090
content_type (str): The MIME type of the data.
9191
9292
Returns:
9393
bytes: The bytes object read from the stream.
9494
"""
9595
try:
96-
return data.read()
96+
return stream.read()
9797
finally:
98-
data.close()
98+
stream.close()
9999

100100

101101
class CSVDeserializer(BaseDeserializer):
@@ -111,22 +111,22 @@ def __init__(self, encoding="utf-8"):
111111
"""
112112
self.encoding = encoding
113113

114-
def deserialize(self, data, content_type):
114+
def deserialize(self, stream, content_type):
115115
"""Deserialize data from an inference endpoint into a list of lists.
116116
117117
Args:
118-
data (botocore.response.StreamingBody): Data to be deserialized.
118+
stream (botocore.response.StreamingBody): Data to be deserialized.
119119
content_type (str): The MIME type of the data.
120120
121121
Returns:
122122
list: The data deserialized into a list of lists representing the
123123
contents of a CSV file.
124124
"""
125125
try:
126-
decoded_string = data.read().decode(self.encoding)
126+
decoded_string = stream.read().decode(self.encoding)
127127
return list(csv.reader(decoded_string.splitlines()))
128128
finally:
129-
data.close()
129+
stream.close()
130130

131131

132132
class StreamDeserializer(BaseDeserializer):
@@ -138,17 +138,17 @@ class StreamDeserializer(BaseDeserializer):
138138

139139
ACCEPT = "*/*"
140140

141-
def deserialize(self, data, content_type):
141+
def deserialize(self, stream, content_type):
142142
"""Returns a stream of the response body and the MIME type of the data.
143143
144144
Args:
145-
data (object): A stream of bytes.
145+
stream (botocore.response.StreamingBody): A stream of bytes.
146146
content_type (str): The MIME type of the data.
147147
148148
Returns:
149149
tuple: A two-tuple containing the stream and content-type.
150150
"""
151-
return data, content_type
151+
return stream, content_type
152152

153153

154154
class NumpyDeserializer(BaseDeserializer):
@@ -164,11 +164,11 @@ def __init__(self, dtype=None):
164164
"""
165165
self.dtype = dtype
166166

167-
def deserialize(self, data, content_type):
167+
def deserialize(self, stream, content_type):
168168
"""Deserialize data from an inference endpoint into a NumPy array.
169169
170170
Args:
171-
data (botocore.response.StreamingBody): Data to be deserialized.
171+
stream (botocore.response.StreamingBody): Data to be deserialized.
172172
content_type (str): The MIME type of the data.
173173
174174
Returns:
@@ -177,14 +177,14 @@ def deserialize(self, data, content_type):
177177
try:
178178
if content_type == "text/csv":
179179
return np.genfromtxt(
180-
codecs.getreader("utf-8")(data), delimiter=",", dtype=self.dtype
180+
codecs.getreader("utf-8")(stream), delimiter=",", dtype=self.dtype
181181
)
182182
if content_type == "application/json":
183-
return np.array(json.load(codecs.getreader("utf-8")(data)), dtype=self.dtype)
183+
return np.array(json.load(codecs.getreader("utf-8")(stream)), dtype=self.dtype)
184184
if content_type == "application/x-npy":
185-
return np.load(io.BytesIO(data.read()))
185+
return np.load(io.BytesIO(stream.read()))
186186
finally:
187-
data.close()
187+
stream.close()
188188

189189
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
190190

@@ -194,17 +194,17 @@ class JSONDeserializer(BaseDeserializer):
194194

195195
ACCEPT = "application/json"
196196

197-
def deserialize(self, data, content_type):
197+
def deserialize(self, stream, content_type):
198198
"""Deserialize JSON data from an inference endpoint into a Python object.
199199
200200
Args:
201-
data (botocore.response.StreamingBody): Data to be deserialized.
201+
stream (botocore.response.StreamingBody): Data to be deserialized.
202202
content_type (str): The MIME type of the data.
203203
204204
Returns:
205205
object: The JSON-formatted data deserialized into a Python object.
206206
"""
207207
try:
208-
return json.load(codecs.getreader("utf-8")(data))
208+
return json.load(codecs.getreader("utf-8")(stream))
209209
finally:
210-
data.close()
210+
stream.close()

0 commit comments

Comments
 (0)