Skip to content

Commit ea9fc4d

Browse files
authored
feature: Add v2 SerDe compatability (#1735)
1 parent 95671e0 commit ea9fc4d

File tree

15 files changed

+937
-21
lines changed

15 files changed

+937
-21
lines changed

doc/v2.rst

+55
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,30 @@ For more information, see `Upgrade from Legacy TensorFlow Support <frameworks/te
6464
The ``delete_endpoint()`` method for estimators and ``HyperparameterTuner`` has been deprecated.
6565
Please use :func:`sagemaker.predictor.Predictor.delete_endpoint` instead.
6666

67+
Pre-instantiated Serializer and Deserializer Objects
68+
----------------------------------------------------
69+
70+
The ``csv_serializer``, ``json_serializer``, ``npy_serializer``, ``csv_deserializer``,
71+
``json_deserializer``, and ``numpy_deserializer`` objects have been deprecated.
72+
73+
Please instantiate the objects instead.
74+
75+
+--------------------------------------------+------------------------------------------------+
76+
| v1.x | v2.0 and later |
77+
+============================================+================================================+
78+
| ``sagemaker.predictor.csv_serializer`` | ``sagemaker.deserializers.CSVSerializer()`` |
79+
+--------------------------------------------+------------------------------------------------+
80+
| ``sagemaker.predictor.json_serializer`` | ``sagemaker.serializers.JSONSerializer()`` |
81+
+--------------------------------------------+------------------------------------------------+
82+
| ``sagemaker.predictor.npy_serializer`` | ``sagemaker.deserializers.NumpySerializer()`` |
83+
+--------------------------------------------+------------------------------------------------+
84+
| ``sagemaker.predictor.csv_deserializer`` | ``sagemaker.deserializers.CSVDeserializer()`` |
85+
+--------------------------------------------+------------------------------------------------+
86+
| ``sagemaker.predictor.json_deserializer`` | ``sagemaker.deserializers.JSONDeserializer()`` |
87+
+--------------------------------------------+------------------------------------------------+
88+
| ``sagemaker.predictor.numpy_deserializer`` | ``sagemaker.serializers.NumpyDeserializer()`` |
89+
+--------------------------------------------+------------------------------------------------+
90+
6791
``update_endpoint`` in ``deploy()``
6892
-----------------------------------
6993

@@ -152,6 +176,37 @@ The following estimator parameters have been renamed:
152176
| ``train_volume_kms_key`` | ``volume_kms_key`` |
153177
+------------------------------+------------------------+
154178

179+
Serializer and Deserializer Classes
180+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
181+
182+
The follow serializer/deserializer classes have been renamed and/or moved:
183+
184+
+--------------------------------------------------------+-------------------------------------------------------+
185+
| v1.x | v2.0 and later |
186+
+========================================================+=======================================================+
187+
| ``sagemaker.predictor._CsvDeserializer`` | ``sagemaker.deserializers.CSVDeserializer`` |
188+
+--------------------------------------------------------+-------------------------------------------------------+
189+
| ``sagemaker.predictor._CsvSerializer`` | ``sagemaker.serializers.CSVSerializer`` |
190+
+--------------------------------------------------------+-------------------------------------------------------+
191+
| ``sagemaker.predictor.BytesDeserializer`` | ``sagemaker.deserializers.BytesDeserializers`` |
192+
+--------------------------------------------------------+-------------------------------------------------------+
193+
| ``sagemaker.predictor.StringDeserializer`` | ``sagemaker.deserializers.StringDeserializer`` |
194+
+--------------------------------------------------------+-------------------------------------------------------+
195+
| ``sagemaker.predictor.StreamDeserializer`` | ``sagemaker.deserializers.StreamDeserializer`` |
196+
+--------------------------------------------------------+-------------------------------------------------------+
197+
| ``sagemaker.predictor._JsonSerializer`` | ``sagemaker.serializers.JSONSerializer`` |
198+
+--------------------------------------------------------+-------------------------------------------------------+
199+
| ``sagemaker.predictor._NumpyDeserializer`` | ``sagemaker.deserializers.NumpyDeserializer`` |
200+
+--------------------------------------------------------+-------------------------------------------------------+
201+
| ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` |
202+
+--------------------------------------------------------+-------------------------------------------------------+
203+
| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.serializers.RecordSerializer`` |
204+
+--------------------------------------------------------+-------------------------------------------------------+
205+
| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.deserializers.RecordDeserializer`` |
206+
+--------------------------------------------------------+-------------------------------------------------------+
207+
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
208+
+--------------------------------------------------------+-------------------------------------------------------+
209+
155210
``distributions``
156211
~~~~~~~~~~~~~~~~~
157212

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+67-19
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,24 @@
3535
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
3636
modifiers.training_params.TrainPrefixRemover(),
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
38+
modifiers.serde.SerdeConstructorRenamer(),
3839
]
3940

