Skip to content

change: Support multiple Accept types #1794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Aug 4, 2020
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/sagemaker/amazon/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def serialize(self, data):
class RecordDeserializer(BaseDeserializer):
"""Deserialize RecordIO Protobuf data from an inference endpoint."""

ACCEPT = "application/x-recordio-protobuf"
ACCEPT = ("application/x-recordio-protobuf",)

def deserialize(self, data, content_type):
"""Deserialize RecordIO Protobuf data from an inference endpoint.
Expand Down
33 changes: 22 additions & 11 deletions src/sagemaker/deserializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def deserialize(self, stream, content_type):
@property
@abc.abstractmethod
def ACCEPT(self):
"""The content type that is expected from the inference endpoint."""
"""The content types that are expected from the inference endpoint."""


class StringDeserializer(BaseDeserializer):
"""Deserialize data from an inference endpoint into a decoded string."""

ACCEPT = "application/json"
ACCEPT = ("application/json",)

def __init__(self, encoding="UTF-8"):
"""Initialize the string encoding.
Expand Down Expand Up @@ -87,7 +87,7 @@ def deserialize(self, stream, content_type):
class BytesDeserializer(BaseDeserializer):
"""Deserialize a stream of bytes into a bytes object."""

ACCEPT = "*/*"
ACCEPT = ("*/*",)

def deserialize(self, stream, content_type):
"""Read a stream of bytes returned from an inference endpoint.
Expand All @@ -108,7 +108,7 @@ def deserialize(self, stream, content_type):
class CSVDeserializer(BaseDeserializer):
"""Deserialize a stream of bytes into a list of lists."""

ACCEPT = "text/csv"
ACCEPT = ("text/csv",)

def __init__(self, encoding="utf-8"):
"""Initialize the string encoding.
Expand Down Expand Up @@ -143,7 +143,7 @@ class StreamDeserializer(BaseDeserializer):
reading it.
"""

ACCEPT = "*/*"
ACCEPT = ("*/*",)

def deserialize(self, stream, content_type):
"""Returns a stream of the response body and the MIME type of the data.
Expand All @@ -161,16 +161,17 @@ def deserialize(self, stream, content_type):
class NumpyDeserializer(BaseDeserializer):
"""Deserialize a stream of data in the .npy format."""

ACCEPT = "application/x-npy"

def __init__(self, dtype=None, allow_pickle=True):
def __init__(self, dtype=None, accept="application/x-npy", allow_pickle=True):
"""Initialize the dtype and allow_pickle arguments.

Args:
dtype (str): The dtype of the data (default: None).
accept (str): The MIME type that is expected from the inference
endpoint (default: "application/x-npy").
allow_pickle (bool): Allow loading pickled object arrays (default: True).
"""
self.dtype = dtype
self.accept = accept
self.allow_pickle = allow_pickle

def deserialize(self, stream, content_type):
Expand All @@ -197,11 +198,21 @@ def deserialize(self, stream, content_type):

raise ValueError("%s cannot read content type %s." % (__class__.__name__, content_type))

@property
def ACCEPT(self):
"""The content types that are expected from the inference endpoint.

To maintain backwards compatability with legacy images, the
NumpyDeserializer supports sending only one content type in the Accept
header.
"""
return (self.accept,)


class JSONDeserializer(BaseDeserializer):
"""Deserialize JSON data from an inference endpoint into a Python object."""

ACCEPT = "application/json"
ACCEPT = ("application/json",)

def deserialize(self, stream, content_type):
"""Deserialize JSON data from an inference endpoint into a Python object.
Expand All @@ -222,7 +233,7 @@ def deserialize(self, stream, content_type):
class PandasDeserializer(BaseDeserializer):
"""Deserialize CSV or JSON data from an inference endpoint into a pandas dataframe."""

ACCEPT = "text/csv"
ACCEPT = ("text/csv", "application/json")

def deserialize(self, stream, content_type):
"""Deserialize CSV or JSON data from an inference endpoint into a pandas
Expand Down Expand Up @@ -250,7 +261,7 @@ def deserialize(self, stream, content_type):
class JSONLinesDeserializer(BaseDeserializer):
"""Deserialize JSON lines data from an inference endpoint."""

ACCEPT = "application/jsonlines"
ACCEPT = ("application/jsonlines",)

def deserialize(self, stream, content_type):
"""Deserialize JSON lines data from an inference endpoint.
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
args["ContentType"] = self.content_type

if "Accept" not in args:
args["Accept"] = self.accept
args["Accept"] = ", ".join(self.accept)

if target_model:
args["TargetModel"] = target_model
Expand Down
37 changes: 30 additions & 7 deletions tests/unit/test_predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import io
import json

import pytest
from mock import Mock, call, patch

from sagemaker.deserializers import CSVDeserializer, PandasDeserializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer, CSVSerializer

Expand Down Expand Up @@ -132,7 +134,7 @@ def json_sagemaker_session():
response_body.close = Mock("close", return_value=None)
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
name="invoke_endpoint",
return_value={"Body": response_body, "ContentType": DEFAULT_CONTENT_TYPE},
return_value={"Body": response_body, "ContentType": "application/json"},
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fix is out of scope for this PR, but since it's small I decided to keep it in.

)
return ims

Expand Down Expand Up @@ -168,9 +170,7 @@ def ret_csv_sagemaker_session():
ims.sagemaker_client.describe_endpoint = Mock(return_value=ENDPOINT_DESC)
ims.sagemaker_client.describe_endpoint_config = Mock(return_value=ENDPOINT_CONFIG_DESC)

response_body = Mock("body")
response_body.read = Mock("read", return_value=CSV_RETURN_VALUE)
response_body.close = Mock("close", return_value=None)
response_body = io.BytesIO(bytes(CSV_RETURN_VALUE, "utf-8"))
ims.sagemaker_runtime_client.invoke_endpoint = Mock(
name="invoke_endpoint",
return_value={"Body": response_body, "ContentType": CSV_CONTENT_TYPE},
Expand All @@ -180,23 +180,46 @@ def ret_csv_sagemaker_session():

def test_predict_call_with_csv():
sagemaker_session = ret_csv_sagemaker_session()
predictor = Predictor(ENDPOINT, sagemaker_session, serializer=CSVSerializer())
predictor = Predictor(
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=CSVDeserializer()
)

data = [1, 2]
result = predictor.predict(data)

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

expected_request_args = {
"Accept": DEFAULT_ACCEPT,
"Accept": CSV_CONTENT_TYPE,
"Body": "1,2",
"ContentType": CSV_CONTENT_TYPE,
"EndpointName": ENDPOINT,
}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args

assert result == CSV_RETURN_VALUE
assert result == [["1", "2", "3"]]


def test_predict_call_with_multiple_accept_types():
sagemaker_session = ret_csv_sagemaker_session()
predictor = Predictor(
ENDPOINT, sagemaker_session, serializer=CSVSerializer(), deserializer=PandasDeserializer()
)

data = [1, 2]
predictor.predict(data)

assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called

expected_request_args = {
"Accept": "text/csv, application/json",
"Body": "1,2",
"ContentType": CSV_CONTENT_TYPE,
"EndpointName": ENDPOINT,
}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args


@patch("sagemaker.predictor.name_from_base")
Expand Down