Skip to content

Commit 9e44f84

Browse files
pintaoz-awspintaoz
and
pintaoz
authored
Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserialzers (#5037)
* Move RecordSerializer and RecordDeserializer to sagemaker.serializers and sagemaker.deserializers * fix codestyle * fix test --------- Co-authored-by: pintaoz <[email protected]>
1 parent 18897d7 commit 9e44f84

File tree

17 files changed

+114
-110
lines changed

17 files changed

+114
-110
lines changed

doc/v2.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,9 @@ The follow serializer/deserializer classes have been renamed and/or moved:
324324
+--------------------------------------------------------+-------------------------------------------------------+
325325
| ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` |
326326
+--------------------------------------------------------+-------------------------------------------------------+
327-
| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.common.RecordSerializer`` |
327+
| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.serializers.RecordSerializer`` |
328328
+--------------------------------------------------------+-------------------------------------------------------+
329-
| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.common.RecordDeserializer`` |
329+
| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.deserializers.RecordDeserializer`` |
330330
+--------------------------------------------------------+-------------------------------------------------------+
331331
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
332332
+--------------------------------------------------------+-------------------------------------------------------+

src/sagemaker/amazon/common.py

Lines changed: 0 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -13,84 +13,16 @@
1313
"""Placeholder docstring"""
1414
from __future__ import absolute_import
1515

16-
import io
1716
import logging
1817
import struct
1918
import sys
2019

2120
import numpy as np
2221

2322
from sagemaker.amazon.record_pb2 import Record
24-
from sagemaker.deprecations import deprecated_class
25-
from sagemaker.deserializers import SimpleBaseDeserializer
26-
from sagemaker.serializers import SimpleBaseSerializer
2723
from sagemaker.utils import DeferredError
2824

2925

30-
class RecordSerializer(SimpleBaseSerializer):
31-
"""Serialize a NumPy array for an inference request."""
32-
33-
def __init__(self, content_type="application/x-recordio-protobuf"):
34-
"""Initialize a ``RecordSerializer`` instance.
35-
36-
Args:
37-
content_type (str): The MIME type to signal to the inference endpoint when sending
38-
request data (default: "application/x-recordio-protobuf").
39-
"""
40-
super(RecordSerializer, self).__init__(content_type=content_type)
41-
42-
def serialize(self, data):
43-
"""Serialize a NumPy array into a buffer containing RecordIO records.
44-
45-
Args:
46-
data (numpy.ndarray): The data to serialize.
47-
48-
Returns:
49-
io.BytesIO: A buffer containing the data serialized as records.
50-
"""
51-
if len(data.shape) == 1:
52-
data = data.reshape(1, data.shape[0])
53-
54-
if len(data.shape) != 2:
55-
raise ValueError(
56-
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
57-
)
58-
59-
buffer = io.BytesIO()
60-
write_numpy_to_dense_tensor(buffer, data)
61-
buffer.seek(0)
62-
63-
return buffer
64-
65-
66-
class RecordDeserializer(SimpleBaseDeserializer):
67-
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
68-
69-
def __init__(self, accept="application/x-recordio-protobuf"):
70-
"""Initialize a ``RecordDeserializer`` instance.
71-
72-
Args:
73-
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
74-
is expected from the inference endpoint (default:
75-
"application/x-recordio-protobuf").
76-
"""
77-
super(RecordDeserializer, self).__init__(accept=accept)
78-
79-
def deserialize(self, data, content_type):
80-
"""Deserialize RecordIO Protobuf data from an inference endpoint.
81-
82-
Args:
83-
data (object): The protobuf message to deserialize.
84-
content_type (str): The MIME type of the data.
85-
Returns:
86-
list: A list of records.
87-
"""
88-
try:
89-
return read_records(data)
90-
finally:
91-
data.close()
92-
93-
9426
def _write_feature_tensor(resolved_type, record, vector):
9527
"""Placeholder Docstring"""
9628
if resolved_type == "Int32":
@@ -288,7 +220,3 @@ def _resolve_type(dtype):
288220
if dtype == np.dtype("float32"):
289221
return "Float32"
290222
raise ValueError("Unsupported dtype {} on array".format(dtype))
291-
292-
293-
numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer")
294-
record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")

src/sagemaker/amazon/factorization_machines.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import gt, isin, ge
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/kmeans.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import gt, isin, ge, le
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/knn.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import ge, isin
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/lda.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
from sagemaker import image_uris
2020
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
21-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
21+
from sagemaker.deserializers import RecordDeserializer
2222
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2323
from sagemaker.amazon.validation import gt
2424
from sagemaker.predictor import Predictor
2525
from sagemaker.model import Model
26+
from sagemaker.serializers import RecordSerializer
2627
from sagemaker.session import Session
2728
from sagemaker.utils import pop_out_unused_kwarg
2829
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/linear_learner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@
1818

1919
from sagemaker import image_uris
2020
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
21-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
21+
from sagemaker.deserializers import RecordDeserializer
2222
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2323
from sagemaker.amazon.validation import isin, gt, lt, ge, le
2424
from sagemaker.predictor import Predictor
2525
from sagemaker.model import Model
26+
from sagemaker.serializers import RecordSerializer
2627
from sagemaker.session import Session
2728
from sagemaker.utils import pop_out_unused_kwarg
2829
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/ntm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import ge, le, isin
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/pca.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import gt, isin
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/amazon/randomcutforest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
from sagemaker import image_uris
1919
from sagemaker.amazon.amazon_estimator import AmazonAlgorithmEstimatorBase
20-
from sagemaker.amazon.common import RecordSerializer, RecordDeserializer
2120
from sagemaker.amazon.hyperparameter import Hyperparameter as hp # noqa
2221
from sagemaker.amazon.validation import ge, le
22+
from sagemaker.deserializers import RecordDeserializer
2323
from sagemaker.predictor import Predictor
2424
from sagemaker.model import Model
25+
from sagemaker.serializers import RecordSerializer
2526
from sagemaker.session import Session
2627
from sagemaker.utils import pop_out_unused_kwarg
2728
from sagemaker.vpc_utils import VPC_CONFIG_DEFAULT

src/sagemaker/base_deserializers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import numpy as np
2424
from six import with_metaclass
2525

26+
from sagemaker.amazon.common import read_records
2627
from sagemaker.utils import DeferredError
2728

2829
try:
@@ -388,3 +389,31 @@ def deserialize(self, stream, content_type="tensor/pt"):
388389
"Unable to deserialize your data to torch.Tensor.\
389390
Please provide custom deserializer in InferenceSpec."
390391
)
392+
393+
394+
class RecordDeserializer(SimpleBaseDeserializer):
395+
"""Deserialize RecordIO Protobuf data from an inference endpoint."""
396+
397+
def __init__(self, accept="application/x-recordio-protobuf"):
398+
"""Initialize a ``RecordDeserializer`` instance.
399+
400+
Args:
401+
accept (union[str, tuple[str]]): The MIME type (or tuple of allowable MIME types) that
402+
is expected from the inference endpoint (default:
403+
"application/x-recordio-protobuf").
404+
"""
405+
super(RecordDeserializer, self).__init__(accept=accept)
406+
407+
def deserialize(self, data, content_type):
408+
"""Deserialize RecordIO Protobuf data from an inference endpoint.
409+
410+
Args:
411+
data (object): The protobuf message to deserialize.
412+
content_type (str): The MIME type of the data.
413+
Returns:
414+
list: A list of records.
415+
"""
416+
try:
417+
return read_records(data)
418+
finally:
419+
data.close()

src/sagemaker/base_serializers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pandas import DataFrame
2323
from six import with_metaclass
2424

25+
from sagemaker.amazon.common import write_numpy_to_dense_tensor
2526
from sagemaker.utils import DeferredError
2627

2728
try:
@@ -466,3 +467,39 @@ def serialize(self, data):
466467
)
467468

468469
raise ValueError("Object of type %s is not a torch.Tensor" % type(data))
470+
471+
472+
class RecordSerializer(SimpleBaseSerializer):
473+
"""Serialize a NumPy array for an inference request."""
474+
475+
def __init__(self, content_type="application/x-recordio-protobuf"):
476+
"""Initialize a ``RecordSerializer`` instance.
477+
478+
Args:
479+
content_type (str): The MIME type to signal to the inference endpoint when sending
480+
request data (default: "application/x-recordio-protobuf").
481+
"""
482+
super(RecordSerializer, self).__init__(content_type=content_type)
483+
484+
def serialize(self, data):
485+
"""Serialize a NumPy array into a buffer containing RecordIO records.
486+
487+
Args:
488+
data (numpy.ndarray): The data to serialize.
489+
490+
Returns:
491+
io.BytesIO: A buffer containing the data serialized as records.
492+
"""
493+
if len(data.shape) == 1:
494+
data = data.reshape(1, data.shape[0])
495+
496+
if len(data.shape) != 2:
497+
raise ValueError(
498+
"Expected a 1D or 2D array, but got a %dD array instead." % len(data.shape)
499+
)
500+
501+
buffer = io.BytesIO()
502+
write_numpy_to_dense_tensor(buffer, data)
503+
buffer.seek(0)
504+
505+
return buffer

src/sagemaker/cli/compatibility/v2/modifiers/serde.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@
5151
"StreamDeserializer": ("sagemaker.deserializers",),
5252
"NumpyDeserializer": ("sagemaker.deserializers",),
5353
"JSONDeserializer": ("sagemaker.deserializers",),
54-
"RecordSerializer ": ("sagemaker.amazon.common",),
55-
"RecordDeserializer": ("sagemaker.amazon.common",),
54+
"RecordSerializer ": ("sagemaker.serializers",),
55+
"RecordDeserializer": ("sagemaker.deserializers",),
5656
}
5757

5858
OLD_CLASS_NAME_TO_NEW_CLASS_NAME = {
@@ -101,8 +101,8 @@ def node_should_be_modified(self, node):
101101
- ``sagemaker.predictor.StreamDeserializer``
102102
- ``sagemaker.predictor._NumpyDeserializer``
103103
- ``sagemaker.predictor._JsonDeserializer``
104-
- ``sagemaker.amazon.common.numpy_to_record_serializer``
105-
- ``sagemaker.amazon.common.record_deserializer``
104+
- ``sagemaker.serializers.numpy_to_record_serializer``
105+
- ``sagemaker.deserializers.record_deserializer``
106106
107107
Args:
108108
node (ast.Call): a node that represents a function call. For more,
@@ -128,8 +128,8 @@ def modify_node(self, node):
128128
- ``sagemaker.deserializers.StreamDeserializer``
129129
- ``sagemaker.deserializers.NumpyDeserializer``
130130
- ``sagemaker.deserializers._JsonDeserializer``
131-
- ``sagemaker.amazon.common.RecordSerializer``
132-
- ``sagemaker.amazon.common.RecordDeserializer``
131+
- ``sagemaker.serializers.RecordSerializer``
132+
- ``sagemaker.deserializers.RecordDeserializer``
133133
134134
Args:
135135
node (ast.Call): a node that represents a SerDe constructor.
@@ -303,8 +303,8 @@ def node_should_be_modified(self, node):
303303
"""Checks if the import statement imports a SerDe from the ``sagemaker.amazon.common``.
304304
305305
This checks for:
306-
- ``sagemaker.amazon.common.numpy_to_record_serializer``
307-
- ``sagemaker.amazon.common.record_deserializer``
306+
- ``sagemaker.serializers.numpy_to_record_serializer``
307+
- ``sagemaker.deserializers.record_deserializer``
308308
309309
Args:
310310
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
@@ -322,8 +322,8 @@ def modify_node(self, node):
322322
"""Upgrades the ``numpy_to_record_serializer`` and ``record_deserializer`` imports.
323323
324324
This upgrades the classes to (if applicable):
325-
- ``sagemaker.amazon.common.RecordSerializer``
326-
- ``sagemaker.amazon.common.RecordDeserializer``
325+
- ``sagemaker.serializers.RecordSerializer``
326+
- ``sagemaker.deserializers.RecordDeserializer``
327327
328328
Args:
329329
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.

src/sagemaker/deserializers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@
3131
StreamDeserializer,
3232
StringDeserializer,
3333
TorchTensorDeserializer,
34+
RecordDeserializer,
3435
)
3536

37+
from sagemaker.deprecations import deprecated_class
3638
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
3739
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
3840
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -150,3 +152,6 @@ def retrieve_default(
150152
model_type=model_type,
151153
config_name=config_name,
152154
)
155+
156+
157+
record_deserializer = deprecated_class(RecordDeserializer, "record_deserializer")

src/sagemaker/serializers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,10 @@
3030
SparseMatrixSerializer,
3131
TorchTensorSerializer,
3232
StringSerializer,
33+
RecordSerializer,
3334
)
3435

36+
from sagemaker.deprecations import deprecated_class
3537
from sagemaker.jumpstart import artifacts, utils as jumpstart_utils
3638
from sagemaker.jumpstart.constants import DEFAULT_JUMPSTART_SAGEMAKER_SESSION
3739
from sagemaker.jumpstart.enums import JumpStartModelType
@@ -152,3 +154,6 @@ def retrieve_default(
152154
model_type=model_type,
153155
config_name=config_name,
154156
)
157+
158+
159+
numpy_to_record_serializer = deprecated_class(RecordSerializer, "numpy_to_record_serializer")

0 commit comments

Comments
 (0)