21
21
import json
22
22
23
23
import numpy as np
24
+ from six import with_metaclass
24
25
25
26
from sagemaker .utils import DeferredError
26
27
@@ -55,17 +56,44 @@ def ACCEPT(self):
55
56
"""The content types that are expected from the inference endpoint."""
56
57
57
58
58
- class StringDeserializer (BaseDeserializer ):
59
- """Deserialize data from an inference endpoint into a decoded string."""
59
+ class SimpleBaseDeserializer (with_metaclass (abc .ABCMeta , BaseDeserializer )):
60
+ """Abstract base class for creation of new deserializers.
61
+
62
+ This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
63
+ user-friendly options for setting the ACCEPT content type header, in situations where it can be
64
+ provided at init and freely updated.
65
+ """
66
+
67
+ def __init__ (self , accept = "*/*" ):
68
+ """Initialize a ``SimpleBaseDeserializer`` instance.
69
+
70
+ Args:
71
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
72
+ is expected from the inference endpoint (default: "*/*").
73
+ """
74
+ super (SimpleBaseDeserializer , self ).__init__ ()
75
+ self .accept = accept
76
+
77
+ @property
78
+ def ACCEPT (self ):
79
+ """The tuple of possible content types that are expected from the inference endpoint."""
80
+ if isinstance (self .accept , str ):
81
+ return (self .accept ,)
82
+ return self .accept
60
83
61
- ACCEPT = ("application/json" ,)
62
84
63
- def __init__ (self , encoding = "UTF-8" ):
64
- """Initialize the string encoding.
85
+ class StringDeserializer (SimpleBaseDeserializer ):
86
+ """Deserialize data from an inference endpoint into a decoded string."""
87
+
88
+ def __init__ (self , encoding = "UTF-8" , accept = "application/json" ):
89
+ """Initialize a ``StringDeserializer`` instance.
65
90
66
91
Args:
67
92
encoding (str): The string encoding to use (default: UTF-8).
93
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
94
+ is expected from the inference endpoint (default: "application/json").
68
95
"""
96
+ super (StringDeserializer , self ).__init__ (accept = accept )
69
97
self .encoding = encoding
70
98
71
99
def deserialize (self , stream , content_type ):
@@ -84,11 +112,9 @@ def deserialize(self, stream, content_type):
84
112
stream .close ()
85
113
86
114
87
- class BytesDeserializer (BaseDeserializer ):
115
+ class BytesDeserializer (SimpleBaseDeserializer ):
88
116
"""Deserialize a stream of bytes into a bytes object."""
89
117
90
- ACCEPT = ("*/*" ,)
91
-
92
118
def deserialize (self , stream , content_type ):
93
119
"""Read a stream of bytes returned from an inference endpoint.
94
120
@@ -105,17 +131,23 @@ def deserialize(self, stream, content_type):
105
131
stream .close ()
106
132
107
133
108
- class CSVDeserializer (BaseDeserializer ):
109
- """Deserialize a stream of bytes into a list of lists."""
134
+ class CSVDeserializer (SimpleBaseDeserializer ):
135
+ """Deserialize a stream of bytes into a list of lists.
110
136
111
- ACCEPT = ("text/csv" ,)
137
+ Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
138
+ :class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
139
+ responses directly into other data types.
140
+ """
112
141
113
- def __init__ (self , encoding = "utf-8" ):
114
- """Initialize the string encoding .
142
+ def __init__ (self , encoding = "utf-8" , accept = "text/csv" ):
143
+ """Initialize a ``CSVDeserializer`` instance .
115
144
116
145
Args:
117
146
encoding (str): The string encoding to use (default: "utf-8").
147
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
148
+ is expected from the inference endpoint (default: "text/csv").
118
149
"""
150
+ super (CSVDeserializer , self ).__init__ (accept = accept )
119
151
self .encoding = encoding
120
152
121
153
def deserialize (self , stream , content_type ):
@@ -136,15 +168,13 @@ def deserialize(self, stream, content_type):
136
168
stream .close ()
137
169
138
170
139
- class StreamDeserializer (BaseDeserializer ):
140
- """Returns the data and content-type received from an inference endpoint.
171
+ class StreamDeserializer (SimpleBaseDeserializer ):
172
+ """Directly return the data and content-type received from an inference endpoint.
141
173
142
174
It is the user's responsibility to close the data stream once they're done
143
175
reading it.
144
176
"""
145
177
146
- ACCEPT = ("*/*" ,)
147
-
148
178
def deserialize (self , stream , content_type ):
149
179
"""Returns a stream of the response body and the MIME type of the data.
150
180
@@ -158,20 +188,20 @@ def deserialize(self, stream, content_type):
158
188
return stream , content_type
159
189
160
190
161
- class NumpyDeserializer (BaseDeserializer ):
162
- """Deserialize a stream of data in the .npy format."""
191
+ class NumpyDeserializer (SimpleBaseDeserializer ):
192
+ """Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array ."""
163
193
164
194
def __init__ (self , dtype = None , accept = "application/x-npy" , allow_pickle = True ):
165
- """Initialize the dtype and allow_pickle arguments .
195
+ """Initialize a ``NumpyDeserializer`` instance .
166
196
167
197
Args:
168
198
dtype (str): The dtype of the data (default: None).
169
- accept (str): The MIME type that is expected from the inference
170
- endpoint (default: "application/x-npy").
199
+ accept (union[ str, tuple[str]] ): The MIME type (or tuple of allowable MIME types) that
200
+ is expected from the inference endpoint (default: "application/x-npy").
171
201
allow_pickle (bool): Allow loading pickled object arrays (default: True).
172
202
"""
203
+ super (NumpyDeserializer , self ).__init__ (accept = accept )
173
204
self .dtype = dtype
174
- self .accept = accept
175
205
self .allow_pickle = allow_pickle
176
206
177
207
def deserialize (self , stream , content_type ):
@@ -198,21 +228,18 @@ def deserialize(self, stream, content_type):
198
228
199
229
raise ValueError ("%s cannot read content type %s." % (__class__ .__name__ , content_type ))
200
230
201
- @property
202
- def ACCEPT (self ):
203
- """The content types that are expected from the inference endpoint.
204
-
205
- To maintain backwards compatability with legacy images, the
206
- NumpyDeserializer supports sending only one content type in the Accept
207
- header.
208
- """
209
- return (self .accept ,)
210
-
211
231
212
- class JSONDeserializer (BaseDeserializer ):
232
+ class JSONDeserializer (SimpleBaseDeserializer ):
213
233
"""Deserialize JSON data from an inference endpoint into a Python object."""
214
234
215
- ACCEPT = ("application/json" ,)
235
+ def __init__ (self , accept = "application/json" ):
236
+ """Initialize a ``JSONDeserializer`` instance.
237
+
238
+ Args:
239
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
240
+ is expected from the inference endpoint (default: "application/json").
241
+ """
242
+ super (JSONDeserializer , self ).__init__ (accept = accept )
216
243
217
244
def deserialize (self , stream , content_type ):
218
245
"""Deserialize JSON data from an inference endpoint into a Python object.
@@ -230,10 +257,17 @@ def deserialize(self, stream, content_type):
230
257
stream .close ()
231
258
232
259
233
- class PandasDeserializer (BaseDeserializer ):
260
+ class PandasDeserializer (SimpleBaseDeserializer ):
234
261
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""
235
262
236
- ACCEPT = ("text/csv" , "application/json" )
263
+ def __init__ (self , accept = ("text/csv" , "application/json" )):
264
+ """Initialize a ``PandasDeserializer`` instance.
265
+
266
+ Args:
267
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
268
+ is expected from the inference endpoint (default: ("text/csv","application/json")).
269
+ """
270
+ super (PandasDeserializer , self ).__init__ (accept = accept )
237
271
238
272
def deserialize (self , stream , content_type ):
239
273
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
@@ -258,10 +292,17 @@ def deserialize(self, stream, content_type):
258
292
raise ValueError ("%s cannot read content type %s." % (__class__ .__name__ , content_type ))
259
293
260
294
261
- class JSONLinesDeserializer (BaseDeserializer ):
295
+ class JSONLinesDeserializer (SimpleBaseDeserializer ):
262
296
"""Deserialize JSON lines data from an inference endpoint."""
263
297
264
- ACCEPT = ("application/jsonlines" ,)
298
+ def __init__ (self , accept = "application/jsonlines" ):
299
+ """Initialize a ``JSONLinesDeserializer`` instance.
300
+
301
+ Args:
302
+ accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
303
+ is expected from the inference endpoint (default: ("text/csv","application/json")).
304
+ """
305
+ super (JSONLinesDeserializer , self ).__init__ (accept = accept )
265
306
266
307
def deserialize (self , stream , content_type ):
267
308
"""Deserialize JSON lines data from an inference endpoint.
0 commit comments