Skip to content

Commit 1487b22

Browse files
authored
feature: add BaseSerializer and BaseDeserializer (#1668)
1 parent 15f5358 commit 1487b22

File tree

3 files changed

+185
-15
lines changed

3 files changed

+185
-15
lines changed

src/sagemaker/deserializers.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements methods for deserializing data returned from an inference endpoint."""
14+
from __future__ import absolute_import
15+
16+
import abc
17+
18+
19+
class BaseDeserializer(abc.ABC):
20+
"""Abstract base class for creation of new deserializers.
21+
22+
Provides a skeleton for customization requiring the overriding of the method
23+
deserialize and the class attribute ACCEPT.
24+
"""
25+
26+
@abc.abstractmethod
27+
def deserialize(self, data, content_type):
28+
"""Deserialize data received from an inference endpoint.
29+
30+
Args:
31+
data (object): Data to be deserialized.
32+
content_type (str): The MIME type of the data.
33+
34+
Returns:
35+
object: The data deserialized into an object.
36+
"""
37+
38+
@property
39+
@abc.abstractmethod
40+
def ACCEPT(self):
41+
"""The content type that is expected from the inference endpoint."""

src/sagemaker/predictor.py

+104-15
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@
2121
import numpy as np
2222

2323
from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
24+
from sagemaker.deserializers import BaseDeserializer
2425
from sagemaker.model_monitor import DataCaptureConfig
26+
from sagemaker.serializers import BaseSerializer
2527
from sagemaker.session import production_variant, Session
2628
from sagemaker.utils import name_from_base
2729

@@ -59,27 +61,28 @@ def __init__(
5961
object, used for SageMaker interactions (default: None). If not
6062
specified, one is created using the default AWS configuration
6163
chain.
62-
serializer (callable): Accepts a single argument, the input data,
63-
and returns a sequence of bytes. It may provide a
64-
``content_type`` attribute that defines the endpoint request
65-
content type. If not specified, a sequence of bytes is expected
66-
for the data.
67-
deserializer (callable): Accepts two arguments, the result data and
68-
the response content type, and returns a sequence of bytes. It
69-
may provide a ``content_type`` attribute that defines the
70-
endpoint response's "Accept" content type. If not specified, a
71-
sequence of bytes is expected for the data.
64+
serializer (sagemaker.serializers.BaseSerializer): A serializer
65+
object, used to encode data for an inference endpoint
66+
(default: None).
67+
deserializer (sagemaker.deserializers.BaseDeserializer): A
68+
deserializer object, used to decode data from an inference
69+
endpoint (default: None).
7270
content_type (str): The invocation's "ContentType", overriding any
73-
``content_type`` from the serializer (default: None).
71+
``CONTENT_TYPE`` from the serializer (default: None).
7472
accept (str): The invocation's "Accept", overriding any accept from
7573
the deserializer (default: None).
7674
"""
75+
if serializer is not None and not isinstance(serializer, BaseSerializer):
76+
serializer = LegacySerializer(serializer)
77+
if deserializer is not None and not isinstance(deserializer, BaseDeserializer):
78+
deserializer = LegacyDeserializer(deserializer)
79+
7780
self.endpoint_name = endpoint_name
7881
self.sagemaker_session = sagemaker_session or Session()
7982
self.serializer = serializer
8083
self.deserializer = deserializer
81-
self.content_type = content_type or getattr(serializer, "content_type", None)
82-
self.accept = accept or getattr(deserializer, "accept", None)
84+
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None)
85+
self.accept = accept or getattr(deserializer, "ACCEPT", None)
8386
self._endpoint_config_name = self._get_endpoint_config_name()
8487
self._model_names = self._get_model_names()
8588

