diff --git a/doc/frameworks/xgboost/using_xgboost.rst b/doc/frameworks/xgboost/using_xgboost.rst index 130fab1b9a..26c333a0bf 100644 --- a/doc/frameworks/xgboost/using_xgboost.rst +++ b/doc/frameworks/xgboost/using_xgboost.rst @@ -192,8 +192,7 @@ inference against your model. .. code:: - serializer = StringSerializer() - serializer.CONTENT_TYPE = "text/libsvm" + serializer = StringSerializer(content_type="text/libsvm") predictor = estimator.deploy( initial_instance_count=1, diff --git a/src/sagemaker/amazon/common.py b/src/sagemaker/amazon/common.py index b73d59dea9..17838ee13b 100644 --- a/src/sagemaker/amazon/common.py +++ b/src/sagemaker/amazon/common.py @@ -22,15 +22,22 @@ from sagemaker.amazon.record_pb2 import Record from sagemaker.deprecations import deprecated_class -from sagemaker.deserializers import BaseDeserializer -from sagemaker.serializers import BaseSerializer +from sagemaker.deserializers import SimpleBaseDeserializer +from sagemaker.serializers import SimpleBaseSerializer from sagemaker.utils import DeferredError -class RecordSerializer(BaseSerializer): +class RecordSerializer(SimpleBaseSerializer): """Serialize a NumPy array for an inference request.""" - CONTENT_TYPE = "application/x-recordio-protobuf" + def __init__(self, content_type="application/x-recordio-protobuf"): + """Initialize a ``RecordSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-recordio-protobuf"). + """ + super(RecordSerializer, self).__init__(content_type=content_type) def serialize(self, data): """Serialize a NumPy array into a buffer containing RecordIO records. @@ -56,10 +63,18 @@ def serialize(self, data): return buffer -class RecordDeserializer(BaseDeserializer): +class RecordDeserializer(SimpleBaseDeserializer): """Deserialize RecordIO Protobuf data from an inference endpoint.""" - ACCEPT = ("application/x-recordio-protobuf",) + def __init__(self, accept="application/x-recordio-protobuf"): + """Initialize a ``RecordDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: + "application/x-recordio-protobuf"). + """ + super(RecordDeserializer, self).__init__(accept=accept) def deserialize(self, data, content_type): """Deserialize RecordIO Protobuf data from an inference endpoint. diff --git a/src/sagemaker/deserializers.py b/src/sagemaker/deserializers.py index bffdfa8bb7..11b0557d9e 100644 --- a/src/sagemaker/deserializers.py +++ b/src/sagemaker/deserializers.py @@ -21,6 +21,7 @@ import json import numpy as np +from six import with_metaclass from sagemaker.utils import DeferredError @@ -55,17 +56,44 @@ def ACCEPT(self): """The content types that are expected from the inference endpoint.""" -class StringDeserializer(BaseDeserializer): - """Deserialize data from an inference endpoint into a decoded string.""" +class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)): + """Abstract base class for creation of new deserializers. + + This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more + user-friendly options for setting the ACCEPT content type header, in situations where it can be + provided at init and freely updated. + """ + + def __init__(self, accept="*/*"): + """Initialize a ``SimpleBaseDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: "*/*"). + """ + super(SimpleBaseDeserializer, self).__init__() + self.accept = accept + + @property + def ACCEPT(self): + """The tuple of possible content types that are expected from the inference endpoint.""" + if isinstance(self.accept, str): + return (self.accept,) + return self.accept - ACCEPT = ("application/json",) - def __init__(self, encoding="UTF-8"): - """Initialize the string encoding. +class StringDeserializer(SimpleBaseDeserializer): + """Deserialize data from an inference endpoint into a decoded string.""" + + def __init__(self, encoding="UTF-8", accept="application/json"): + """Initialize a ``StringDeserializer`` instance. Args: encoding (str): The string encoding to use (default: UTF-8). + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: "application/json"). """ + super(StringDeserializer, self).__init__(accept=accept) self.encoding = encoding def deserialize(self, stream, content_type): @@ -84,11 +112,9 @@ def deserialize(self, stream, content_type): stream.close() -class BytesDeserializer(BaseDeserializer): +class BytesDeserializer(SimpleBaseDeserializer): """Deserialize a stream of bytes into a bytes object.""" - ACCEPT = ("*/*",) - def deserialize(self, stream, content_type): """Read a stream of bytes returned from an inference endpoint. @@ -105,17 +131,23 @@ def deserialize(self, stream, content_type): stream.close() -class CSVDeserializer(BaseDeserializer): - """Deserialize a stream of bytes into a list of lists.""" +class CSVDeserializer(SimpleBaseDeserializer): + """Deserialize a stream of bytes into a list of lists. - ACCEPT = ("text/csv",) + Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or + :class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv + responses directly into other data types. + """ - def __init__(self, encoding="utf-8"): - """Initialize the string encoding. + def __init__(self, encoding="utf-8", accept="text/csv"): + """Initialize a ``CSVDeserializer`` instance. Args: encoding (str): The string encoding to use (default: "utf-8"). + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: "text/csv"). """ + super(CSVDeserializer, self).__init__(accept=accept) self.encoding = encoding def deserialize(self, stream, content_type): @@ -136,15 +168,13 @@ def deserialize(self, stream, content_type): stream.close() -class StreamDeserializer(BaseDeserializer): - """Returns the data and content-type received from an inference endpoint. +class StreamDeserializer(SimpleBaseDeserializer): + """Directly return the data and content-type received from an inference endpoint. It is the user's responsibility to close the data stream once they're done reading it. """ - ACCEPT = ("*/*",) - def deserialize(self, stream, content_type): """Returns a stream of the response body and the MIME type of the data. @@ -158,20 +188,20 @@ def deserialize(self, stream, content_type): return stream, content_type -class NumpyDeserializer(BaseDeserializer): - """Deserialize a stream of data in the .npy format.""" +class NumpyDeserializer(SimpleBaseDeserializer): + """Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array.""" def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True): - """Initialize the dtype and allow_pickle arguments. + """Initialize a ``NumpyDeserializer`` instance. Args: dtype (str): The dtype of the data (default: None). - accept (str): The MIME type that is expected from the inference - endpoint (default: "application/x-npy"). + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: "application/x-npy"). allow_pickle (bool): Allow loading pickled object arrays (default: True). """ + super(NumpyDeserializer, self).__init__(accept=accept) self.dtype = dtype - self.accept = accept self.allow_pickle = allow_pickle def deserialize(self, stream, content_type): @@ -198,21 +228,18 @@ def deserialize(self, stream, content_type): raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type)) - @property - def ACCEPT(self): - """The content types that are expected from the inference endpoint. - - To maintain backwards compatability with legacy images, the - NumpyDeserializer supports sending only one content type in the Accept - header. - """ - return (self.accept,) - -class JSONDeserializer(BaseDeserializer): +class JSONDeserializer(SimpleBaseDeserializer): """Deserialize JSON data from an inference endpoint into a Python object.""" - ACCEPT = ("application/json",) + def __init__(self, accept="application/json"): + """Initialize a ``JSONDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: "application/json"). + """ + super(JSONDeserializer, self).__init__(accept=accept) def deserialize(self, stream, content_type): """Deserialize JSON data from an inference endpoint into a Python object. @@ -230,10 +257,17 @@ def deserialize(self, stream, content_type): stream.close() -class PandasDeserializer(BaseDeserializer): +class PandasDeserializer(SimpleBaseDeserializer): """Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe.""" - ACCEPT = ("text/csv", "application/json") + def __init__(self, accept=("text/csv", "application/json")): + """Initialize a ``PandasDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: ("text/csv","application/json")). + """ + super(PandasDeserializer, self).__init__(accept=accept) def deserialize(self, stream, content_type): """Deserialize CSV or JSON data from an inference endpoint into a pandas @@ -258,10 +292,17 @@ def deserialize(self, stream, content_type): raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type)) -class JSONLinesDeserializer(BaseDeserializer): +class JSONLinesDeserializer(SimpleBaseDeserializer): """Deserialize JSON lines data from an inference endpoint.""" - ACCEPT = ("application/jsonlines",) + def __init__(self, accept="application/jsonlines"): + """Initialize a ``JSONLinesDeserializer`` instance. + + Args: + accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that + is expected from the inference endpoint (default: ("text/csv","application/json")). + """ + super(JSONLinesDeserializer, self).__init__(accept=accept) def deserialize(self, stream, content_type): """Deserialize JSON lines data from an inference endpoint. diff --git a/src/sagemaker/serializers.py b/src/sagemaker/serializers.py index 9fbf33b6e7..a5dae25c6a 100644 --- a/src/sagemaker/serializers.py +++ b/src/sagemaker/serializers.py @@ -20,6 +20,7 @@ import json import numpy as np +from six import with_metaclass from sagemaker.utils import DeferredError @@ -53,10 +54,46 @@ def CONTENT_TYPE(self): """The MIME type of the data sent to the inference endpoint.""" -class CSVSerializer(BaseSerializer): +class SimpleBaseSerializer(with_metaclass(abc.ABCMeta, BaseSerializer)): + """Abstract base class for creation of new serializers. + + This class extends the API of :class:~`sagemaker.serializers.BaseSerializer` with more + user-friendly options for setting the Content-Type header, in situations where it can be + provided at init and freely updated. + """ + + def __init__(self, content_type="application/json"): + """Initialize a ``SimpleBaseSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/json"). + """ + super(SimpleBaseSerializer, self).__init__() + if not isinstance(content_type, str): + raise ValueError( + "content_type must be a string specifying the MIME type of the data sent in " + "requests: e.g. 'application/json', 'text/csv', etc. Got %s" % content_type + ) + self.content_type = content_type + + @property + def CONTENT_TYPE(self): + """The data MIME type set in the Content-Type header on prediction endpoint requests.""" + return self.content_type + + +class CSVSerializer(SimpleBaseSerializer): """Serialize data of various formats to a CSV-formatted string.""" - CONTENT_TYPE = "text/csv" + def __init__(self, content_type="text/csv"): + """Initialize a ``CSVSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "text/csv"). + """ + super(CSVSerializer, self).__init__(content_type=content_type) def serialize(self, data): """Serialize data of various formats to a CSV-formatted string. @@ -109,17 +146,18 @@ def _is_sequence_like(self, data): return hasattr(data, "__iter__") and hasattr(data, "__getitem__") -class NumpySerializer(BaseSerializer): +class NumpySerializer(SimpleBaseSerializer): """Serialize data to a buffer using the .npy format.""" - CONTENT_TYPE = "application/x-npy" - - def __init__(self, dtype=None): - """Initialize the dtype. + def __init__(self, dtype=None, content_type="application/x-npy"): + """Initialize a ``NumpySerializer`` instance. Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-npy"). dtype (str): The dtype of the data. """ + super(NumpySerializer, self).__init__(content_type=content_type) self.dtype = dtype def serialize(self, data): @@ -162,11 +200,9 @@ def _serialize_array(self, array): return buffer.getvalue() -class JSONSerializer(BaseSerializer): +class JSONSerializer(SimpleBaseSerializer): """Serialize data to a JSON formatted string.""" - CONTENT_TYPE = "application/json" - def serialize(self, data): """Serialize data of various formats to a JSON formatted string. @@ -193,17 +229,21 @@ def serialize(self, data): return json.dumps(data) -class IdentitySerializer(BaseSerializer): - """Serialize data by returning data without modification.""" +class IdentitySerializer(SimpleBaseSerializer): + """Serialize data by returning data without modification. + + This serializer may be useful if, for example, you're sending raw bytes such as from an image + file's .read() method. + """ def __init__(self, content_type="application/octet-stream"): - """Initialize the ``content_type`` attribute. + """Initialize an ``IdentitySerializer`` instance. Args: - content_type (str): The MIME type of the serialized data (default: - "application/octet-stream"). + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/octet-stream"). """ - self.content_type = content_type + super(IdentitySerializer, self).__init__(content_type=content_type) def serialize(self, data): """Return data without modification. @@ -216,16 +256,18 @@ def serialize(self, data): """ return data - @property - def CONTENT_TYPE(self): - """The MIME type of the data sent to the inference endpoint.""" - return self.content_type - -class JSONLinesSerializer(BaseSerializer): +class JSONLinesSerializer(SimpleBaseSerializer): """Serialize data to a JSON Lines formatted string.""" - CONTENT_TYPE = "application/jsonlines" + def __init__(self, content_type="application/jsonlines"): + """Initialize a ``JSONLinesSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/jsonlines"). + """ + super(JSONLinesSerializer, self).__init__(content_type=content_type) def serialize(self, data): """Serialize data of various formats to a JSON Lines formatted string. @@ -250,10 +292,17 @@ def serialize(self, data): raise ValueError("Object of type %s is not JSON Lines serializable." % type(data)) -class SparseMatrixSerializer(BaseSerializer): +class SparseMatrixSerializer(SimpleBaseSerializer): """Serialize a sparse matrix to a buffer using the .npz format.""" - CONTENT_TYPE = "application/x-npz" + def __init__(self, content_type="application/x-npz"): + """Initialize a ``SparseMatrixSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "application/x-npz"). + """ + super(SparseMatrixSerializer, self).__init__(content_type=content_type) def serialize(self, data): """Serialize a sparse matrix to a buffer using the .npz format. @@ -272,7 +321,7 @@ def serialize(self, data): return buffer.getvalue() -class LibSVMSerializer(BaseSerializer): +class LibSVMSerializer(SimpleBaseSerializer): """Serialize data of various formats to a LibSVM-formatted string. The data must already be in LIBSVM file format: @@ -282,7 +331,14 @@ class LibSVMSerializer(BaseSerializer): features. """ - CONTENT_TYPE = "text/libsvm" + def __init__(self, content_type="text/libsvm"): + """Initialize a ``LibSVMSerializer`` instance. + + Args: + content_type (str): The MIME type to signal to the inference endpoint when sending + request data (default: "text/libsvm"). + """ + super(LibSVMSerializer, self).__init__(content_type=content_type) def serialize(self, data): """Serialize data of various formats to a LibSVM-formatted string. diff --git a/tests/integ/test_byo_estimator.py b/tests/integ/test_byo_estimator.py index 475c6a2cdc..0e51f4d0d1 100644 --- a/tests/integ/test_byo_estimator.py +++ b/tests/integ/test_byo_estimator.py @@ -20,7 +20,7 @@ import sagemaker from sagemaker import image_uris from sagemaker.estimator import Estimator -from sagemaker.serializers import BaseSerializer +from sagemaker.serializers import SimpleBaseSerializer from sagemaker.utils import unique_name_from_base from tests.integ import DATA_DIR, TRAINING_DEFAULT_TIMEOUT_MINUTES, datasets from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -36,9 +36,8 @@ def training_set(): return datasets.one_p_mnist() -class _FactorizationMachineSerializer(BaseSerializer): - - CONTENT_TYPE = "application/json" +class _FactorizationMachineSerializer(SimpleBaseSerializer): + # SimpleBaseSerializer already uses "application/json" CONTENT_TYPE by default def serialize(self, data): js = {"instances": []} diff --git a/tests/integ/test_neo_mxnet.py b/tests/integ/test_neo_mxnet.py index f3a7d0db84..e908ae45e3 100644 --- a/tests/integ/test_neo_mxnet.py +++ b/tests/integ/test_neo_mxnet.py @@ -75,8 +75,7 @@ def test_attach_deploy( output_path=estimator.output_path, ) - serializer = JSONSerializer() - serializer.CONTENT_TYPE = "application/vnd+python.numpy+binary" + serializer = JSONSerializer(content_type="application/vnd+python.numpy+binary") predictor = estimator.deploy( 1, @@ -118,8 +117,7 @@ def test_deploy_model( sagemaker_session=sagemaker_session, ) - serializer = JSONSerializer() - serializer.CONTENT_TYPE = "application/vnd+python.numpy+binary" + serializer = JSONSerializer(content_type="application/vnd+python.numpy+binary") model.compile( target_instance_family=cpu_instance_family, @@ -171,8 +169,7 @@ def test_inferentia_deploy_model( output_path="/".join(model_data.split("/")[:-1]), ) - serializer = JSONSerializer() - serializer.CONTENT_TYPE = "application/vnd+python.numpy+binary" + serializer = JSONSerializer(content_type="application/vnd+python.numpy+binary") predictor = model.deploy( 1, inf_instance_type, serializer=serializer, endpoint_name=endpoint_name diff --git a/tests/integ/test_tuner.py b/tests/integ/test_tuner.py index 02dd131e92..cc6a893567 100644 --- a/tests/integ/test_tuner.py +++ b/tests/integ/test_tuner.py @@ -28,7 +28,7 @@ from sagemaker.estimator import Estimator from sagemaker.mxnet.estimator import MXNet from sagemaker.pytorch import PyTorch -from sagemaker.serializers import BaseSerializer +from sagemaker.serializers import SimpleBaseSerializer from sagemaker.tensorflow import TensorFlow from sagemaker.tuner import ( IntegerParameter, @@ -884,9 +884,8 @@ def test_tuning_byo_estimator(sagemaker_session, cpu_instance_type): # Serializer for the Factorization Machines predictor (for BYO example) -class _FactorizationMachineSerializer(BaseSerializer): - - CONTENT_TYPE = "application/json" +class _FactorizationMachineSerializer(SimpleBaseSerializer): + # SimpleBaseSerializer already uses "application/json" CONTENT_TYPE by default def serialize(self, data): js = {"instances": []} diff --git a/tests/integ/test_tuner_multi_algo.py b/tests/integ/test_tuner_multi_algo.py index e6aa0c8e7f..02f2404b15 100644 --- a/tests/integ/test_tuner_multi_algo.py +++ b/tests/integ/test_tuner_multi_algo.py @@ -21,7 +21,7 @@ from sagemaker.analytics import HyperparameterTuningJobAnalytics from sagemaker.deserializers import JSONDeserializer from sagemaker.estimator import Estimator -from sagemaker.serializers import BaseSerializer +from sagemaker.serializers import SimpleBaseSerializer from sagemaker.tuner import ContinuousParameter, IntegerParameter, HyperparameterTuner from tests.integ import datasets, DATA_DIR, TUNING_DEFAULT_TIMEOUT_MINUTES from tests.integ.timeout import timeout, timeout_and_delete_endpoint_by_name @@ -219,9 +219,8 @@ def _create_training_inputs(sagemaker_session): return {"train": s3_train_data, "test": s3_train_data} -class PredictionDataSerializer(BaseSerializer): - - CONTENT_TYPE = "application/json" +class PredictionDataSerializer(SimpleBaseSerializer): + # SimpleBaseSerializer already uses "application/json" CONTENT_TYPE by default def serialize(self, data): js = {"instances": []}