Skip to content

feature: all de/serializers support content type #1993

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions doc/frameworks/xgboost/using_xgboost.rst
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,7 @@ inference against your model.

.. code::

serializer = StringSerializer()
serializer.CONTENT_TYPE = "text/libsvm"
serializer = StringSerializer(content_type="text/libsvm")

predictor = estimator.deploy(
initial_instance_count=1,
Expand Down
27 changes: 21 additions & 6 deletions src/sagemaker/amazon/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,22 @@

from sagemaker.amazon.record_pb2 import Record
from sagemaker.deprecations import deprecated_class
from sagemaker.deserializers import BaseDeserializer
from sagemaker.serializers import BaseSerializer
from sagemaker.deserializers import SimpleBaseDeserializer
from sagemaker.serializers import SimpleBaseSerializer
from sagemaker.utils import DeferredError


class RecordSerializer(BaseSerializer):
class RecordSerializer(SimpleBaseSerializer):
"""Serialize a NumPy array for an inference request."""

CONTENT_TYPE = "application/x-recordio-protobuf"
def __init__(self, content_type="application/x-recordio-protobuf"):
"""Initialize a ``RecordSerializer`` instance.

Args:
content_type (str): The MIME type to signal to the inference endpoint when sending
request data (default: "application/x-recordio-protobuf").
"""
super(RecordSerializer, self).__init__(content_type=content_type)

def serialize(self, data):
"""Serialize a NumPy array into a buffer containing RecordIO records.
Expand All @@ -56,10 +63,18 @@ def serialize(self, data):
return buffer


class RecordDeserializer(BaseDeserializer):
class RecordDeserializer(SimpleBaseDeserializer):
"""Deserialize RecordIO Protobuf data from an inference endpoint."""

ACCEPT = ("application/x-recordio-protobuf",)
def __init__(self, accept="application/x-recordio-protobuf"):
"""Initialize a ``RecordDeserializer`` instance.

Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default:
"application/x-recordio-protobuf").
"""
super(RecordDeserializer, self).__init__(accept=accept)

def deserialize(self, data, content_type):
"""Deserialize RecordIO Protobuf data from an inference endpoint.
Expand Down
119 changes: 80 additions & 39 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import json

import numpy as np
from six import with_metaclass

from sagemaker.utils import DeferredError

Expand Down Expand Up @@ -55,17 +56,44 @@ def ACCEPT(self):
"""The content types that are expected from the inference endpoint."""


class StringDeserializer(BaseDeserializer):
"""Deserialize data from an inference endpoint into a decoded string."""
class SimpleBaseDeserializer(with_metaclass(abc.ABCMeta, BaseDeserializer)):
"""Abstract base class for creation of new deserializers.

This class extends the API of :class:~`sagemaker.deserializers.BaseDeserializer` with more
user-friendly options for setting the ACCEPT content type header, in situations where it can be
provided at init and freely updated.
"""

def __init__(self, accept="*/*"):
"""Initialize a ``SimpleBaseDeserializer`` instance.

Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "*/*").
"""
super(SimpleBaseDeserializer, self).__init__()
self.accept = accept

@property
def ACCEPT(self):
"""The tuple of possible content types that are expected from the inference endpoint."""
if isinstance(self.accept, str):
return (self.accept,)
return self.accept

ACCEPT = ("application/json",)

def __init__(self, encoding="UTF-8"):
"""Initialize the string encoding.
class StringDeserializer(SimpleBaseDeserializer):
"""Deserialize data from an inference endpoint into a decoded string."""

def __init__(self, encoding="UTF-8", accept="application/json"):
"""Initialize a ``StringDeserializer`` instance.

Args:
encoding (str): The string encoding to use (default: UTF-8).
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "application/json").
"""
super(StringDeserializer, self).__init__(accept=accept)
self.encoding = encoding

def deserialize(self, stream, content_type):
Expand All @@ -84,11 +112,9 @@ def deserialize(self, stream, content_type):
stream.close()


class BytesDeserializer(BaseDeserializer):
class BytesDeserializer(SimpleBaseDeserializer):
"""Deserialize a stream of bytes into a bytes object."""

ACCEPT = ("*/*",)

def deserialize(self, stream, content_type):
"""Read a stream of bytes returned from an inference endpoint.

Expand All @@ -105,17 +131,23 @@ def deserialize(self, stream, content_type):
stream.close()


