Skip to content

Commit 55303b8

Browse files
committed
feat: framework predictor de/serial override args
All framework *Predictor classes accept constructor arguments to override the default `serializer` & `deserializer` logic (like TensorFlowPredictor did already).
1 parent e8d16f8 commit 55303b8

File tree

6 files changed

+88
-12
lines changed

6 files changed

+88
-12
lines changed

src/sagemaker/chainer/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ class ChainerPredictor(Predictor):
3838
multidimensional tensors for Chainer inference.
3939
"""
4040

41-
def __init__(self, endpoint_name, sagemaker_session=None):
41+
def __init__(
42+
self,
43+
endpoint_name,
44+
sagemaker_session=None,
45+
serializer=NumpySerializer(),
46+
deserializer=NumpyDeserializer(),
47+
):
4248
"""Initialize an ``ChainerPredictor``.
4349
4450
Args:
@@ -48,9 +54,17 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4854
manages interactions with Amazon SageMaker APIs and any other
4955
AWS services needed. If not specified, the estimator creates one
5056
using the default AWS configuration chain.
57+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
58+
serializes input data to .npy format. Handles lists and numpy
59+
arrays.
60+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
61+
Default parses the response from .npy format to numpy array.
5162
"""
5263
super(ChainerPredictor, self).__init__(
53-
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
64+
endpoint_name,
65+
sagemaker_session,
66+
serializer=serializer,
67+
deserializer=deserializer,
5468
)
5569

5670

src/sagemaker/mxnet/model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,13 @@ class MXNetPredictor(Predictor):
4040
multidimensional tensors for MXNet inference.
4141
"""
4242

43-
def __init__(self, endpoint_name, sagemaker_session=None):
43+
def __init__(
44+
self,
45+
endpoint_name,
46+
sagemaker_session=None,
47+
serializer=JSONSerializer(),
48+
deserializer=JSONDeserializer(),
49+
):
4450
"""Initialize an ``MXNetPredictor``.
4551
4652
Args:
@@ -50,9 +56,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
5056
manages interactions with Amazon SageMaker APIs and any other
5157
AWS services needed. If not specified, the estimator creates one
5258
using the default AWS configuration chain.
59+
serializer (callable): Optional. Default serializes input data to
60+
json. Handles dicts, lists, and numpy arrays.
61+
deserializer (callable): Optional. Default parses the response using
62+
``json.load(...)``.
5363
"""
5464
super(MXNetPredictor, self).__init__(
55-
endpoint_name, sagemaker_session, JSONSerializer(), JSONDeserializer()
65+
endpoint_name,
66+
sagemaker_session,
67+
serializer=serializer,
68+
deserializer=deserializer,
5669
)
5770

5871

src/sagemaker/pytorch/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,13 @@ class PyTorchPredictor(Predictor):
3939
multidimensional tensors for PyTorch inference.
4040
"""
4141

42-
def __init__(self, endpoint_name, sagemaker_session=None):
42+
def __init__(
43+
self,
44+
endpoint_name,
45+
sagemaker_session=None,
46+
serializer=NumpySerializer(),
47+
deserializer=NumpyDeserializer(),
48+
):
4349
"""Initialize an ``PyTorchPredictor``.
4450
4551
Args:
@@ -49,9 +55,17 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4955
manages interactions with Amazon SageMaker APIs and any other
5056
AWS services needed. If not specified, the estimator creates one
5157
using the default AWS configuration chain.
58+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
59+
serializes input data to .npy format. Handles lists and numpy
60+
arrays.
61+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
62+
Default parses the response from .npy format to numpy array.
5263
"""
5364
super(PyTorchPredictor, self).__init__(
54-
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
65+
endpoint_name,
66+
sagemaker_session,
67+
serializer=serializer,
68+
deserializer=deserializer,
5569
)
5670

5771

src/sagemaker/sklearn/model.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ class SKLearnPredictor(Predictor):
3434
multidimensional tensors for Scikit-learn inference.
3535
"""
3636

37-
def __init__(self, endpoint_name, sagemaker_session=None):
37+
def __init__(
38+
self,
39+
endpoint_name,
40+
sagemaker_session=None,
41+
serializer=NumpySerializer(),
42+
deserializer=NumpyDeserializer(),
43+
):
3844
"""Initialize an ``SKLearnPredictor``.
3945
4046
Args:
@@ -44,9 +50,17 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4450
manages interactions with Amazon SageMaker APIs and any other
4551
AWS services needed. If not specified, the estimator creates one
4652
using the default AWS configuration chain.
53+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
54+
serializes input data to .npy format. Handles lists and numpy
55+
arrays.
56+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
57+
Default parses the response from .npy format to numpy array.
4758
"""
4859
super(SKLearnPredictor, self).__init__(
49-
endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer()
60+
endpoint_name,
61+
sagemaker_session,
62+
serializer=serializer,
63+
deserializer=deserializer,
5064
)
5165

5266

src/sagemaker/sparkml/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,13 @@ class SparkMLPredictor(Predictor):
3131
list.
3232
"""
3333

34-
def __init__(self, endpoint_name, sagemaker_session=None, **kwargs):
34+
def __init__(
35+
self,
36+
endpoint_name,
37+
sagemaker_session=None,
38+
serializer=CSVSerializer(),
39+
**kwargs,
40+
):
3541
"""Initializes a SparkMLPredictor which should be used with SparkMLModel
3642
to perform predictions against SparkML models serialized via MLeap. The
3743
response is returned in text/csv format which is the default response
@@ -43,12 +49,14 @@ def __init__(self, endpoint_name, sagemaker_session=None, **kwargs):
4349
manages interactions with Amazon SageMaker APIs and any other
4450
AWS services needed. If not specified, the estimator creates one
4551
using the default AWS configuration chain.
52+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
53+
serializes input data to text/csv.
4654
"""
4755
sagemaker_session = sagemaker_session or Session()
4856
super(SparkMLPredictor, self).__init__(
4957
endpoint_name=endpoint_name,
5058
sagemaker_session=sagemaker_session,
51-
serializer=CSVSerializer(),
59+
serializer=serializer,
5260
**kwargs,
5361
)
5462

src/sagemaker/xgboost/model.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ class XGBoostPredictor(Predictor):
3535
for XGBoost inference.
3636
"""
3737

38-
def __init__(self, endpoint_name, sagemaker_session=None):
38+
def __init__(
39+
self,
40+
endpoint_name,
41+
sagemaker_session=None,
42+
serializer=LibSVMSerializer(),
43+
deserializer=CSVDeserializer(),
44+
):
3945
"""Initialize an ``XGBoostPredictor``.
4046
4147
Args:
@@ -44,9 +50,16 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4450
interactions with Amazon SageMaker APIs and any other AWS services needed.
4551
If not specified, the estimator creates one using the default AWS configuration
4652
chain.
53+
serializer (sagemaker.serializers.BaseSerializer): Optional. Default
54+
serializes input data to LibSVM format
55+
deserializer (sagemaker.deserializers.BaseDeserializer): Optional.
56+
Default parses the response from text/csv to a Python list.
4757
"""
4858
super(XGBoostPredictor, self).__init__(
49-
endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer()
59+
endpoint_name,
60+
sagemaker_session,
61+
serializer=serializer,
62+
deserializer=deserializer,
5063
)
5164

5265

0 commit comments

Comments
 (0)