4041
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
4142

43+
NAME_MODIFIERS = [modifiers.serde.SerdeObjectRenamer()]
44+
45+
MODULE_MODIFIERS = [
46+
modifiers.serde.SerializerImportInserter(),
47+
modifiers.serde.DeserializerImportInserter(),
48+
]
49+
4250
IMPORT_FROM_MODIFIERS = [
4351
modifiers.predictors.PredictorImportFromRenamer(),
4452
modifiers.tfs.TensorFlowServingImportFromRenamer(),
4553
modifiers.training_input.TrainingInputImportFromRenamer(),
54+
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
55+
modifiers.serde.SerdeImportFromPredictorRenamer(),
4656
]
4757

4858

@@ -52,52 +62,90 @@ class ASTTransformer(ast.NodeTransformer):
5262
"""
5363

5464
def visit_Call(self, node):
55-
"""Visits an ``ast.Call`` node and returns a modified node, if needed.
65+
"""Visits an ``ast.Call`` node and returns a modified node or None.
66+
5667
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
5768
5869
Args:
5970
node (ast.Call): a node that represents a function call.
6071
6172
Returns:
62-
ast.Call: a node that represents a function call, which has
63-
potentially been modified from the original input.
73+
ast.AST: if the returned node is None, the original node is removed
74+
from its location. Otherwise, the original node is replaced with
75+
the returned node.
6476
"""
6577
for function_checker in FUNCTION_CALL_MODIFIERS:
66-
function_checker.check_and_modify_node(node)
78+
node = function_checker.check_and_modify_node(node)
79+
return ast.fix_missing_locations(node) if node else None
80+
81+
def visit_Name(self, node):
82+
"""Visits an ``ast.Name`` node and returns a modified node or None.
6783
68-
ast.fix_missing_locations(node)
69-
return node
84+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
85+
86+
Args:
87+
node (ast.Name): a node that represents an identifier.
88+
89+
Returns:
90+
ast.AST: if the returned node is None, the original node is removed
91+
from its location. Otherwise, the original node is replaced with
92+
the returned node.
93+
"""
94+
for name_checker in NAME_MODIFIERS:
95+
node = name_checker.check_and_modify_node(node)
96+
return ast.fix_missing_locations(node) if node else None
7097

7198
def visit_Import(self, node):
72-
"""Visits an ``ast.Import`` node and returns a modified node, if needed.
99+
"""Visits an ``ast.Import`` node and returns a modified node or None.
100+
73101
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
74102
75103
Args:
76104
node (ast.Import): a node that represents an import statement.
77105
78106
Returns:
79-
ast.Import: a node that represents an import statement, which has
80-
potentially been modified from the original input.
107+
ast.AST: if the returned node is None, the original node is removed
108+
from its location. Otherwise, the original node is replaced with
109+
the returned node.
81110
"""
82111
for import_checker in IMPORT_MODIFIERS:
83-
import_checker.check_and_modify_node(node)
112+
node = import_checker.check_and_modify_node(node)
113+
return ast.fix_missing_locations(node) if node else None
114+
115+
def visit_Module(self, node):
116+
"""Visits an ``ast.Module`` node and returns a modified node or None.
117+
118+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
119+
120+
The ``ast.NodeTransformer`` walks the abstract syntax tree and modifies
121+
all other nodes before modifying the ``ast.Module`` node.
84122
85-
ast.fix_missing_locations(node)
86-
return node
123+
Args:
124+
node (ast.Module): a node that represents a Python module.
125+
126+
Returns:
127+
ast.AST: if the returned node is None, the original node is removed
128+
from its location. Otherwise, the original node is replaced with
129+
the returned node.
130+
"""
131+
self.generic_visit(node)
132+
for module_checker in MODULE_MODIFIERS:
133+
node = module_checker.check_and_modify_node(node)
134+
return ast.fix_missing_locations(node) if node else None
87135

88136
def visit_ImportFrom(self, node):
89-
"""Visits an ``ast.ImportFrom`` node and returns a modified node, if needed.
137+
"""Visits an ``ast.ImportFrom`` node and returns a modified node or None.
138+
90139
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
91140
92141
Args:
93142
node (ast.ImportFrom): a node that represents an import statement.
94143
95144
Returns:
96-
ast.ImportFrom: a node that represents an import statement, which has
97-
potentially been modified from the original input.
145+
ast.AST: if the returned node is None, the original node is removed
146+
from its location. Otherwise, the original node is replaced with
147+
the returned node.
98148
"""
99149
for import_checker in IMPORT_FROM_MODIFIERS:
100-
import_checker.check_and_modify_node(node)
101-
102-
ast.fix_missing_locations(node)
103-
return node
150+
node = import_checker.check_and_modify_node(node)
151+
return ast.fix_missing_locations(node) if node else None

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
framework_version,
2020
predictors,
2121
renamed_params,
22+
serde,
2223
tf_legacy_mode,
2324
tfs,
2425
training_params,

