Skip to content

Commit 666910c

Browse files
authored
breaking: Remove content_type and accept parameters from Predictor (aws#1751)
1 parent 284eddc commit 666910c

File tree

15 files changed

+130
-132
lines changed

15 files changed

+130
-132
lines changed

doc/v2.rst

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,18 @@ Please instantiate the objects instead.
9494
The ``update_endpoint`` argument in ``deploy()`` methods for estimators and models has been deprecated.
9595
Please use :func:`sagemaker.predictor.Predictor.update_endpoint` instead.
9696

97+
``content_type`` and ``accept`` in the Predictor Constructor
98+
------------------------------------------------------------
99+
100+
The ``content_type`` and ``accept`` parameters have been removed from the
101+
following classes and methods:
102+
- ``sagemaker.predictor.Predictor``
103+
- ``sagemaker.estimator.Estimator.create_model``
104+
- ``sagemaker.algorithms.AlgorithmEstimator.create_model``
105+
- ``sagemaker.tensorflow.model.TensorFlowPredictor``
106+
107+
Please specify content types in a serializer or deserializer class instead.
108+
97109
``sagemaker.content_types``
98110
---------------------------
99111

@@ -115,7 +127,6 @@ write MIME types as a string,
115127
| ``CONTENT_TYPE_NPY`` | ``"application/x-npy"`` |
116128
+-------------------------------+--------------------------------+
117129

118-
119130
Require ``framework_version`` and ``py_version`` for Frameworks
120131
===============================================================
121132

src/sagemaker/algorithm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@
1616
import sagemaker
1717
import sagemaker.parameter
1818
from sagemaker import vpc_utils
19+
from sagemaker.deserializers import BytesDeserializer
1920
from sagemaker.estimator import EstimatorBase
21+
from sagemaker.serializers import IdentitySerializer
2022
from sagemaker.transformer import Transformer
2123
from sagemaker.predictor import Predictor
2224

@@ -251,37 +253,29 @@ def create_model(
251253
self,
252254
role=None,
253255
predictor_cls=None,
254-
serializer=None,
255-
deserializer=None,
256-
content_type=None,
257-
accept=None,
256+
serializer=IdentitySerializer(),
257+
deserializer=BytesDeserializer(),
258258
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
259259
**kwargs
260260
):
261261
"""Create a model to deploy.
262262
263-
The serializer, deserializer, content_type, and accept arguments are
264-
only used to define a default Predictor. They are ignored if an
265-
explicit predictor class is passed in. Other arguments are passed
266-
through to the Model class.
263+
The serializer and deserializer are only used to define a default
264+
Predictor. They are ignored if an explicit predictor class is passed in.
265+
Other arguments are passed through to the Model class.
267266
268267
Args:
269268
role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``,
270269
which is also used during transform jobs. If not specified, the
271270
role from the Estimator will be used.
272271
predictor_cls (Predictor): The predictor class to use when
273272
deploying the model.
274-
serializer (callable): Should accept a single argument, the input
275-
data, and return a sequence of bytes. May provide a content_type
276-
attribute that defines the endpoint request content type
277-
deserializer (callable): Should accept two arguments, the result
278-
data and the response content type, and return a sequence of
279-
bytes. May provide a content_type attribute that defines the
280-
endpoint response Accept content type.
281-
content_type (str): The invocation ContentType, overriding any
282-
content_type from the serializer
283-
accept (str): The invocation Accept, overriding any accept from the
284-
deserializer.
273+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
274+
serializer object, used to encode data for an inference endpoint
275+
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
276+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
277+
deserializer object, used to decode data from an inference
278+
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
285279
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
286280
the model. Default: use subnets and security groups from this Estimator.
287281
* 'Subnets' (list[str]): List of subnet ids.
@@ -300,7 +294,7 @@ def create_model(
300294
if predictor_cls is None:
301295

302296
def predict_wrapper(endpoint, session):
303-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
297+
return Predictor(endpoint, session, serializer, deserializer)
304298

305299
predictor_cls = predict_wrapper
306300

src/sagemaker/estimator.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@
2929
from sagemaker.debugger import DebuggerHookConfig
3030
from sagemaker.debugger import TensorBoardOutputConfig # noqa: F401 # pylint: disable=unused-import
3131
from sagemaker.debugger import get_rule_container_image_uri
32+
from sagemaker.deserializers import BytesDeserializer
3233
from sagemaker.s3 import S3Uploader, parse_s3_url
34+
from sagemaker.serializers import IdentitySerializer
3335

3436
from sagemaker.fw_utils import (
3537
tar_and_upload_dir,
@@ -1340,16 +1342,14 @@ def create_model(
13401342
role=None,
13411343
image_uri=None,
13421344
predictor_cls=None,
1343-
serializer=None,
1344-
deserializer=None,
1345-
content_type=None,
1346-
accept=None,
1345+
serializer=IdentitySerializer(),
1346+
deserializer=BytesDeserializer(),
13471347
vpc_config_override=vpc_utils.VPC_CONFIG_DEFAULT,
13481348
**kwargs
13491349
):
13501350
"""Create a model to deploy.
13511351
1352-
The serializer, deserializer, content_type, and accept arguments are only used to define a
1352+
The serializer and deserializer arguments are only used to define a
13531353
default Predictor. They are ignored if an explicit predictor class is passed in.
13541354
Other arguments are passed through to the Model class.
13551355
@@ -1361,17 +1361,12 @@ def create_model(
13611361
Defaults to the image used for training.
13621362
predictor_cls (Predictor): The predictor class to use when
13631363
deploying the model.
1364-
serializer (callable): Should accept a single argument, the input
1365-
data, and return a sequence of bytes. May provide a content_type
1366-
attribute that defines the endpoint request content type
1367-
deserializer (callable): Should accept two arguments, the result
1368-
data and the response content type, and return a sequence of
1369-
bytes. May provide a content_type attribute that defines th
1370-
endpoint response Accept content type.
1371-
content_type (str): The invocation ContentType, overriding any
1372-
content_type from the serializer
1373-
accept (str): The invocation Accept, overriding any accept from the
1374-
deserializer.
1364+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
1365+
serializer object, used to encode data for an inference endpoint
1366+
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
1367+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
1368+
deserializer object, used to decode data from an inference
1369+
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
13751370
vpc_config_override (dict[str, list[str]]): Optional override for VpcConfig set on
13761371
the model.
13771372
Default: use subnets and security groups from this Estimator.
@@ -1390,7 +1385,7 @@ def create_model(
13901385
if predictor_cls is None:
13911386

13921387
def predict_wrapper(endpoint, session):
1393-
return Predictor(endpoint, session, serializer, deserializer, content_type, accept)
1388+
return Predictor(endpoint, session, serializer, deserializer)
13941389

13951390
predictor_cls = predict_wrapper
13961391

src/sagemaker/predictor.py

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
"""Placeholder docstring"""
1414
from __future__ import print_function, absolute_import
1515

16+
from sagemaker.deserializers import BytesDeserializer
1617
from sagemaker.model_monitor import DataCaptureConfig
18+
from sagemaker.serializers import IdentitySerializer
1719
from sagemaker.session import production_variant, Session
1820
from sagemaker.utils import name_from_base
1921

@@ -31,10 +33,8 @@ def __init__(
3133
self,
3234
endpoint_name,
3335
sagemaker_session=None,
34-
serializer=None,
35-
deserializer=None,
36-
content_type=None,
37-
accept=None,
36+
serializer=IdentitySerializer(),
37+
deserializer=BytesDeserializer(),
3838
):
3939
"""Initialize a ``Predictor``.
4040
@@ -51,23 +51,19 @@ def __init__(
5151
object, used for SageMaker interactions (default: None). If not
5252
specified, one is created using the default AWS configuration
5353
chain.
54-
serializer (sagemaker.serializers.BaseSerializer): A serializer
55-
object, used to encode data for an inference endpoint
56-
(default: None).
57-
deserializer (sagemaker.deserializers.BaseDeserializer): A
54+
serializer (:class:`~sagemaker.serializers.BaseSerializer`): A
55+
serializer object, used to encode data for an inference endpoint
56+
(default: :class:`~sagemaker.serializers.IdentitySerializer`).
57+
deserializer (:class:`~sagemaker.deserializers.BaseDeserializer`): A
5858
deserializer object, used to decode data from an inference
59-
endpoint (default: None).
60-
content_type (str): The invocation's "ContentType", overriding any
61-
``CONTENT_TYPE`` from the serializer (default: None).
62-
accept (str): The invocation's "Accept", overriding any accept from
63-
the deserializer (default: None).
59+
endpoint (default: :class:`~sagemaker.deserializers.BytesDeserializer`).
6460
"""
6561
self.endpoint_name = endpoint_name
6662
self.sagemaker_session = sagemaker_session or Session()
6763
self.serializer = serializer
6864
self.deserializer = deserializer
69-
self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None)
70-
self.accept = accept or getattr(deserializer, "ACCEPT", None)
65+
self.content_type = serializer.CONTENT_TYPE
66+
self.accept = deserializer.ACCEPT
7167
self._endpoint_config_name = self._get_endpoint_config_name()
7268
self._model_names = self._get_model_names()
7369

@@ -107,12 +103,8 @@ def _handle_response(self, response):
107103
response:
108104
"""
109105
response_body = response["Body"]
110-
if self.deserializer is not None:
111-
# It's the deserializer's responsibility to close the stream
112-
return self.deserializer.deserialize(response_body, response["ContentType"])
113-
data = response_body.read()
114-
response_body.close()
115-
return data
106+
content_type = response.get("ContentType", "application/octet-stream")
107+
return self.deserializer.deserialize(response_body, content_type)
116108

117109
def _create_request_args(self, data, initial_args=None, target_model=None, target_variant=None):
118110
"""
@@ -127,10 +119,10 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
127119
if "EndpointName" not in args:
128120
args["EndpointName"] = self.endpoint_name
129121

130-
if self.content_type and "ContentType" not in args:
122+
if "ContentType" not in args:
131123
args["ContentType"] = self.content_type
132124

133-
if self.accept and "Accept" not in args:
125+
if "Accept" not in args:
134126
args["Accept"] = self.accept
135127

136128
if target_model:
@@ -139,8 +131,7 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
139131
if target_variant:
140132
args["TargetVariant"] = target_variant
141133

142-
if self.serializer is not None:
143-
data = self.serializer.serialize(data)
134+
data = self.serializer.serialize(data)
144135

145136
args["Body"] = data
146137
return args

src/sagemaker/serializers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,35 @@ def serialize(self, data):
193193
return json.dumps(data)
194194

195195

196+
class IdentitySerializer(BaseSerializer):
197+
"""Serialize data by returning data without modification."""
198+
199+
def __init__(self, content_type="application/octet-stream"):
200+
"""Initialize the ``content_type`` attribute.
201+
202+
Args:
203+
content_type (str): The MIME type of the serialized data (default:
204+
"application/octet-stream").
205+
"""
206+
self.content_type = content_type
207+
208+
def serialize(self, data):
209+
"""Return data without modification.
210+
211+
Args:
212+
data (object): Data to be serialized.
213+
214+
Returns:
215+
object: The unmodified data.
216+
"""
217+
return data
218+
219+
@property
220+
def CONTENT_TYPE(self):
221+
"""The MIME type of the data sent to the inference endpoint."""
222+
return self.content_type
223+
224+
196225
class JSONLinesSerializer(BaseSerializer):
197226
"""Serialize data to a JSON Lines formatted string."""
198227

src/sagemaker/sparkml/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ def __init__(self, endpoint_name, sagemaker_session=None):
4949
endpoint_name=endpoint_name,
5050
sagemaker_session=sagemaker_session,
5151
serializer=CSVSerializer(),
52-
content_type="text/csv",
5352
)
5453

5554

src/sagemaker/tensorflow/model.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def __init__(
3333
sagemaker_session=None,
3434
serializer=JSONSerializer(),
3535
deserializer=JSONDeserializer(),
36-
content_type=None,
3736
model_name=None,
3837
model_version=None,
3938
):
@@ -51,9 +50,6 @@ def __init__(
5150
json. Handles dicts, lists, and numpy arrays.
5251
deserializer (callable): Optional. Default parses the response using
5352
``json.load(...)``.
54-
content_type (str): Optional. The "ContentType" for invocation
55-
requests. If specified, overrides the ``content_type`` from the
56-
serializer (default: None).
5753
model_name (str): Optional. The name of the SavedModel model that
5854
should handle the request. If not specified, the endpoint's
5955
default model will handle the request.
@@ -62,7 +58,7 @@ def __init__(
6258
version of the model will be used.
6359
"""
6460
super(TensorFlowPredictor, self).__init__(
65-
endpoint_name, sagemaker_session, serializer, deserializer, content_type
61+
endpoint_name, sagemaker_session, serializer, deserializer,
6662
)
6763

6864
attributes = []

tests/integ/test_byo_estimator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ def test_byo_estimator(sagemaker_session, region, cpu_instance_type, training_se
9191
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
9292
predictor.serializer = _FactorizationMachineSerializer()
9393
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
94+
predictor.content_type = "application/json"
9495

9596
result = predictor.predict(training_set[0][:10])
9697

@@ -136,6 +137,7 @@ def test_async_byo_estimator(sagemaker_session, region, cpu_instance_type, train
136137
predictor = model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
137138
predictor.serializer = _FactorizationMachineSerializer()
138139
predictor.deserializer = sagemaker.deserializers.JSONDeserializer()
140+
predictor.content_type = "application/json"
139141

140142
result = predictor.predict(training_set[0][:10])
141143

tests/integ/test_multi_variant_endpoint.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import scipy.stats as st
2121

2222
from sagemaker import image_uris
23+
from sagemaker.deserializers import CSVDeserializer
2324
from sagemaker.s3 import S3Uploader
2425
from sagemaker.session import production_variant
2526
from sagemaker.sparkml import SparkMLModel
@@ -173,8 +174,6 @@ def test_predict_invocation_with_target_variant(sagemaker_session, multi_variant
173174
endpoint_name=multi_variant_endpoint.endpoint_name,
174175
sagemaker_session=sagemaker_session,
175176
serializer=CSVSerializer(),
176-
content_type="text/csv",
177-
accept="text/csv",
178177
)
179178

180179
# Validate that no exception is raised when the target_variant is specified.
@@ -301,8 +300,7 @@ def test_predict_invocation_with_target_variant_local_mode(
301300
endpoint_name=multi_variant_endpoint.endpoint_name,
302301
sagemaker_session=sagemaker_session,
303302
serializer=CSVSerializer(),
304-
content_type="text/csv",
305-
accept="text/csv",
303+
deserializer=CSVDeserializer(),
306304
)
307305

308306
# Validate that no exception is raised when the target_variant is specified.

0 commit comments

Comments
 (0)