Skip to content

feature: Add v2 SerDe compatability #1735

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jul 28, 2020
Merged
55 changes: 55 additions & 0 deletions doc/v2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,30 @@ For more information, see `Upgrade from Legacy TensorFlow Support <frameworks/te
The ``delete_endpoint()`` method for estimators and ``HyperparameterTuner`` has been deprecated.
Please use :func:`sagemaker.predictor.Predictor.delete_endpoint` instead.

Pre-instantiated Serializer and Deserializer Objects
----------------------------------------------------

The ``csv_serializer``, ``json_serializer``, ``npy_serializer``, ``csv_deserializer``,
``json_deserializer``, and ``numpy_deserializer`` objects have been deprecated.

Please instantiate the objects instead.

+--------------------------------------------+------------------------------------------------+
| v1.x | v2.0 and later |
+============================================+================================================+
| ``sagemaker.predictor.csv_serializer`` | ``sagemaker.deserializers.CSVSerializer()`` |
+--------------------------------------------+------------------------------------------------+
| ``sagemaker.predictor.json_serializer`` | ``sagemaker.serializers.JSONSerializer()`` |
+--------------------------------------------+------------------------------------------------+
| ``sagemaker.predictor.npy_serializer`` | ``sagemaker.deserializers.NumpySerializer()`` |
+--------------------------------------------+------------------------------------------------+
| ``sagemaker.predictor.csv_deserializer`` | ``sagemaker.deserializers.CSVDeserializer()`` |
+--------------------------------------------+------------------------------------------------+
| ``sagemaker.predictor.json_deserializer`` | ``sagemaker.deserializers.JSONDeserializer()`` |
+--------------------------------------------+------------------------------------------------+
| ``sagemaker.predictor.numpy_deserializer`` | ``sagemaker.serializers.NumpyDeserializer()`` |
+--------------------------------------------+------------------------------------------------+

``update_endpoint`` in ``deploy()``
-----------------------------------

Expand Down Expand Up @@ -152,6 +176,37 @@ The following estimator parameters have been renamed:
| ``train_volume_kms_key`` | ``volume_kms_key`` |
+------------------------------+------------------------+

Serializer and Deserializer Classes
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The follow serializer/deserializer classes have been renamed and/or moved:

+--------------------------------------------------------+-------------------------------------------------------+
| v1.x | v2.0 and later |
+========================================================+=======================================================+
| ``sagemaker.predictor._CsvDeserializer`` | ``sagemaker.deserializers.CSVDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor._CsvSerializer`` | ``sagemaker.serializers.CSVSerializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor.BytesDeserializer`` | ``sagemaker.deserializers.BytesDeserializers`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor.StringDeserializer`` | ``sagemaker.deserializers.StringDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor.StreamDeserializer`` | ``sagemaker.deserializers.StreamDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor._JsonSerializer`` | ``sagemaker.serializers.JSONSerializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor._NumpyDeserializer`` | ``sagemaker.deserializers.NumpyDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor._NPYSerializer`` | ``sagemaker.serializers.NumpySerializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.amazon.common.numpy_to_record_serializer`` | ``sagemaker.amazon.serializers.RecordSerializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.amazon.common.record_deserializer`` | ``sagemaker.amazon.deserializers.RecordDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+
| ``sagemaker.predictor._JsonDeserializer`` | ``sagemaker.deserializers.JSONDeserializer`` |
+--------------------------------------------------------+-------------------------------------------------------+

``distributions``
~~~~~~~~~~~~~~~~~

Expand Down
81 changes: 62 additions & 19 deletions src/sagemaker/cli/compatibility/v2/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,24 @@
modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(),
modifiers.training_params.TrainPrefixRemover(),
modifiers.training_input.TrainingInputConstructorRefactor(),
modifiers.serde.SerdeConstructorRenamer(),
]

IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]

NAME_MODIFIERS = [modifiers.serde.SerdeObjectRenamer()]

MODULE_MODIFIERS = [
modifiers.serde.SerializerImportInserter(),
modifiers.serde.DeserializerImportInserter(),
]

IMPORT_FROM_MODIFIERS = [
modifiers.predictors.PredictorImportFromRenamer(),
modifiers.tfs.TensorFlowServingImportFromRenamer(),
modifiers.training_input.TrainingInputImportFromRenamer(),
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
modifiers.serde.SerdeImportFromPredictorRenamer(),
]


Expand All @@ -52,52 +62,85 @@ class ASTTransformer(ast.NodeTransformer):
"""

def visit_Call(self, node):
"""Visits an ``ast.Call`` node and returns a modified node, if needed.
"""Visits an ``ast.Call`` node and returns a modified node or None.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

Args:
node (ast.Call): a node that represents a function call.

Returns:
ast.Call: a node that represents a function call, which has
potentially been modified from the original input.
ast.AST: if the returned node is None, the original node is removed
from its location. Otherwise, the original node is replaced with the
returned node.
"""
for function_checker in FUNCTION_CALL_MODIFIERS:
function_checker.check_and_modify_node(node)
node = function_checker.check_and_modify_node(node)
return ast.fix_missing_locations(node) if node else None