src/sagemaker/cli/compatibility/v2/modifiers/airflow.py

+4
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,13 @@ def modify_node(self, node):
5858
Args:
5959
node (ast.Call): a node that represents either a ``model_config`` call or
6060
a ``model_config_from_estimator`` call.
61+
62+
Returns:
63+
ast.AST: the original node, which has been potentially modified.
6164
"""
6265
instance_type = node.args.pop(0)
6366
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))
67+
return node
6468

6569

6670
class ModelConfigImageURIRenamer(renamed_params.ParamRenamer):

src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py

+4
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,11 @@ def modify_node(self, node):
5454
5555
Args:
5656
node (ast.Call): a node that represents a TensorFlow constructor.
57+
58+
Returns:
59+
ast.AST: the original node, which has been potentially modified.
5760
"""
5861
for kw in node.keywords:
5962
if kw.arg == "script_mode":
6063
node.keywords.remove(kw)
64+
return node

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+4
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def modify_node(self, node):
9494
9595
Args:
9696
node (ast.Call): a node that represents the constructor of a framework class.
97+
98+
Returns:
99+
ast.AST: the original node, which has been potentially modified.
97100
"""
98101
framework, is_model = _framework_from_node(node)
99102

@@ -109,6 +112,7 @@ def modify_node(self, node):
109112
py_version = _py_version_defaults(framework, framework_version, is_model)
110113
if py_version:
111114
node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
115+
return node
112116

113117

114118
def _py_version_defaults(framework, framework_version, is_model=False):

src/sagemaker/cli/compatibility/v2/modifiers/modifier.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ class Modifier(object):
2222
"""
2323

2424
def check_and_modify_node(self, node):
25-
"""Check an AST node, and modify it if applicable."""
25+
"""Check an AST node, and modify, replace, or remove it if applicable."""
2626
if self.node_should_be_modified(node):
27-
self.modify_node(node)
27+
node = self.modify_node(node)
28+
return node
2829

2930
@abstractmethod
3031
def node_should_be_modified(self, node):

src/sagemaker/cli/compatibility/v2/modifiers/predictors.py

+8
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,13 @@ def modify_node(self, node):
6262
6363
Args:
6464
node (ast.Call): a node that represents a *Predictor constructor.
65+
66+
Returns:
67+
ast.AST: the original node, which has been potentially modified.
6568
"""
6669
_rename_class(node)
6770
_rename_endpoint(node)
71+
return node
6872

6973

7074
def _rename_class(node):
@@ -106,7 +110,11 @@ def modify_node(self, node):
106110
Args:
107111
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
108112
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
113+
114+
Returns:
115+
ast.AST: the original node, which has been potentially modified.
109116
"""
110117
for name in node.names:
111118
if name.name == BASE_PREDICTOR:
112119
name.name = "Predictor"
120+
return node

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+4
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,13 @@ def modify_node(self, node):
6161
6262
Args:
6363
node (ast.Call): a node that represents the relevant function call.
64+
65+
Returns:
66+
ast.AST: the original node, which has been potentially modified.
6467
"""
6568
keyword = parsing.arg_from_keywords(node, self.old_param_name)
6669
keyword.arg = self.new_param_name
70+
return node
6771

6872

6973
class MethodParamRenamer(ParamRenamer):

0 commit comments

Comments
 (0)