Skip to content

Commit 8dd3247

Browse files
athewseyajaykarpur
andauthored
feature: all predictors support serializer/deserializer overrides (#1997)
* feat: framework predictor de/serial override args All framework *Predictor classes accept constructor arguments to override the default `serializer` & `deserializer` logic (like TensorFlowPredictor did already). * feat: Amz algo predictor de/serial override args All Amazon algorithm *Predictor classes accept constructor arguments to override the default `serializer` & `deserializer` logic, in case users need to specify alternative inference formats. * change: add unit tests for de/ser overrides Co-authored-by: Ajay Karpur <[email protected]>
1 parent 85de4b9 commit 8dd3247

29 files changed

+541
-48
lines changed

src/sagemaker/amazon/factorization_machines.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -277,13 +277,20 @@ class FactorizationMachinesPredictor(Predictor):
277277
to fit the model this Predictor performs inference on.
278278
279279
:meth:`predict()` returns a list of
280-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
280+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
281+
recordio-protobuf ``deserializer`` is used), one for each row in
281282
the input ``ndarray``. The prediction is stored in the ``"score"`` key of
282283
the ``Record.label`` field. Please refer to the formats details described:
283284
https://docs.aws.amazon.com/sagemaker/latest/dg/fm-in-formats.html
284285
"""
285286

286-
def __init__(self, endpoint_name, sagemaker_session=None):
287+
def __init__(
288+
self,
289+
endpoint_name,
290+
sagemaker_session=None,
291+
serializer=RecordSerializer(),
292+
deserializer=RecordDeserializer(),
293+
):
287294
"""
288295
Args:
289296
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -292,12 +299,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
292299
object, used for SageMaker interactions (default: None). If not
293300
specified, one is created using the default AWS configuration
294301
chain.
302+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
303+
serializes input data to x-recordio-protobuf format.
304+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
305+
Default parses responses from x-recordio-protobuf format.
295306
"""
296307
super(FactorizationMachinesPredictor, self).__init__(
297308
endpoint_name,
298309
sagemaker_session,
299-
serializer=RecordSerializer(),
300-
deserializer=RecordDeserializer(),
310+
serializer=serializer,
311+
deserializer=deserializer,
301312
)
302313

303314

src/sagemaker/amazon/ipinsights.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,13 @@ class IPInsightsPredictor(Predictor):
191191
second column should contain the IPv4 address in dot notation.
192192
"""
193193

194-
def __init__(self, endpoint_name, sagemaker_session=None):
194+
def __init__(
195+
self,
196+
endpoint_name,
197+
sagemaker_session=None,
198+
serializer=CSVSerializer(),
199+
deserializer=JSONDeserializer(),
200+
):
195201
"""
196202
Args:
197203
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -200,12 +206,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
200206
object, used for SageMaker interactions (default: None). If not
201207
specified, one is created using the default AWS configuration
202208
chain.
209+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
210+
serializes input data to text/csv.
211+
deserializer (callable): Optional. Default parses JSON responses
212+
using ``json.load(...)``.
203213
"""
204214
super(IPInsightsPredictor, self).__init__(
205215
endpoint_name,
206216
sagemaker_session,
207-
serializer=CSVSerializer(),
208-
deserializer=JSONDeserializer(),
217+
serializer=serializer,
218+
deserializer=deserializer,
209219
)
210220

211221

src/sagemaker/amazon/kmeans.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,19 @@ class KMeansPredictor(Predictor):
210210
to fit the model this Predictor performs inference on.
211211
212212
``predict()`` returns a list of
213-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
213+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
214+
recordio-protobuf ``deserializer`` is used), one for each row in
214215
the input ``ndarray``. The nearest cluster is stored in the
215216
``closest_cluster`` key of the ``Record.label`` field.
216217
"""
217218

218-
def __init__(self, endpoint_name, sagemaker_session=None):
219+
def __init__(
220+
self,
221+
endpoint_name,
222+
sagemaker_session=None,
223+
serializer=RecordSerializer(),
224+
deserializer=RecordDeserializer(),
225+
):
219226
"""
220227
Args:
221228
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -224,12 +231,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
224231
object, used for SageMaker interactions (default: None). If not
225232
specified, one is created using the default AWS configuration
226233
chain.
234+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
235+
serializes input data to x-recordio-protobuf format.
236+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
237+
Default parses responses from x-recordio-protobuf format.
227238
"""
228239
super(KMeansPredictor, self).__init__(
229240
endpoint_name,
230241
sagemaker_session,
231-
serializer=RecordSerializer(),
232-
deserializer=RecordDeserializer(),
242+
serializer=serializer,
243+
deserializer=deserializer,
233244
)
234245

235246

src/sagemaker/amazon/knn.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,12 +199,19 @@ class KNNPredictor(Predictor):
199199
to fit the model this Predictor performs inference on.
200200
201201
:func:`predict` returns a list of
202-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
202+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
203+
recordio-protobuf ``deserializer`` is used), one for each row in
203204
the input ``ndarray``. The prediction is stored in the ``"predicted_label"``
204205
key of the ``Record.label`` field.
205206
"""
206207