class CSVDeserializer(BaseDeserializer):
"""Deserialize a stream of bytes into a list of lists."""
class CSVDeserializer(SimpleBaseDeserializer):
"""Deserialize a stream of bytes into a list of lists.

ACCEPT = ("text/csv",)
Consider using :class:~`sagemaker.deserializers.NumpyDeserializer` or
:class:~`sagemaker.deserializers.PandasDeserializer` instead, if you'd like to convert text/csv
responses directly into other data types.
"""

def __init__(self, encoding="utf-8"):
"""Initialize the string encoding.
def __init__(self, encoding="utf-8", accept="text/csv"):
"""Initialize a ``CSVDeserializer`` instance.

Args:
encoding (str): The string encoding to use (default: "utf-8").
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "text/csv").
"""
super(CSVDeserializer, self).__init__(accept=accept)
self.encoding = encoding

def deserialize(self, stream, content_type):
Expand All @@ -136,15 +168,13 @@ def deserialize(self, stream, content_type):
stream.close()


class StreamDeserializer(BaseDeserializer):
"""Returns the data and content-type received from an inference endpoint.
class StreamDeserializer(SimpleBaseDeserializer):
"""Directly return the data and content-type received from an inference endpoint.

It is the user's responsibility to close the data stream once they're done
reading it.
"""

ACCEPT = ("*/*",)

def deserialize(self, stream, content_type):
"""Returns a stream of the response body and the MIME type of the data.

Expand All @@ -158,20 +188,20 @@ def deserialize(self, stream, content_type):
return stream, content_type


class NumpyDeserializer(BaseDeserializer):
"""Deserialize a stream of data in the .npy format."""
class NumpyDeserializer(SimpleBaseDeserializer):
"""Deserialize a stream of data in .npy or UTF-8 CSV/JSON format to a numpy array."""

def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
"""Initialize the dtype and allow_pickle arguments.
"""Initialize a ``NumpyDeserializer`` instance.

Args:
dtype (str): The dtype of the data (default: None).
accept (str): The MIME type that is expected from the inference
endpoint (default: "application/x-npy").
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "application/x-npy").
allow_pickle (bool): Allow loading pickled object arrays (default: True).
"""
super(NumpyDeserializer, self).__init__(accept=accept)
self.dtype = dtype
self.accept = accept
self.allow_pickle = allow_pickle

def deserialize(self, stream, content_type):
Expand All @@ -198,21 +228,18 @@ def deserialize(self, stream, content_type):

raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

@property
def ACCEPT(self):
"""The content types that are expected from the inference endpoint.

To maintain backwards compatability with legacy images, the
NumpyDeserializer supports sending only one content type in the Accept
header.
"""
return (self.accept,)


class JSONDeserializer(BaseDeserializer):
class JSONDeserializer(SimpleBaseDeserializer):
"""Deserialize JSON data from an inference endpoint into a Python object."""

ACCEPT = ("application/json",)
def __init__(self, accept="application/json"):
"""Initialize a ``JSONDeserializer`` instance.

Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: "application/json").
"""
super(JSONDeserializer, self).__init__(accept=accept)

def deserialize(self, stream, content_type):
"""Deserialize JSON data from an inference endpoint into a Python object.
Expand All @@ -230,10 +257,17 @@ def deserialize(self, stream, content_type):
stream.close()


class PandasDeserializer(BaseDeserializer):
class PandasDeserializer(SimpleBaseDeserializer):
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""

ACCEPT = ("text/csv", "application/json")
def __init__(self, accept=("text/csv", "application/json")):
"""Initialize a ``PandasDeserializer`` instance.

Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: ("text/csv","application/json")).
"""
super(PandasDeserializer, self).__init__(accept=accept)

def deserialize(self, stream, content_type):
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
Expand All @@ -258,10 +292,17 @@ def deserialize(self, stream, content_type):
raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))


class JSONLinesDeserializer(BaseDeserializer):
class JSONLinesDeserializer(SimpleBaseDeserializer):
"""Deserialize JSON lines data from an inference endpoint."""

ACCEPT = ("application/jsonlines",)
def __init__(self, accept="application/jsonlines"):
"""Initialize a ``JSONLinesDeserializer`` instance.

Args:
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
is expected from the inference endpoint (default: ("text/csv","application/json")).
"""
super(JSONLinesDeserializer, self).__init__(accept=accept)

def deserialize(self, stream, content_type):
"""Deserialize JSON lines data from an inference endpoint.
Expand Down
Loading