diff --git a/src/sagemaker/amazon/factorization_machines.py b/src/sagemaker/amazon/factorization_machines.py index 7544a3d5c7..48bed830cb 100644 --- a/src/sagemaker/amazon/factorization_machines.py +++ b/src/sagemaker/amazon/factorization_machines.py @@ -277,13 +277,20 @@ class FactorizationMachinesPredictor(Predictor): to fit the model this Predictor performs inference on. :meth:`predict()` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The prediction is stored in the ``"score"`` key of the ``Record.label`` field. Please refer to the formats details described: https://docs.aws.amazon.com/sagemaker/latest/dg/fm-in-formats.html """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -292,12 +299,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(FactorizationMachinesPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/ipinsights.py b/src/sagemaker/amazon/ipinsights.py index 8f4c2964e0..34a0173e2f 100644 --- a/src/sagemaker/amazon/ipinsights.py +++ b/src/sagemaker/amazon/ipinsights.py @@ -191,7 +191,13 @@ class IPInsightsPredictor(Predictor): second column should contain the IPv4 address in dot notation. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=CSVSerializer(), + deserializer=JSONDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -200,12 +206,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to text/csv. + deserializer (callable): Optional. Default parses JSON responses + using ``json.load(...)``. """ super(IPInsightsPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=CSVSerializer(), - deserializer=JSONDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/kmeans.py b/src/sagemaker/amazon/kmeans.py index 2b405b6f62..59d6fe2116 100644 --- a/src/sagemaker/amazon/kmeans.py +++ b/src/sagemaker/amazon/kmeans.py @@ -210,12 +210,19 @@ class KMeansPredictor(Predictor): to fit the model this Predictor performs inference on. ``predict()`` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The nearest cluster is stored in the ``closest_cluster`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -224,12 +231,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(KMeansPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/knn.py b/src/sagemaker/amazon/knn.py index 1f1822c8f8..f5af87e377 100644 --- a/src/sagemaker/amazon/knn.py +++ b/src/sagemaker/amazon/knn.py @@ -199,12 +199,19 @@ class KNNPredictor(Predictor): to fit the model this Predictor performs inference on. :func:`predict` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The prediction is stored in the ``"predicted_label"`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -213,12 +220,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(KNNPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/lda.py b/src/sagemaker/amazon/lda.py index 4740f53085..8c459b3aac 100644 --- a/src/sagemaker/amazon/lda.py +++ b/src/sagemaker/amazon/lda.py @@ -183,12 +183,19 @@ class LDAPredictor(Predictor): to fit the model this Predictor performs inference on. :meth:`predict()` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -197,12 +204,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(LDAPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/linear_learner.py b/src/sagemaker/amazon/linear_learner.py index e1eec87a03..a7c798c188 100644 --- a/src/sagemaker/amazon/linear_learner.py +++ b/src/sagemaker/amazon/linear_learner.py @@ -444,12 +444,19 @@ class LinearLearnerPredictor(Predictor): to fit the model this Predictor performs inference on. :func:`predict` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The prediction is stored in the ``"predicted_label"`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -458,12 +465,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(LinearLearnerPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/ntm.py b/src/sagemaker/amazon/ntm.py index 9dcec3bfe5..d24a1aee2f 100644 --- a/src/sagemaker/amazon/ntm.py +++ b/src/sagemaker/amazon/ntm.py @@ -212,12 +212,19 @@ class NTMPredictor(Predictor): to fit the model this Predictor performs inference on. :meth:`predict()` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -226,12 +233,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(NTMPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/pca.py b/src/sagemaker/amazon/pca.py index 2d38dda3f7..7c11e05494 100644 --- a/src/sagemaker/amazon/pca.py +++ b/src/sagemaker/amazon/pca.py @@ -193,12 +193,19 @@ class PCAPredictor(Predictor): to fit the model this Predictor performs inference on. :meth:`predict()` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input ``ndarray``. The lower dimension vector result is stored in the ``projection`` key of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -207,12 +214,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(PCAPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/amazon/randomcutforest.py b/src/sagemaker/amazon/randomcutforest.py index ea7d824355..ce5b79178f 100644 --- a/src/sagemaker/amazon/randomcutforest.py +++ b/src/sagemaker/amazon/randomcutforest.py @@ -171,12 +171,19 @@ class RandomCutForestPredictor(Predictor): to fit the model this Predictor performs inference on. :meth:`predict()` returns a list of - :class:`~sagemaker.amazon.record_pb2.Record` objects, one for each row in + :class:`~sagemaker.amazon.record_pb2.Record` objects (assuming the default + recordio-protobuf ``deserializer`` is used), one for each row in the input. Each row's score is stored in the key ``score`` of the ``Record.label`` field. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=RecordSerializer(), + deserializer=RecordDeserializer(), + ): """ Args: endpoint_name (str): Name of the Amazon SageMaker endpoint to which @@ -185,12 +192,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): object, used for SageMaker interactions (default: None). If not specified, one is created using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to x-recordio-protobuf format. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses responses from x-recordio-protobuf format. """ super(RandomCutForestPredictor, self).__init__( endpoint_name, sagemaker_session, - serializer=RecordSerializer(), - deserializer=RecordDeserializer(), + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/chainer/model.py b/src/sagemaker/chainer/model.py index 48ca779257..d7f6debfbe 100644 --- a/src/sagemaker/chainer/model.py +++ b/src/sagemaker/chainer/model.py @@ -38,7 +38,13 @@ class ChainerPredictor(Predictor): multidimensional tensors for Chainer inference. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=NumpySerializer(), + deserializer=NumpyDeserializer(), + ): """Initialize an ``ChainerPredictor``. Args: @@ -48,9 +54,17 @@ def __init__(self, endpoint_name, sagemaker_session=None): manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to .npy format. Handles lists and numpy + arrays. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from .npy format to numpy array. """ super(ChainerPredictor, self).__init__( - endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer() + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/mxnet/model.py b/src/sagemaker/mxnet/model.py index d1317c59a1..624e5cfbe1 100644 --- a/src/sagemaker/mxnet/model.py +++ b/src/sagemaker/mxnet/model.py @@ -40,7 +40,13 @@ class MXNetPredictor(Predictor): multidimensional tensors for MXNet inference. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=JSONSerializer(), + deserializer=JSONDeserializer(), + ): """Initialize an ``MXNetPredictor``. Args: @@ -50,9 +56,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (callable): Optional. Default serializes input data to + json. Handles dicts, lists, and numpy arrays. + deserializer (callable): Optional. Default parses the response using + ``json.load(...)``. """ super(MXNetPredictor, self).__init__( - endpoint_name, sagemaker_session, JSONSerializer(), JSONDeserializer() + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/pytorch/model.py b/src/sagemaker/pytorch/model.py index b7ba240597..0b752ae9a3 100644 --- a/src/sagemaker/pytorch/model.py +++ b/src/sagemaker/pytorch/model.py @@ -39,7 +39,13 @@ class PyTorchPredictor(Predictor): multidimensional tensors for PyTorch inference. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=NumpySerializer(), + deserializer=NumpyDeserializer(), + ): """Initialize an ``PyTorchPredictor``. Args: @@ -49,9 +55,17 @@ def __init__(self, endpoint_name, sagemaker_session=None): manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to .npy format. Handles lists and numpy + arrays. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from .npy format to numpy array. """ super(PyTorchPredictor, self).__init__( - endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer() + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/sklearn/model.py b/src/sagemaker/sklearn/model.py index 2a505fe924..06d28ae14c 100644 --- a/src/sagemaker/sklearn/model.py +++ b/src/sagemaker/sklearn/model.py @@ -34,7 +34,13 @@ class SKLearnPredictor(Predictor): multidimensional tensors for Scikit-learn inference. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=NumpySerializer(), + deserializer=NumpyDeserializer(), + ): """Initialize an ``SKLearnPredictor``. Args: @@ -44,9 +50,17 @@ def __init__(self, endpoint_name, sagemaker_session=None): manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to .npy format. Handles lists and numpy + arrays. + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from .npy format to numpy array. """ super(SKLearnPredictor, self).__init__( - endpoint_name, sagemaker_session, NumpySerializer(), NumpyDeserializer() + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, ) diff --git a/src/sagemaker/sparkml/model.py b/src/sagemaker/sparkml/model.py index 9661d091fd..482f391713 100644 --- a/src/sagemaker/sparkml/model.py +++ b/src/sagemaker/sparkml/model.py @@ -31,7 +31,13 @@ class SparkMLPredictor(Predictor): list. """ - def __init__(self, endpoint_name, sagemaker_session=None, **kwargs): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=CSVSerializer(), + **kwargs, + ): """Initializes a SparkMLPredictor which should be used with SparkMLModel to perform predictions against SparkML models serialized via MLeap. The response is returned in text/csv format which is the default response @@ -43,12 +49,14 @@ def __init__(self, endpoint_name, sagemaker_session=None, **kwargs): manages interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to text/csv. """ sagemaker_session = sagemaker_session or Session() super(SparkMLPredictor, self).__init__( endpoint_name=endpoint_name, sagemaker_session=sagemaker_session, - serializer=CSVSerializer(), + serializer=serializer, **kwargs, ) diff --git a/src/sagemaker/xgboost/model.py b/src/sagemaker/xgboost/model.py index 400f3e6bcc..536bf75020 100644 --- a/src/sagemaker/xgboost/model.py +++ b/src/sagemaker/xgboost/model.py @@ -35,7 +35,13 @@ class XGBoostPredictor(Predictor): for XGBoost inference. """ - def __init__(self, endpoint_name, sagemaker_session=None): + def __init__( + self, + endpoint_name, + sagemaker_session=None, + serializer=LibSVMSerializer(), + deserializer=CSVDeserializer(), + ): """Initialize an ``XGBoostPredictor``. Args: @@ -44,9 +50,16 @@ def __init__(self, endpoint_name, sagemaker_session=None): interactions with Amazon SageMaker APIs and any other AWS services needed. If not specified, the estimator creates one using the default AWS configuration chain. + serializer (sagemaker.serializers.BaseSerializer): Optional. Default + serializes input data to LibSVM format + deserializer (sagemaker.deserializers.BaseDeserializer): Optional. + Default parses the response from text/csv to a Python list. """ super(XGBoostPredictor, self).__init__( - endpoint_name, sagemaker_session, LibSVMSerializer(), CSVDeserializer() + endpoint_name, + sagemaker_session, + serializer=serializer, + deserializer=deserializer, ) diff --git a/tests/unit/test_chainer.py b/tests/unit/test_chainer.py index 5679bb5284..d2aaaec8d9 100644 --- a/tests/unit/test_chainer.py +++ b/tests/unit/test_chainer.py @@ -383,6 +383,29 @@ def test_model(sagemaker_session, chainer_version, chainer_py_version): assert isinstance(predictor, ChainerPredictor) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_model_custom_serialization(sagemaker_session, chainer_version, chainer_py_version): + model = ChainerModel( + "s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + framework_version=chainer_version, + py_version=chainer_py_version, + ) + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + CPU, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + assert isinstance(predictor, ChainerPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer + + @patch("sagemaker.fw_utils.tar_and_upload_dir", MagicMock()) def test_model_prepare_container_def_accelerator_error( sagemaker_session, chainer_version, chainer_py_version diff --git a/tests/unit/test_fm.py b/tests/unit/test_fm.py index b2f96b1c41..26b3bba0cd 100644 --- a/tests/unit/test_fm.py +++ b/tests/unit/test_fm.py @@ -330,3 +330,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, FactorizationMachinesPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + fm = FactorizationMachines(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + fm.fit(data, MINI_BATCH_SIZE) + model = fm.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, FactorizationMachinesPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_ipinsights.py b/tests/unit/test_ipinsights.py index 0673d375b4..a1a3959baf 100644 --- a/tests/unit/test_ipinsights.py +++ b/tests/unit/test_ipinsights.py @@ -305,3 +305,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, IPInsightsPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + ipinsights = IPInsights(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + ipinsights.fit(data, MINI_BATCH_SIZE) + model = ipinsights.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, IPInsightsPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_kmeans.py b/tests/unit/test_kmeans.py index ec5a8476a7..89c0580a95 100644 --- a/tests/unit/test_kmeans.py +++ b/tests/unit/test_kmeans.py @@ -272,3 +272,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, KMeansPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + kmeans = KMeans(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + kmeans.fit(data, MINI_BATCH_SIZE) + model = kmeans.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, KMeansPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_knn.py b/tests/unit/test_knn.py index 999de92f5c..e2546ba3b1 100644 --- a/tests/unit/test_knn.py +++ b/tests/unit/test_knn.py @@ -296,3 +296,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, KNNPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + knn = KNN(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + knn.fit(data, MINI_BATCH_SIZE) + model = knn.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, KNNPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_lda.py b/tests/unit/test_lda.py index 5e8579c595..0b4f3a8a65 100644 --- a/tests/unit/test_lda.py +++ b/tests/unit/test_lda.py @@ -232,3 +232,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, LDAPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + lda = LDA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + lda.fit(data, MINI_BATCH_SZIE) + model = lda.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, LDAPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_linear_learner.py b/tests/unit/test_linear_learner.py index dc6e082827..6f8d2caf69 100644 --- a/tests/unit/test_linear_learner.py +++ b/tests/unit/test_linear_learner.py @@ -433,3 +433,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, LinearLearnerPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + lr = LinearLearner(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + lr.fit(data) + model = lr.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, LinearLearnerPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_mxnet.py b/tests/unit/test_mxnet.py index 091ce3497a..31fc38f265 100644 --- a/tests/unit/test_mxnet.py +++ b/tests/unit/test_mxnet.py @@ -581,6 +581,31 @@ def test_model_register_all_args( ) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +def test_model_custom_serialization( + sagemaker_session, mxnet_inference_version, mxnet_inference_py_version, skip_if_mms_version +): + model = MXNetModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=mxnet_inference_version, + py_version=mxnet_inference_py_version, + sagemaker_session=sagemaker_session, + ) + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + CPU, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + assert isinstance(predictor, MXNetPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer + + @patch("sagemaker.utils.repack_model") def test_model_mms_version( repack_model, diff --git a/tests/unit/test_ntm.py b/tests/unit/test_ntm.py index 3805d640a9..6aa07c469c 100644 --- a/tests/unit/test_ntm.py +++ b/tests/unit/test_ntm.py @@ -301,3 +301,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, NTMPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + ntm = NTM(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + ntm.fit(data, MINI_BATCH_SIZE) + model = ntm.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, NTMPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_pca.py b/tests/unit/test_pca.py index 4861b57c5a..79934b3a61 100644 --- a/tests/unit/test_pca.py +++ b/tests/unit/test_pca.py @@ -252,3 +252,27 @@ def test_predictor_type(sagemaker_session): predictor = model.deploy(1, INSTANCE_TYPE) assert isinstance(predictor, PCAPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + pca = PCA(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS) + data = RecordSet( + "s3://{}/{}".format(BUCKET_NAME, PREFIX), + num_records=1, + feature_dim=FEATURE_DIM, + channel="train", + ) + pca.fit(data, MINI_BATCH_SIZE) + model = pca.create_model() + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + INSTANCE_TYPE, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + + assert isinstance(predictor, PCAPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer diff --git a/tests/unit/test_pytorch.py b/tests/unit/test_pytorch.py index 1d33fb25dc..5fc4ddfca7 100644 --- a/tests/unit/test_pytorch.py +++ b/tests/unit/test_pytorch.py @@ -408,6 +408,34 @@ def test_model_image_accelerator(sagemaker_session): assert "Unsupported Python version: py2." in str(error) +@patch("sagemaker.utils.create_tar_file", MagicMock()) +@patch("sagemaker.utils.repack_model", MagicMock()) +def test_model_custom_serialization( + sagemaker_session, + pytorch_inference_version, + pytorch_inference_py_version, +): + model = PyTorchModel( + MODEL_DATA, + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=pytorch_inference_version, + py_version=pytorch_inference_py_version, + sagemaker_session=sagemaker_session, + ) + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + GPU, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + assert isinstance(predictor, PyTorchPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer + + def test_model_prepare_container_def_no_instance_type_or_image(): model = PyTorchModel( MODEL_DATA, diff --git a/tests/unit/test_sklearn.py b/tests/unit/test_sklearn.py index dc37d599b2..aa3fd491b8 100644 --- a/tests/unit/test_sklearn.py +++ b/tests/unit/test_sklearn.py @@ -20,8 +20,8 @@ from mock import Mock from mock import patch -from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor from sagemaker.fw_utils import UploadedCode +from sagemaker.sklearn import SKLearn, SKLearnModel, SKLearnPredictor DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") SCRIPT_PATH = os.path.join(DATA_DIR, "dummy_script.py") @@ -408,6 +408,27 @@ def test_model(sagemaker_session, sklearn_version): assert isinstance(predictor, SKLearnPredictor) +def test_model_custom_serialization(sagemaker_session, sklearn_version): + model = SKLearnModel( + "s3://some/data.tar.gz", + role=ROLE, + entry_point=SCRIPT_PATH, + framework_version=sklearn_version, + sagemaker_session=sagemaker_session, + ) + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + CPU, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + assert isinstance(predictor, SKLearnPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer + + def test_attach(sagemaker_session, sklearn_version): training_image = "1.dkr.ecr.us-west-2.amazonaws.com/sagemaker-scikit-learn:{}-cpu-{}".format( sklearn_version, PYTHON_VERSION diff --git a/tests/unit/test_sparkml_serving.py b/tests/unit/test_sparkml_serving.py index 4888639537..1ea9b7e3da 100644 --- a/tests/unit/test_sparkml_serving.py +++ b/tests/unit/test_sparkml_serving.py @@ -57,3 +57,12 @@ def test_predictor_type(sagemaker_session): predictor = sparkml.deploy(1, TRAIN_INSTANCE_TYPE) assert isinstance(predictor, SparkMLPredictor) + + +def test_predictor_custom_serialization(sagemaker_session): + sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE) + custom_serializer = Mock() + predictor = sparkml.deploy(1, TRAIN_INSTANCE_TYPE, serializer=custom_serializer) + + assert isinstance(predictor, SparkMLPredictor) + assert predictor.serializer is custom_serializer diff --git a/tests/unit/test_xgboost.py b/tests/unit/test_xgboost.py index dd2b80c76e..0f155df987 100644 --- a/tests/unit/test_xgboost.py +++ b/tests/unit/test_xgboost.py @@ -429,6 +429,27 @@ def test_model(sagemaker_session, xgboost_framework_version): assert isinstance(predictor, XGBoostPredictor) +def test_model_custom_serialization(sagemaker_session, xgboost_framework_version): + model = XGBoostModel( + "s3://some/data.tar.gz", + role=ROLE, + framework_version=xgboost_framework_version, + entry_point=SCRIPT_PATH, + sagemaker_session=sagemaker_session, + ) + custom_serializer = Mock() + custom_deserializer = Mock() + predictor = model.deploy( + 1, + CPU, + serializer=custom_serializer, + deserializer=custom_deserializer, + ) + assert isinstance(predictor, XGBoostPredictor) + assert predictor.serializer is custom_serializer + assert predictor.deserializer is custom_deserializer + + def test_training_image_uri(sagemaker_session, xgboost_framework_version): xgboost = XGBoost( entry_point=SCRIPT_PATH,