207-
def __init__(self, endpoint_name, sagemaker_session=None):
208+
def __init__(
209+
self,
210+
endpoint_name,
211+
sagemaker_session=None,
212+
serializer=RecordSerializer(),
213+
deserializer=RecordDeserializer(),
214+
):
208215
"""
209216
Args:
210217
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -213,12 +220,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
213220
object, used for SageMaker interactions (default: None). If not
214221
specified, one is created using the default AWS configuration
215222
chain.
223+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
224+
serializes input data to x-recordio-protobuf format.
225+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
226+
Default parses responses from x-recordio-protobuf format.
216227
"""
217228
super(KNNPredictor, self).__init__(
218229
endpoint_name,
219230
sagemaker_session,
220-
serializer=RecordSerializer(),
221-
deserializer=RecordDeserializer(),
231+
serializer=serializer,
232+
deserializer=deserializer,
222233
)
223234

224235

src/sagemaker/amazon/lda.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,19 @@ class LDAPredictor(Predictor):
183183
to fit the model this Predictor performs inference on.
184184
185185
:meth:`predict()` returns a list of
186-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
186+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
187+
recordio-protobuf ``deserializer`` is used), one for each row in
187188
the input ``ndarray``. The lower dimension vector result is stored in the
188189
``projection`` key of the ``Record.label`` field.
189190
"""
190191

191-
def __init__(self, endpoint_name, sagemaker_session=None):
192+
def __init__(
193+
self,
194+
endpoint_name,
195+
sagemaker_session=None,
196+
serializer=RecordSerializer(),
197+
deserializer=RecordDeserializer(),
198+
):
192199
"""
193200
Args:
194201
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -197,12 +204,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
197204
object, used for SageMaker interactions (default: None). If not
198205
specified, one is created using the default AWS configuration
199206
chain.
207+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
208+
serializes input data to x-recordio-protobuf format.
209+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
210+
Default parses responses from x-recordio-protobuf format.
200211
"""
201212
super(LDAPredictor, self).__init__(
202213
endpoint_name,
203214
sagemaker_session,
204-
serializer=RecordSerializer(),
205-
deserializer=RecordDeserializer(),
215+
serializer=serializer,
216+
deserializer=deserializer,
206217
)
207218

208219

src/sagemaker/amazon/linear_learner.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -444,12 +444,19 @@ class LinearLearnerPredictor(Predictor):
444444
to fit the model this Predictor performs inference on.
445445
446446
:func:`predict` returns a list of
447-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
447+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
448+
recordio-protobuf ``deserializer`` is used), one for each row in
448449
the input ``ndarray``. The prediction is stored in the ``"predicted_label"``
449450
key of the ``Record.label`` field.
450451
"""
451452

452-
def __init__(self, endpoint_name, sagemaker_session=None):
453+
def __init__(
454+
self,
455+
endpoint_name,
456+
sagemaker_session=None,
457+
serializer=RecordSerializer(),
458+
deserializer=RecordDeserializer(),
459+
):
453460
"""
454461
Args:
455462
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -458,12 +465,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
458465
object, used for SageMaker interactions (default: None). If not
459466
specified, one is created using the default AWS configuration
460467
chain.
468+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
469+
serializes input data to x-recordio-protobuf format.
470+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
471+
Default parses responses from x-recordio-protobuf format.
461472
"""
462473
super(LinearLearnerPredictor, self).__init__(
463474
endpoint_name,
464475
sagemaker_session,
465-
serializer=RecordSerializer(),
466-
deserializer=RecordDeserializer(),
476+
serializer=serializer,
477+
deserializer=deserializer,
467478
)
468479

469480

