Skip to content

Commit 84a5102

Browse files
athewseyajaykarpur
andauthored
feature: all de/serializers support content type (#1993)
* feature: all de/serializers support content type Add support for content_type constructor arg for all serializers, and accept constructor arg for all deserializers. This will make our de/serializers easier to re-purpose for models with specific header requirements but standard content formats. * fix: update broken integ tests And refactor some others to use SimpleBaseSerializer * feat: content_type for RecordIO de/serializers Extend content_type and accept args to the RecordIO de/serializers hiding over in the sagemaker.amazon package that were missed in the initial commit. Co-authored-by: Ajay Karpur <[email protected]>
1 parent a421487 commit 84a5102

File tree

8 files changed

+197
-92
lines changed

8 files changed

+197
-92
lines changed

doc/frameworks/xgboost/using_xgboost.rst

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,8 +192,7 @@ inference against your model.
192192

193193
.. code::
194194
195-
serializer = StringSerializer()
196-
serializer.CONTENT_TYPE = "text/libsvm"
195+
serializer = StringSerializer(content_type="text/libsvm")
197196
198197
predictor = estimator.deploy(
199198
initial_instance_count=1,

src/sagemaker/amazon/common.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,22 @@
2222

2323
from sagemaker.amazon.record_pb2 import Record
2424
from sagemaker.deprecations import deprecated_class
25-
from sagemaker.deserializers import BaseDeserializer
26-
from sagemaker.serializers import BaseSerializer
25+
from sagemaker.deserializers import SimpleBaseDeserializer
26+
from sagemaker.serializers import SimpleBaseSerializer
2727
from sagemaker.utils import DeferredError
2828

2929

30-
class RecordSerializer(BaseSerializer):
30+
class RecordSerializer(SimpleBaseSerializer):
3131
"""Serialize a NumPy array for an inference request."""
3232

33-
CONTENT_TYPE = "application/x-recordio-protobuf"
33+
def __init__(self, content_type="application/x-recordio-protobuf"):
34+
"""Initialize a ``RecordSerializer`` instance.
35+
36+
Args:
37+
content_type (str): The MIME type to signal to the inference endpoint when sending
38+
request data (default: "application/x-recordio-protobuf").
39+
"""
40+
super(RecordSerializer, self).__init__(content_type=content_type)
3441

3542
def serialize(self, data):
3643
"""Serialize a NumPy array into a buffer containing RecordIO records.
@@ -56,10 +63,18 @@ def serialize(self, data):
5663
return buffer
5764

5865

59-
class RecordDeserializer(BaseDeserializer):
66+
class RecordDeserializer(SimpleBaseDeserializer):
6067
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
6168

62-
ACCEPT = ("application/x-recordio-protobuf",)
69+
def __init__(self, accept="application/x-recordio-protobuf"):
70+
"""Initialize a ``RecordDeserializer`` instance.
71+
72+
Args:
73+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
74+
is expected from the inference endpoint (default:
75+
"application/x-recordio-protobuf").
76+
"""
77+
super(RecordDeserializer, self).__init__(accept=accept)
6378

6479
def deserialize(self, data, content_type):
6580
"""Deserialize RecordIO Protobuf data from an inference endpoint.

src/sagemaker/deserializers.py

Lines changed: 80 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import json
2222

2323
import numpy as np
24+
from six import with_metaclass
2425

2526
from sagemaker.utils import DeferredError
2627

@@ -55,17 +56,44 @@ def ACCEPT(self):
5556
"""The content types that are expected from the inference endpoint."""
5657

5758

58-
class StringDeserializer(BaseDeserializer):
59-
"""Deserialize data from an inference endpoint into a decoded string."""
59+
class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)):
60+
"""Abstract base class for creation of new deserializers.
61+
62+
This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
63+
user-friendly options for setting the ACCEPT content type header, in situations where it can be
64+
provided at init and freely updated.
65+
"""
66+
67+
def __init__(self, accept="*/*"):
68+
"""Initialize a ``SimpleBaseDeserializer`` instance.
69+
70+
Args:
71+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
72+
is expected from the inference endpoint (default: "*/*").
73+
"""
74+
super(SimpleBaseDeserializer, self).__init__()
75+
self.accept = accept
76+
77+
@property
78+
def ACCEPT(self):
79+
"""The tuple of possible content types that are expected from the inference endpoint."""
80+
if isinstance(self.accept, str):
81+
return (self.accept,)
82+
return self.accept
6083

61-
ACCEPT = ("application/json",)
6284

63-
def __init__(self, encoding="UTF-8"):
64-
"""Initialize the string encoding.
85+
class StringDeserializer(SimpleBaseDeserializer):
86+
"""Deserialize data from an inference endpoint into a decoded string."""
87+
88+
def __init__(self, encoding="UTF-8", accept="application/json"):
89+
"""Initialize a ``StringDeserializer`` instance.
6590
6691
Args:
6792
encoding (str): The string encoding to use (default: UTF-8).
93+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
94+
is expected from the inference endpoint (default: "application/json").
6895
"""
96+
super(StringDeserializer, self).__init__(accept=accept)
6997
self.encoding = encoding
7098

7199
def deserialize(self, stream, content_type):
@@ -84,11 +112,9 @@ def deserialize(self, stream, content_type):
84112
stream.close()
85113

86114

87-
class BytesDeserializer(BaseDeserializer):
115+
class BytesDeserializer(SimpleBaseDeserializer):
88116
"""Deserialize a stream of bytes into a bytes object."""
89117

90-
ACCEPT = ("*/*",)
91-
92118
def deserialize(self, stream, content_type):
93119
"""Read a stream of bytes returned from an inference endpoint.
94120
@@ -105,17 +131,23 @@ def deserialize(self, stream, content_type):
105131
stream.close()
106132

