|
21 | 21 | import numpy as np
|
22 | 22 |
|
23 | 23 | from sagemaker.content_types import CONTENT_TYPE_JSON, CONTENT_TYPE_CSV, CONTENT_TYPE_NPY
|
| 24 | +from sagemaker.deserializers import BaseDeserializer |
24 | 25 | from sagemaker.model_monitor import DataCaptureConfig
|
| 26 | +from sagemaker.serializers import BaseSerializer |
25 | 27 | from sagemaker.session import production_variant, Session
|
26 | 28 | from sagemaker.utils import name_from_base
|
27 | 29 |
|
@@ -59,27 +61,28 @@ def __init__(
|
59 | 61 | object, used for SageMaker interactions (default: None). If not
|
60 | 62 | specified, one is created using the default AWS configuration
|
61 | 63 | chain.
|
62 |
| - serializer (callable): Accepts a single argument, the input data, |
63 |
| - and returns a sequence of bytes. It may provide a |
64 |
| - ``content_type`` attribute that defines the endpoint request |
65 |
| - content type. If not specified, a sequence of bytes is expected |
66 |
| - for the data. |
67 |
| - deserializer (callable): Accepts two arguments, the result data and |
68 |
| - the response content type, and returns a sequence of bytes. It |
69 |
| - may provide a ``content_type`` attribute that defines the |
70 |
| - endpoint response's "Accept" content type. If not specified, a |
71 |
| - sequence of bytes is expected for the data. |
| 64 | + serializer (sagemaker.serializers.BaseSerializer): A serializer |
| 65 | + object, used to encode data for an inference endpoint |
| 66 | + (default: None). |
| 67 | + deserializer (sagemaker.deserializers.BaseDeserializer): A |
| 68 | + deserializer object, used to decode data from an inference |
| 69 | + endpoint (default: None). |
72 | 70 | content_type (str): The invocation's "ContentType", overriding any
|
73 |
| - ``content_type`` from the serializer (default: None). |
| 71 | + ``CONTENT_TYPE`` from the serializer (default: None). |
74 | 72 | accept (str): The invocation's "Accept", overriding any accept from
|
75 | 73 | the deserializer (default: None).
|
76 | 74 | """
|
| 75 | + if serializer is not None and not isinstance(serializer, BaseSerializer): |
| 76 | + serializer = LegacySerializer(serializer) |
| 77 | + if deserializer is not None and not isinstance(deserializer, BaseDeserializer): |
| 78 | + deserializer = LegacyDeserializer(deserializer) |
| 79 | + |
77 | 80 | self.endpoint_name = endpoint_name
|
78 | 81 | self.sagemaker_session = sagemaker_session or Session()
|
79 | 82 | self.serializer = serializer
|
80 | 83 | self.deserializer = deserializer
|
81 |
| - self.content_type = content_type or getattr(serializer, "content_type", None) |
82 |
| - self.accept = accept or getattr(deserializer, "accept", None) |
| 84 | + self.content_type = content_type or getattr(serializer, "CONTENT_TYPE", None) |
| 85 | + self.accept = accept or getattr(deserializer, "ACCEPT", None) |
83 | 86 | self._endpoint_config_name = self._get_endpoint_config_name()
|
84 | 87 | self._model_names = self._get_model_names()
|
85 | 88 |
|
@@ -120,8 +123,10 @@ def _handle_response(self, response):
|
120 | 123 | """
|
121 | 124 | response_body = response["Body"]
|
122 | 125 | if self.deserializer is not None:
|
| 126 | + if not isinstance(self.deserializer, BaseDeserializer): |
| 127 | + self.deserializer = LegacyDeserializer(self.deserializer) |
123 | 128 | # It's the deserializer's responsibility to close the stream
|
124 |
| - return self.deserializer(response_body, response["ContentType"]) |
| 129 | + return self.deserializer.deserialize(response_body, response["ContentType"]) |
125 | 130 | data = response_body.read()
|
126 | 131 | response_body.close()
|
127 | 132 | return data
|
@@ -152,7 +157,9 @@ def _create_request_args(self, data, initial_args=None, target_model=None, targe
|
152 | 157 | args["TargetVariant"] = target_variant
|
153 | 158 |
|
154 | 159 | if self.serializer is not None:
|
155 |
| - data = self.serializer(data) |
| 160 | + if not isinstance(self.serializer, BaseSerializer): |
| 161 | + self.serializer = LegacySerializer(self.serializer) |
| 162 | + data = self.serializer.serialize(data) |
156 | 163 |
|
157 | 164 | args["Body"] = data
|
158 | 165 | return args
|
@@ -406,6 +413,88 @@ def _get_model_names(self):
|
406 | 413 | return [d["ModelName"] for d in production_variants]
|
407 | 414 |
|
408 | 415 |
|
| 416 | +class LegacySerializer(BaseSerializer): |
| 417 | + """Wrapper that makes legacy serializers forward compatibile.""" |
| 418 | + |
| 419 | + def __init__(self, serializer): |
| 420 | + """Initialize a ``LegacySerializer``. |
| 421 | +
|
| 422 | + Args: |
| 423 | + serializer (callable): A legacy serializer. |
| 424 | + """ |
| 425 | + self.serializer = serializer |
| 426 | + self.content_type = getattr(serializer, "content_type", None) |
| 427 | + |
| 428 | + def __call__(self, *args, **kwargs): |
| 429 | + """Wraps the call method of the legacy serializer. |
| 430 | +
|
| 431 | + Args: |
| 432 | + data (object): Data to be serialized. |
| 433 | +
|
| 434 | + Returns: |
| 435 | + object: Serialized data used for a request. |
| 436 | + """ |
| 437 | + return self.serializer(*args, **kwargs) |
| 438 | + |
| 439 | + def serialize(self, data): |
| 440 | + """Wraps the call method of the legacy serializer. |
| 441 | +
|
| 442 | + Args: |
| 443 | + data (object): Data to be serialized. |
| 444 | +
|
| 445 | + Returns: |
| 446 | + object: Serialized data used for a request. |
| 447 | + """ |
| 448 | + return self.serializer(data) |
| 449 | + |
| 450 | + @property |
| 451 | + def CONTENT_TYPE(self): |
| 452 | + """The MIME type of the data sent to the inference endpoint.""" |
| 453 | + return self.content_type |
| 454 | + |
| 455 | + |
| 456 | +class LegacyDeserializer(BaseDeserializer): |
| 457 | + """Wrapper that makes legacy deserializers forward compatibile.""" |
| 458 | + |
| 459 | + def __init__(self, deserializer): |
| 460 | + """Initialize a ``LegacyDeserializer``. |
| 461 | +
|
| 462 | + Args: |
| 463 | + deserializer (callable): A legacy deserializer. |
| 464 | + """ |
| 465 | + self.deserializer = deserializer |
| 466 | + self.accept = getattr(deserializer, "accept", None) |
| 467 | + |
| 468 | + def __call__(self, *args, **kwargs): |
| 469 | + """Wraps the call method of the legacy deserializer. |
| 470 | +
|
| 471 | + Args: |
| 472 | + data (object): Data to be deserialized. |
| 473 | + content_type (str): The MIME type of the data. |
| 474 | +
|
| 475 | + Returns: |
| 476 | + object: The data deserialized into an object. |
| 477 | + """ |
| 478 | + return self.deserializer(*args, **kwargs) |
| 479 | + |
| 480 | + def deserialize(self, data, content_type): |
| 481 | + """Wraps the call method of the legacy deserializer. |
| 482 | +
|
| 483 | + Args: |
| 484 | + data (object): Data to be deserialized. |
| 485 | + content_type (str): The MIME type of the data. |
| 486 | +
|
| 487 | + Returns: |
| 488 | + object: The data deserialized into an object. |
| 489 | + """ |
| 490 | + return self.deserializer(data, content_type) |
| 491 | + |
| 492 | + @property |
| 493 | + def ACCEPT(self): |
| 494 | + """The content type that is expected from the inference endpoint.""" |
| 495 | + return self.accept |
| 496 | + |
| 497 | + |
409 | 498 | class _CsvSerializer(object):
|
410 | 499 | """Placeholder docstring"""
|
411 | 500 |
|
|
0 commit comments