src/sagemaker/amazon/ntm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,12 +212,19 @@ class NTMPredictor(Predictor):
212212
to fit the model this Predictor performs inference on.
213213
214214
:meth:`predict()` returns a list of
215-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
215+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
216+
recordio-protobuf ``deserializer`` is used), one for each row in
216217
the input ``ndarray``. The lower dimension vector result is stored in the
217218
``projection`` key of the ``Record.label`` field.
218219
"""
219220

220-
def __init__(self, endpoint_name, sagemaker_session=None):
221+
def __init__(
222+
self,
223+
endpoint_name,
224+
sagemaker_session=None,
225+
serializer=RecordSerializer(),
226+
deserializer=RecordDeserializer(),
227+
):
221228
"""
222229
Args:
223230
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -226,12 +233,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
226233
object, used for SageMaker interactions (default: None). If not
227234
specified, one is created using the default AWS configuration
228235
chain.
236+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
237+
serializes input data to x-recordio-protobuf format.
238+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
239+
Default parses responses from x-recordio-protobuf format.
229240
"""
230241
super(NTMPredictor, self).__init__(
231242
endpoint_name,
232243
sagemaker_session,
233-
serializer=RecordSerializer(),
234-
deserializer=RecordDeserializer(),
244+
serializer=serializer,
245+
deserializer=deserializer,
235246
)
236247

237248

src/sagemaker/amazon/pca.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -193,12 +193,19 @@ class PCAPredictor(Predictor):
193193
to fit the model this Predictor performs inference on.
194194
195195
:meth:`predict()` returns a list of
196-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
196+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
197+
recordio-protobuf ``deserializer`` is used), one for each row in
197198
the input ``ndarray``. The lower dimension vector result is stored in the
198199
``projection`` key of the ``Record.label`` field.
199200
"""
200201

201-
def __init__(self, endpoint_name, sagemaker_session=None):
202+
def __init__(
203+
self,
204+
endpoint_name,
205+
sagemaker_session=None,
206+
serializer=RecordSerializer(),
207+
deserializer=RecordDeserializer(),
208+
):
202209
"""
203210
Args:
204211
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -207,12 +214,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
207214
object, used for SageMaker interactions (default: None). If not
208215
specified, one is created using the default AWS configuration
209216
chain.
217+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
218+
serializes input data to x-recordio-protobuf format.
219+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
220+
Default parses responses from x-recordio-protobuf format.
210221
"""
211222
super(PCAPredictor, self).__init__(
212223
endpoint_name,
213224
sagemaker_session,
214-
serializer=RecordSerializer(),
215-
deserializer=RecordDeserializer(),
225+
serializer=serializer,
226+
deserializer=deserializer,
216227
)
217228

218229

src/sagemaker/amazon/randomcutforest.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -171,12 +171,19 @@ class RandomCutForestPredictor(Predictor):
171171
to fit the model this Predictor performs inference on.
172172
173173
:meth:`predict()` returns a list of
174-
:class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in
174+
:class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default
175+
recordio-protobuf ``deserializer`` is used), one for each row in
175176
the input. Each row's score is stored in the key ``score`` of the
176177
``Record.label`` field.
177178
"""
178179

179-
def __init__(self, endpoint_name, sagemaker_session=None):
180+
def __init__(
181+
self,
182+
endpoint_name,
183+
sagemaker_session=None,
184+
serializer=RecordSerializer(),
185+
deserializer=RecordDeserializer(),
186+
):
180187
"""
181188
Args:
182189
endpoint_name (str): Name of the Amazon SageMaker endpoint to which
@@ -185,12 +192,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
185192
object, used for SageMaker interactions (default: None). If not
186193
specified, one is created using the default AWS configuration
187194
chain.
195+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
196+
serializes input data to x-recordio-protobuf format.
197+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
198+
Default parses responses from x-recordio-protobuf format.
188199
"""
189200
super(RandomCutForestPredictor, self).__init__(
190201
endpoint_name,
191202
sagemaker_session,
192-
serializer=RecordSerializer(),
193-
deserializer=RecordDeserializer(),
203+
serializer=serializer,
204+
deserializer=deserializer,
194205
)
195206

196207

src/sagemaker/chainer/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ class ChainerPredictor(Predictor):
3838
multidimensional tensors for Chainer inference.
3939
"""
4040

41-
def __init__(self, endpoint_name, sagemaker_session=None):
41+
def __init__(
42+
self,
43+
endpoint_name,
44+
sagemaker_session=None,
45+
serializer=NumpySerializer(),
46+
deserializer=NumpyDeserializer(),
47+
):
4248
"""Initialize an ``ChainerPredictor``.
4349
4450
Args:
@@ -48,9 +54,17 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4854
manages interactions with Amazon SageMaker APIs and any other
4955
AWS services needed. If not specified, the estimator creates one
5056
using the default AWS configuration chain.
57+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
58+
serializes input data to .npy format. Handles lists and numpy
59+
arrays.
60+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
61+
Default parses the response from .npy format to numpy array.
5162
"""
5263
super(ChainerPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
64+
endpoint_name,
65+
sagemaker_session,
66+
serializer=serializer,
67+
deserializer=deserializer,
5468
)
5569

5670

0 commit comments

Comments
 (0)