def visit_Name(self, node):
"""Visits an ``ast.Name`` node and returns a modified node or None.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

ast.fix_missing_locations(node)
return node
Args:
node (ast.Name): a node that represents an identifier.

Returns:
ast.AST: if the returned node is None, the original node is removed
from its location. Otherwise, the original node is replaced with the
returned node.
"""
for name_checker in NAME_MODIFIERS:
node = name_checker.check_and_modify_node(node)
return ast.fix_missing_locations(node) if node else None

def visit_Import(self, node):
"""Visits an ``ast.Import`` node and returns a modified node, if needed.
"""Visits an ``ast.Import`` node and returns a modified node or None.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

Args:
node (ast.Import): a node that represents an import statement.

Returns:
ast.Import: a node that represents an import statement, which has
potentially been modified from the original input.
ast.AST: if the returned node is None, the original node is removed
from its location. Otherwise, the original node is replaced with the
returned node.
"""
for import_checker in IMPORT_MODIFIERS:
import_checker.check_and_modify_node(node)
node = import_checker.check_and_modify_node(node)
return ast.fix_missing_locations(node) if node else None

ast.fix_missing_locations(node)
return node
def visit_Module(self, node):
"""Visits an ``ast.Module`` node and returns a modified node or None.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

The ``ast.NodeTransformer`` walks the abstract syntax tree and modifies
all other nodes before modifying the ``ast.Module`` node.

Args:
node (ast.Module): a node that represents a Python module.

Returns:
ast.AST: if the returned node is None, the original node is removed
from its location. Otherwise, the original node is replaced with the
returned node.
"""
self.generic_visit(node)
for module_checker in MODULE_MODIFIERS:
node = module_checker.check_and_modify_node(node)
return ast.fix_missing_locations(node) if node else None

def visit_ImportFrom(self, node):
"""Visits an ``ast.ImportFrom`` node and returns a modified node, if needed.
"""Visits an ``ast.ImportFrom`` node and returns a modified node or None.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.

Args:
node (ast.ImportFrom): a node that represents an import statement.

Returns:
ast.ImportFrom: a node that represents an import statement, which has
potentially been modified from the original input.
ast.AST: if the returned node is None, the original node is removed
from its location. Otherwise, the original node is replaced with the
returned node.
"""
for import_checker in IMPORT_FROM_MODIFIERS:
import_checker.check_and_modify_node(node)

ast.fix_missing_locations(node)
return node
node = import_checker.check_and_modify_node(node)
return ast.fix_missing_locations(node) if node else None
1 change: 1 addition & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
framework_version,
predictors,
renamed_params,
serde,
tf_legacy_mode,
tfs,
training_params,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/airflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,13 @@ def modify_node(self, node):
Args:
node (ast.Call): a node that represents either a ``model_config`` call or
a ``model_config_from_estimator`` call.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
instance_type = node.args.pop(0)
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))
return node


class ModelConfigImageURIRenamer(renamed_params.ParamRenamer):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def modify_node(self, node):

Args:
node (ast.Call): a node that represents a TensorFlow constructor.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
for kw in node.keywords:
if kw.arg == "script_mode":
node.keywords.remove(kw)
return node
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ def modify_node(self, node):

Args:
node (ast.Call): a node that represents the constructor of a framework class.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
framework, is_model = _framework_from_node(node)

Expand All @@ -109,6 +112,7 @@ def modify_node(self, node):
py_version = _py_version_defaults(framework, framework_version, is_model)
if py_version:
node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
return node


def _py_version_defaults(framework, framework_version, is_model=False):
Expand Down
5 changes: 3 additions & 2 deletions src/sagemaker/cli/compatibility/v2/modifiers/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ class Modifier(object):
"""

def check_and_modify_node(self, node):
"""Check an AST node, and modify it if applicable."""
"""Check an AST node, and modify, replace, or remove it if applicable."""
if self.node_should_be_modified(node):
self.modify_node(node)
node = self.modify_node(node)
return node

@abstractmethod
def node_should_be_modified(self, node):
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/cli/compatibility/v2/modifiers/predictors.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,13 @@ def modify_node(self, node):

Args:
node (ast.Call): a node that represents a *Predictor constructor.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
_rename_class(node)
_rename_endpoint(node)
return node


def _rename_class(node):
Expand Down Expand Up @@ -106,7 +110,11 @@ def modify_node(self, node):
Args:
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
for name in node.names:
if name.name == BASE_PREDICTOR:
name.name = "Predictor"
return node
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ def modify_node(self, node):

Args:
node (ast.Call): a node that represents the relevant function call.

Returns:
ast.AST: the original node, which has been potentially modified.
"""
keyword = parsing.arg_from_keywords(node, self.old_param_name)
keyword.arg = self.new_param_name
return node


class MethodParamRenamer(ParamRenamer):
Expand Down
Loading