Skip to content

Commit a249739

Browse files
author
Balaji Veeramani
committed
Update deserializers.py
1 parent 23b818b commit a249739

File tree

2 files changed

+16
-5
lines changed

2 files changed

+16
-5
lines changed

src/sagemaker/deserializers.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def ACCEPT(self):
5858
class StringDeserializer(BaseDeserializer):
5959
"""Deserialize data from an inference endpoint into a decoded string."""
6060

61-
ACCEPT = ("application/json", "text/csv")
61+
ACCEPT = ("application/json",)
6262

6363
def __init__(self, encoding="UTF-8"):
6464
"""Initialize the string encoding.
@@ -161,16 +161,17 @@ def deserialize(self, stream, content_type):
161161
class NumpyDeserializer(BaseDeserializer):
162162
"""Deserialize a stream of data in the .npy format."""
163163

164-
ACCEPT = ("application/x-npy", "text/csv", "application/json")
165-
166-
def __init__(self, dtype=None, allow_pickle=True):
164+
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
167165
"""Initialize the dtype and allow_pickle arguments.
168166
169167
Args:
170168
dtype (str): The dtype of the data (default: None).
169+
accept (str): The MIME type that is expected from the inference
170+
endpoint (default: "application/x-npy").
171171
allow_pickle (bool): Allow loading pickled object arrays (default: True).
172172
"""
173173
self.dtype = dtype
174+
self.accept = accept
174175
self.allow_pickle = allow_pickle
175176

176177
def deserialize(self, stream, content_type):
@@ -197,6 +198,16 @@ def deserialize(self, stream, content_type):
197198

198199
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
199200

201+
@property
202+
def ACCEPT(self):
203+
"""The content types that are expected from the inference endpoint.
204+
205+
To maintain backwards compatability with legacy images, the
206+
NumpyDeserializer supports sending only one content type in the Accept
207+
header.
208+
"""
209+
return (self.accept,)
210+
200211

201212
class JSONDeserializer(BaseDeserializer):
202213
"""Deserialize JSON data from an inference endpoint into a Python object."""

tests/unit/test_predictor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,7 @@ def test_predict_call_with_multiple_accept_types():
214214
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
215215

216216
expected_request_args = {
217-
"Accept": "application/json, text/csv",
217+
"Accept": "application/json",
218218
"Body": "1,2",
219219
"ContentType": CSV_CONTENT_TYPE,
220220
"EndpointName": ENDPOINT,

0 commit comments

Comments
 (0)