@@ -120,8 +123,10 @@ def _handle_response(self, response):
120123
"""
121124
response_body = response["Body"]
122125
if self.deserializer is not None:
126+
if not isinstance(self.deserializer, BaseDeserializer):
127+
self.deserializer = LegacyDeserializer(self.deserializer)
123128
# It's the deserializer's responsibility to close the stream
124-
return self.deserializer(response_body, response["ContentType"])
129+
return self.deserializer.deserialize(response_body, response["ContentType"])
125130
data = response_body.read()
126131
response_body.close()
127132
return data
@@ -152,7 +157,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
152157
args["TargetVariant"] = target_variant
153158

154159
if self.serializer is not None:
155-
data = self.serializer(data)
160+
if not isinstance(self.serializer, BaseSerializer):
161+
self.serializer = LegacySerializer(self.serializer)
162+
data = self.serializer.serialize(data)
156163

157164
args["Body"] = data
158165
return args
@@ -406,6 +413,88 @@ def _get_model_names(self):
406413
return [d["ModelName"] for d in production_variants]
407414

408415

416+
class LegacySerializer(BaseSerializer):
417+
"""Wrapper that makes legacy serializers forward compatibile."""
418+
419+
def __init__(self, serializer):
420+
"""Initialize a ``LegacySerializer``.
421+
422+
Args:
423+
serializer (callable): A legacy serializer.
424+
"""
425+
self.serializer = serializer
426+
self.content_type = getattr(serializer, "content_type", None)
427+
428+
def __call__(self, *args, **kwargs):
429+
"""Wraps the call method of the legacy serializer.
430+
431+
Args:
432+
data (object): Data to be serialized.
433+
434+
Returns:
435+
object: Serialized data used for a request.
436+
"""
437+
return self.serializer(*args, **kwargs)
438+
439+
def serialize(self, data):
440+
"""Wraps the call method of the legacy serializer.
441+
442+
Args:
443+
data (object): Data to be serialized.
444+
445+
Returns:
446+
object: Serialized data used for a request.
447+
"""
448+
return self.serializer(data)
449+
450+
@property
451+
def CONTENT_TYPE(self):
452+
"""The MIME type of the data sent to the inference endpoint."""
453+
return self.content_type
454+
455+
456+
class LegacyDeserializer(BaseDeserializer):
457+
"""Wrapper that makes legacy deserializers forward compatibile."""
458+
459+
def __init__(self, deserializer):
460+
"""Initialize a ``LegacyDeserializer``.
461+
462+
Args:
463+
deserializer (callable): A legacy deserializer.
464+
"""
465+
self.deserializer = deserializer
466+
self.accept = getattr(deserializer, "accept", None)
467+
468+
def __call__(self, *args, **kwargs):
469+
"""Wraps the call method of the legacy deserializer.
470+
471+
Args:
472+
data (object): Data to be deserialized.
473+
content_type (str): The MIME type of the data.
474+
475+
Returns:
476+
object: The data deserialized into an object.
477+
"""
478+
return self.deserializer(*args, **kwargs)
479+
480+
def deserialize(self, data, content_type):
481+
"""Wraps the call method of the legacy deserializer.
482+
483+
Args:
484+
data (object): Data to be deserialized.
485+
content_type (str): The MIME type of the data.
486+
487+
Returns:
488+
object: The data deserialized into an object.
489+
"""
490+
return self.deserializer(data, content_type)
491+
492+
@property
493+
def ACCEPT(self):
494+
"""The content type that is expected from the inference endpoint."""
495+
return self.accept
496+
497+
409498
class _CsvSerializer(object):
410499
"""Placeholder docstring"""
411500

src/sagemaker/serializers.py

+40
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2017-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Implements methods for serializing data for an inference endpoint."""
14+
from __future__ import absolute_import
15+
16+
import abc
17+
18+
19+
class BaseSerializer(abc.ABC):
20+
"""Abstract base class for creation of new serializers.
21+
22+
Provides a skeleton for customization requiring the overriding of the method
23+
serialize and the class attribute CONTENT_TYPE.
24+
"""
25+
26+
@abc.abstractmethod
27+
def serialize(self, data):
28+
"""Serialize data into the media type specified by CONTENT_TYPE.
29+
30+
Args:
31+
data (object): Data to be serialized.
32+
33+
Returns:
34+
object: Serialized data used for a request.
35+
"""
36+
37+
@property
38+
@abc.abstractmethod
39+
def CONTENT_TYPE(self):
40+
"""The MIME type of the data sent to the inference endpoint."""

0 commit comments

Comments
 (0)