107133

108-
class CSVDeserializer(BaseDeserializer):
109-
"""Deserialize a stream of bytes into a list of lists."""
134+
class CSVDeserializer(SimpleBaseDeserializer):
135+
"""Deserialize a stream of bytes into a list of lists.
110136
111-
ACCEPT = ("text/csv",)
137+
Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
138+
:class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
139+
responses directly into other data types.
140+
"""
112141

113-
def __init__(self, encoding="utf-8"):
114-
"""Initialize the string encoding.
142+
def __init__(self, encoding="utf-8", accept="text/csv"):
143+
"""Initialize a ``CSVDeserializer`` instance.
115144
116145
Args:
117146
encoding (str): The string encoding to use (default: "utf-8").
147+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
148+
is expected from the inference endpoint (default: "text/csv").
118149
"""
150+
super(CSVDeserializer, self).__init__(accept=accept)
119151
self.encoding = encoding
120152

121153
def deserialize(self, stream, content_type):
@@ -136,15 +168,13 @@ def deserialize(self, stream, content_type):
136168
stream.close()
137169

138170

139-
class StreamDeserializer(BaseDeserializer):
140-
"""Returns the data and content-type received from an inference endpoint.
171+
class StreamDeserializer(SimpleBaseDeserializer):
172+
"""Directly return the data and content-type received from an inference endpoint.
141173
142174
It is the user's responsibility to close the data stream once they're done
143175
reading it.
144176
"""
145177

146-
ACCEPT = ("*/*",)
147-
148178
def deserialize(self, stream, content_type):
149179
"""Returns a stream of the response body and the MIME type of the data.
150180
@@ -158,20 +188,20 @@ def deserialize(self, stream, content_type):
158188
return stream, content_type
159189

160190

161-
class NumpyDeserializer(BaseDeserializer):
162-
"""Deserialize a stream of data in the .npy format."""
191+
class NumpyDeserializer(SimpleBaseDeserializer):
192+
"""Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array."""
163193

164194
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
165-
"""Initialize the dtype and allow_pickle arguments.
195+
"""Initialize a ``NumpyDeserializer`` instance.
166196
167197
Args:
168198
dtype (str): The dtype of the data (default: None).
169-
accept (str): The MIME type that is expected from the inference
170-
endpoint (default: "application/x-npy").
199+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
200+
is expected from the inference endpoint (default: "application/x-npy").
171201
allow_pickle (bool): Allow loading pickled object arrays (default: True).
172202
"""
203+
super(NumpyDeserializer, self).__init__(accept=accept)
173204
self.dtype = dtype
174-
self.accept = accept
175205
self.allow_pickle = allow_pickle
176206

177207
def deserialize(self, stream, content_type):
@@ -198,21 +228,18 @@ def deserialize(self, stream, content_type):
198228

199229
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
200230

201-
@property
202-
def ACCEPT(self):
203-
"""The content types that are expected from the inference endpoint.
204-
205-
To maintain backwards compatability with legacy images, the
206-
NumpyDeserializer supports sending only one content type in the Accept
207-
header.
208-
"""
209-
return (self.accept,)
210-
211231

212-
class JSONDeserializer(BaseDeserializer):
232+
class JSONDeserializer(SimpleBaseDeserializer):
213233
"""Deserialize JSON data from an inference endpoint into a Python object."""
214234

215-
ACCEPT = ("application/json",)
235+
def __init__(self, accept="application/json"):
236+
"""Initialize a ``JSONDeserializer`` instance.
237+
238+
Args:
239+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
240+
is expected from the inference endpoint (default: "application/json").
241+
"""
242+
super(JSONDeserializer, self).__init__(accept=accept)
216243

217244
def deserialize(self, stream, content_type):
218245
"""Deserialize JSON data from an inference endpoint into a Python object.
@@ -230,10 +257,17 @@ def deserialize(self, stream, content_type):
230257
stream.close()
231258

232259

233-
class PandasDeserializer(BaseDeserializer):
260+
class PandasDeserializer(SimpleBaseDeserializer):
234261
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
235262

236-
ACCEPT = ("text/csv", "application/json")
263+
def __init__(self, accept=("text/csv", "application/json")):
264+
"""Initialize a ``PandasDeserializer`` instance.
265+
266+
Args:
267+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
268+
is expected from the inference endpoint (default: ("text/csv","application/json")).
269+
"""
270+
super(PandasDeserializer, self).__init__(accept=accept)
237271

238272
def deserialize(self, stream, content_type):
239273
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
@@ -258,10 +292,17 @@ def deserialize(self, stream, content_type):
258292
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))
259293

260294

261-
class JSONLinesDeserializer(BaseDeserializer):
295+
class JSONLinesDeserializer(SimpleBaseDeserializer):
262296
"""Deserialize JSON lines data from an inference endpoint."""
263297

264-
ACCEPT = ("application/jsonlines",)
298+
def __init__(self, accept="application/jsonlines"):
299+
"""Initialize a ``JSONLinesDeserializer`` instance.
300+
301+
Args:
302+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
303+
is expected from the inference endpoint (default: ("text/csv","application/json")).
304+
"""
305+
super(JSONLinesDeserializer, self).__init__(accept=accept)
265306

266307
def deserialize(self, stream, content_type):
267308
"""Deserialize JSON lines data from an inference endpoint.

0 commit comments

Comments
 (0)