From cbcc62267054d877e4416eaa59f9cfc7eb5173c6 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 21 Jul 2020 19:50:09 -0500 Subject: [PATCH 01/11] Add SerDe compatability --- .../cli/compatibility/v2/ast_transformer.py | 82 +++- .../compatibility/v2/modifiers/__init__.py | 1 + .../cli/compatibility/v2/modifiers/airflow.py | 4 + .../v2/modifiers/deprecated_params.py | 4 + .../v2/modifiers/framework_version.py | 4 + .../compatibility/v2/modifiers/modifier.py | 5 +- .../compatibility/v2/modifiers/predictors.py | 8 + .../v2/modifiers/renamed_params.py | 4 + .../cli/compatibility/v2/modifiers/serde.py | 413 ++++++++++++++++++ .../v2/modifiers/tf_legacy_mode.py | 8 + .../cli/compatibility/v2/modifiers/tfs.py | 12 + .../v2/modifiers/training_input.py | 5 + .../v2/modifiers/training_params.py | 4 + .../compatibility/v2/modifiers/test_serde.py | 347 +++++++++++++++ 14 files changed, 881 insertions(+), 20 deletions(-) create mode 100644 src/sagemaker/cli/compatibility/v2/modifiers/serde.py create mode 100644 tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 6e172f9e32..3e12f21a9a 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -35,14 +35,24 @@ modifiers.renamed_params.SessionCreateEndpointImageURIRenamer(), modifiers.training_params.TrainPrefixRemover(), modifiers.training_input.TrainingInputConstructorRefactor(), + modifiers.serde.SerdeConstructorRenamer(), ] IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()] +NAME_MODIFIERS = [modifiers.serde.SerdeObjectRenamer()] + +MODULE_MODIFIERS = [ + modifiers.serde.SerializerImportInserter(), + modifiers.serde.DeserializerImportInserter(), +] + IMPORT_FROM_MODIFIERS = [ modifiers.predictors.PredictorImportFromRenamer(), modifiers.tfs.TensorFlowServingImportFromRenamer(), modifiers.training_input.TrainingInputImportFromRenamer(), + modifiers.serde.SerdeImportFromAmazonCommonRenamer(), + modifiers.serde.SerdeImportFromPredictorRenamer(), ] @@ -52,52 +62,88 @@ class ASTTransformer(ast.NodeTransformer): """ def visit_Call(self, node): - """Visits an ``ast.Call`` node and returns a modified node, if needed. + """Visits an ``ast.Call`` node and returns a modified node or None. See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: node (ast.Call): a node that represents a function call. Returns: - ast.Call: a node that represents a function call, which has - potentially been modified from the original input. + ast.AST: if the returned node is None, the original node is removed + from its location. Otherwise, the original node is replaced with the + returned node. """ for function_checker in FUNCTION_CALL_MODIFIERS: - function_checker.check_and_modify_node(node) + node = function_checker.check_and_modify_node(node) + return ast.fix_missing_locations(node) if node else None + + def visit_Name(self, node): + """Visits an ``ast.Name`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. - ast.fix_missing_locations(node) + Args: + node (ast.Name): a node that represents an identifier. + + Returns: + ast.AST: if the returned node is None, the original node is removed + from its location. Otherwise, the original node is replaced with the + returned node. + """ + for name_checker in NAME_MODIFIERS: + node = name_checker.check_and_modify_node(node) + if node is None: + return None + node = ast.fix_missing_locations(node) return node def visit_Import(self, node): - """Visits an ``ast.Import`` node and returns a modified node, if needed. + """Visits an ``ast.Import`` node and returns a modified node or None. See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: node (ast.Import): a node that represents an import statement. Returns: - ast.Import: a node that represents an import statement, which has - potentially been modified from the original input. + ast.AST: if the returned node is None, the original node is removed + from its location. Otherwise, the original node is replaced with the + returned node. """ for import_checker in IMPORT_MODIFIERS: - import_checker.check_and_modify_node(node) + node = import_checker.check_and_modify_node(node) + return ast.fix_missing_locations(node) if node else None - ast.fix_missing_locations(node) - return node + def visit_Module(self, node): + """Visits an ``ast.Module`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. + + The ``ast.NodeTransformer`` walks the abstract syntax tree and modifies + all other nodes before modifying the ``ast.Module`` node. + + Args: + node (ast.Module): a node that represents a Python module. + + Returns: + ast.AST: if the returned node is None, the original node is removed + from its location. Otherwise, the original node is replaced with the + returned node. + """ + self.generic_visit(node) + for module_checker in MODULE_MODIFIERS: + node = module_checker.check_and_modify_node(node) + return ast.fix_missing_locations(node) if node else None def visit_ImportFrom(self, node): - """Visits an ``ast.ImportFrom`` node and returns a modified node, if needed. + """Visits an ``ast.ImportFrom`` node and returns a modified node or None. See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: node (ast.ImportFrom): a node that represents an import statement. Returns: - ast.ImportFrom: a node that represents an import statement, which has - potentially been modified from the original input. + ast.AST: if the returned node is None, the original node is removed + from its location. Otherwise, the original node is replaced with the + returned node. """ for import_checker in IMPORT_FROM_MODIFIERS: - import_checker.check_and_modify_node(node) - - ast.fix_missing_locations(node) - return node + node = import_checker.check_and_modify_node(node) + return ast.fix_missing_locations(node) if node else None diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py index 3926fec6f0..f6e1ead061 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/__init__.py @@ -19,6 +19,7 @@ framework_version, predictors, renamed_params, + serde, tf_legacy_mode, tfs, training_params, diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py index f69f519468..13b06b9230 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/airflow.py @@ -58,9 +58,13 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents either a ``model_config`` call or a ``model_config_from_estimator`` call. + + Returns: + ast.AST: the original node, which has been potentially modified. """ instance_type = node.args.pop(0) node.keywords.append(ast.keyword(arg="instance_type", value=instance_type)) + return node class ModelConfigImageURIRenamer(renamed_params.ParamRenamer): diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py index 662f9d1e80..f209b7a45c 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py @@ -54,7 +54,11 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents a TensorFlow constructor. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for kw in node.keywords: if kw.arg == "script_mode": node.keywords.remove(kw) + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py index 8689c21d48..f1da388361 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py @@ -94,6 +94,9 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents the constructor of a framework class. + + Returns: + ast.AST: the original node, which has been potentially modified. """ framework, is_model = _framework_from_node(node) @@ -109,6 +112,7 @@ def modify_node(self, node): py_version = _py_version_defaults(framework, framework_version, is_model) if py_version: node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version))) + return node def _py_version_defaults(framework, framework_version, is_model=False): diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/modifier.py b/src/sagemaker/cli/compatibility/v2/modifiers/modifier.py index c1d53dfc85..3b5d47a412 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/modifier.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/modifier.py @@ -22,9 +22,10 @@ class Modifier(object): """ def check_and_modify_node(self, node): - """Check an AST node, and modify it if applicable.""" + """Check an AST node, and modify, replace, or remove it if applicable.""" if self.node_should_be_modified(node): - self.modify_node(node) + node = self.modify_node(node) + return node @abstractmethod def node_should_be_modified(self, node): diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py b/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py index 663a9fe394..90ade836a5 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/predictors.py @@ -62,9 +62,13 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents a *Predictor constructor. + + Returns: + ast.AST: the original node, which has been potentially modified. """ _rename_class(node) _rename_endpoint(node) + return node def _rename_class(node): @@ -106,7 +110,11 @@ def modify_node(self, node): Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for name in node.names: if name.name == BASE_PREDICTOR: name.name = "Predictor" + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py index 7570740d9f..7a19bc6597 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py @@ -61,9 +61,13 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents the relevant function call. + + Returns: + ast.AST: the original node, which has been potentially modified. """ keyword = parsing.arg_from_keywords(node, self.old_param_name) keyword.arg = self.new_param_name + return node class MethodParamRenamer(ParamRenamer): diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py new file mode 100644 index 0000000000..94b25153ec --- /dev/null +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -0,0 +1,413 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes to modify serializer and deserializer code to be compatible with +version 2.0 and later of the SageMaker Python SDK. +""" +from __future__ import absolute_import + +import ast + +import pasta + +from sagemaker.cli.compatibility.v2.modifiers import matching +from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier + +# The values are tuples so that the object can be passed to matching.matches_any. +OLD_CLASS_NAME_TO_NAMESPACES = { + "_CsvSerializer": ("sagemaker.predictor",), + "_JsonSerializer": ("sagemaker.predictor",), + "_NpySerializer": ("sagemaker.predictor",), + "_CsvDeserializer": ("sagemaker.predictor",), + "BytesDeserializer": ("sagemaker.predictor",), + "StringDeserializer": ("sagemaker.predictor",), + "StreamDeserializer": ("sagemaker.predictor",), + "_NumpyDeserializer": ("sagemaker.predictor",), + "_JsonDeserializer": ("sagemaker.predictor",), + "numpy_to_record_serializer": ("sagemaker.amazon.common",), + "record_deserializer": ("sagemaker.amazon.common",), +} + +# The values are tuples so that the object can be passed to matching.matches_any. +NEW_CLASS_NAME_TO_NAMESPACES = { + "CSVSerializer": ("sagemaker.serializers",), + "JSONSerializer": ("sagemaker.serializers",), + "NumpySerializer": ("sagemaker.serializers",), + "CSVDeserializer": ("sagemaker.deserializers",), + "BytesDeserializer": ("sagemaker.deserializers",), + "StringDeserializer": ("sagemaker.deserializers",), + "StreamDeserializer": ("sagemaker.deserializers",), + "NumpyDeserializer": ("sagemaker.deserializers",), + "JSONDeserializer": ("sagemaker.deserializers",), + "RecordSerializer ": ("sagemaker.amazon.common",), + "RecordDeserializer": ("sagemaker.amazon.common",), +} + +OLD_CLASS_NAME_TO_NEW_CLASS_NAME = { + "_CsvSerializer": "CSVSerializer", + "_JsonSerializer": "JSONSerializer", + "_NpySerializer": "NumpySerializer", + "_CsvDeserializer": "CSVDeserializer", + "BytesDeserializer": "BytesDeserializer", + "StringDeserializer": "StringDeserializer", + "StreamDeserializer": "StreamDeserializer", + "_NumpyDeserializer": "NumpyDeserializer", + "_JsonDeserializer": "JSONDeserializer", + "numpy_to_record_serializer": "RecordSerializer", + "record_deserializer": "RecordDeserializer", +} + +OLD_OBJECT_NAME_TO_NEW_CLASS_NAME = { + "csv_serializer": "CSVSerializer", + "json_serializer": "JSONSerializer", + "npy_serializer": "NumpySerializer", + "csv_deserializer": "CSVDeserializer", + "json_deserializer": "JSONDeserializer", + "numpy_deserializer": "NumpyDeserializer", +} + +OLD_AMAZON_CLASS_NAMES = set( + { + class_name + for class_name, namespaces in OLD_CLASS_NAME_TO_NAMESPACES.items() + if "sagemaker.amazon.common" in namespaces + } +) +NEW_AMAZON_CLASS_NAMES = set( + { + class_name + for class_name, namespaces in NEW_CLASS_NAME_TO_NAMESPACES.items() + if "sagemaker.amazon.common" in namespaces + } +) + +NEW_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.values()) +OLD_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.keys()) + +OLD_OBJECT_NAMES = set(OLD_OBJECT_NAME_TO_NEW_CLASS_NAME.keys()) + + +class SerdeConstructorRenamer(Modifier): + """A class to rename SerDe classes.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Call`` node instantiates a SerDe class. + + This looks for the following calls: + + - ``sagemaker.predictor._CsvSerializer`` + - ``sagemaker.predictor._JsonSerializer`` + - ``sagemaker.predictor._NpySerializer`` + - ``sagemaker.predictor._CsvDeserializer`` + - ``sagemaker.predictor.BytesDeserializer`` + - ``sagemaker.predictor.StringDeserializer`` + - ``sagemaker.predictor.StreamDeserializer`` + - ``sagemaker.predictor._NumpyDeserializer`` + - ``sagemaker.predictor._JsonDeserializer`` + - ``sagemaker.amazon.common.numpy_to_record_serializer`` + - ``sagemaker.amazon.common.record_deserializer`` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a SerDe class. + """ + return matching.matches_any(node, OLD_CLASS_NAME_TO_NAMESPACES) + + def modify_node(self, node): + """Modifies the ``ast.Call`` node to use the classes for SerDe + available in version 2.0 and later of the Python SDK: + + - ``sagemaker.serializers.CSVSerializer`` + - ``sagemaker.serializers.JSONSerializer`` + - ``sagemaker.serializers.NumpySerializer`` + - ``sagemaker.deserializers.CSVDeserializer`` + - ``sagemaker.deserializers.BytesDeserializer`` + - ``sagemaker.deserializers.StringDeserializer`` + - ``sagemaker.deserializers.StreamDeserializer`` + - ``sagemaker.deserializers.NumpyDeserializer`` + - ``sagemaker.deserializers._JsonDeserializer`` + - ``sagemaker.amazon.common.RecordSerializer`` + - ``sagemaker.amazon.common.RecordDeserializer`` + + Args: + node (ast.Call): a node that represents a SerDe constructor. + + Returns: + ast.Call: a node that represents the instantiation of a SerDe object. + """ + class_name = node.func.id if isinstance(node.func, ast.Name) else node.func.attr + new_class_name = OLD_CLASS_NAME_TO_NEW_CLASS_NAME[class_name] + + # We don't change the namespace for Amazon SerDe. + if class_name in OLD_AMAZON_CLASS_NAMES: + if isinstance(node.func, ast.Name): + node.func.id = new_class_name + elif isinstance(node.func, ast.Attribute): + node.func.attr = new_class_name + return node + + namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] + subpackage_name = namespace_name[namespace_name.find(".") + 1 :] + assert subpackage_name in {"serializers", "deserializers"} + return pasta.parse("%s.%s()" % (subpackage_name, new_class_name)).body[0].value + + +class SerdeObjectRenamer(Modifier): + """A class to rename SerDe objects imported from ``sagemaker.predictor``.""" + + def node_should_be_modified(self, node): + """Checks if the ``ast.Name`` node identifies a SerDe object. + + This looks for the following objects: + + - ``sagemaker.predictor.csv_serializer`` + - ``sagemaker.predictor.json_serializer`` + - ``sagemaker.predictor.npy_serializer`` + - ``sagemaker.predictor.csv_deserializer`` + - ``sagemaker.predictor.json_deserializer`` + - ``sagemaker.predictor.numpy_deserializer`` + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` instantiates a SerDe class. + """ + name = node.id if isinstance(node, ast.Name) else node.attr + return name in OLD_OBJECT_NAMES + + def modify_node(self, node): + """Replaces the ``ast.Name`` node with a ``ast.Call`` node that + instantiates a class available in version 2.0 and later of the Python SDK: + + - ``sagemaker.serializers.CSVSerializer()`` + - ``sagemaker.serializers.JSONSerializer()`` + - ``sagemaker.serializers.NumpySerializer()`` + - ``sagemaker.deserializers.CSVDeserializer()`` + - ``sagemaker.deserializers.JSONDeserializer()`` + - ``sagemaker.deserializers.NumpyDeserializer()`` + + The ``sagemaker`` prefix is excluded from the returned node. + + Args: + node (ast.Name): a node that represents a Python identifier. + + Returns: + ast.Call: a node that represents the instantiation of a SerDe object. + """ + object_name = node.id if isinstance(node, ast.Name) else node.attr + new_class_name = OLD_OBJECT_NAME_TO_NEW_CLASS_NAME[object_name] + namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] + subpackage_name = namespace_name[namespace_name.find(".") + 1 :] + assert subpackage_name in {"serializers", "deserializers"} + return pasta.parse("%s.%s()" % (subpackage_name, new_class_name)).body[0].value + + +class SerdeImportFromPredictorRenamer(Modifier): + """A class to update import statements starting with ``from sagemaker.predictor``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports a SerDe from the + ``sagemaker.predictor`` module. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: True if and only if the ``ast.ImportFrom`` imports a SerDe + from the ``sagemaker.predictor`` module. + """ + return node.module == "sagemaker.predictor" and any( + [name.name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES) for name in node.names] + ) + + def modify_node(self, node): + """Removes the imported SerDe classes, as applicable. + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.ImportFrom: a node that represents a import statement, which has + been modified to remove imported serializers. If nothing is + imported, None is returned. + """ + i = 0 + while i < len(node.names): + name = node.names[i].name + if name in OLD_CLASS_NAMES | OLD_OBJECT_NAMES: + node.names.pop(i) + else: + i += 1 + + if not node.names: + return None + + return node + + +class SerdeImportFromAmazonCommonRenamer(Modifier): + """A class to update import statements starting with ``from sagemaker.amazon.common``.""" + + def node_should_be_modified(self, node): + """Checks if the import statement imports a SerDe from the + ``sagemaker.amazon.common`` module. + + This checks for: + - ``sagemaker.amazon.common.numpy_to_record_serializer`` + - ``sagemaker.amazon.common.record_deserializer`` + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: True if and only if the ``ast.ImportFrom`` imports a SerDe from + the ``sagemaker.amazon.common`` module. + """ + return node.module == "sagemaker.amazon.common" and any( + [alias.name in OLD_AMAZON_CLASS_NAMES for alias in node.names] + ) + + def modify_node(self, node): + """Upgrades the ``numpy_to_record_serializer`` and ``record_deserializer`` + imports, as applicable. + + This upgrades the classes to: + - ``sagemaker.amazon.common.RecordSerializer`` + - ``sagemaker.amazon.common.RecordDeserializer`` + + Args: + node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. + For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.ImportFrom: a node that represents a import statement, which has + been modified to import the upgraded class name. + """ + for alias in node.names: + if alias.name in OLD_AMAZON_CLASS_NAMES: + alias.name = OLD_CLASS_NAME_TO_NEW_CLASS_NAME[alias.name] + return node + + +class _ImportInserter(Modifier): + """A class to insert import statements into the Python module.""" + + def __init__(self, class_names, import_node): + """Initialize the ``class_names`` and ``import_node attributes``. + + Args: + class_names (set): If any of these class names are referenced in the + module, then ``import_node`` is inserted. + import_node (ast.ImportFrom): The AST node to insert. + """ + self.class_names = class_names + self.import_node = import_node + + def node_should_be_modified(self, module): + """Checks if the ``ast.Module`` node contains references to the + specified class names. + + Args: + node (ast.Module): a node that represents a Python module. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Module`` references one of the specified classes. + """ + for node in ast.walk(module): + if isinstance(node, ast.Name) and node.id in self.class_names: + return True + if isinstance(node, ast.Attribute) and node.attr in self.class_names: + return True + return False + + def modify_node(self, module): + """Modifies the ``ast.Module`` node by inserted the specified node. + + The ``import_node`` object is inserted immediately before the first + import statement. + + Args: + node (ast.Module): a node that represents a Python module. + + Returns: + ast.Module: a node that represents a Python module, which has been + modified to import a module. + """ + for i, node in enumerate(module.body): + if isinstance(node, (ast.Import, ast.ImportFrom)): + module.body.insert(i, self.import_node) + return module + + module.body.insert(0, self.import_node) + return module + + +class SerializerImportInserter(_ImportInserter): + """A class to import the ``sagemaker.serializers`` module, if necessary. + + This looks for references to the following classes: + + - ``sagemaker.serializers.CSVSerializer`` + - ``sagemaker.serializers.JSONSerializer`` + - ``sagemaker.serializer.NumpySerializer`` + + Because ``SerializerImportInserter`` is gauranteed to run after + ``SerdeConstructorRenamer`` and ``SerdeObjectRenamer``, + we only need to check for the new serializer class names. + """ + + def __init__(self): + # Amazon SerDe are not defined in the sagemaker.serializers module. + class_names = { + class_name + for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES + if "Serializer" in class_name + } + import_node = pasta.parse("from sagemaker import serializers\n").body[0] + super().__init__(class_names, import_node) + + +class DeserializerImportInserter(_ImportInserter): + """A class to import the ``sagemaker.deserializers`` module, if necessary. + + This looks for references to the following classes: + + - ``sagemaker.serializers.CSVDeserializer`` + - ``sagemaker.serializers.BytesDeserializer`` + - ``sagemaker.serializers.StringDeserializer`` + - ``sagemaker.serializers.StreamDeserializer`` + - ``sagemaker.serializers.NumpyDeserializer`` + - ``sagemaker.serializer.JSONDeserializer`` + + Because ``DeserializerImportInserter`` is gauranteed to run after + ``SerdeConstructorRenamer`` and ``SerdeObjectRenamer``, + we only need to check for the new deserializer class names. + """ + + def __init__(self): + # Amazon SerDe are not defined in the sagemaker.serializers module. + class_names = { + class_name + for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES + if "Deserializer" in class_name + } + import_node = pasta.parse("from sagemaker import deserializers\n").body[0] + super().__init__(class_names, import_node) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py index 69a9268768..2852d0ecdc 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py @@ -100,6 +100,9 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents a TensorFlow constructor. + + Returns: + ast.AST: the original node, which has been potentially modified. """ base_hps = {} additional_hps = {} @@ -130,6 +133,7 @@ def modify_node(self, node): node.keywords.append(ast.keyword(arg="image_uri", value=ast.Str(s=image_uri))) node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False))) + return node def _hyperparameter_key_for_param(self, arg): """Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter.""" @@ -210,7 +214,11 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents ``fit`` being called with ``run_tensorboard_locally`` set. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for kw in node.keywords: if kw.arg == "run_tensorboard_locally": node.keywords.remove(kw) + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py b/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py index 2ecb18957b..5c80fcb2d4 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/tfs.py @@ -61,12 +61,16 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents a TensorFlow Serving constructor. + + Returns: + ast.AST: the original node, which has been potentially modified. """ if isinstance(node.func, ast.Name): node.func.id = self._new_cls_name(node.func.id) else: node.func.attr = self._new_cls_name(node.func.attr) node.func.value = node.func.value.value + return node def _new_cls_name(self, cls_name): """Returns the updated class name.""" @@ -95,11 +99,15 @@ def modify_node(self, node): Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.AST: the original node, which has been potentially modified. """ node.module = "sagemaker.tensorflow" for cls in node.names: cls.name = "TensorFlow{}".format(cls.name) + return node class TensorFlowServingImportRenamer(Modifier): @@ -112,7 +120,11 @@ def check_and_modify_node(self, node): Args: node (ast.Import): a node that represents an import statement. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for module in node.names: if module.name == "sagemaker.tensorflow.serving": module.name = "sagemaker.tensorflow" + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py b/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py index edf5e37c63..171d52f570 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/training_input.py @@ -56,6 +56,7 @@ def modify_node(self, node): elif matching.matches_attr(node, S3_INPUT_NAME): node.func.attr = "TrainingInput" _rename_namespace(node, "session") + return node def _rename_namespace(node, name): @@ -89,9 +90,13 @@ def modify_node(self, node): Args: node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement. For more, see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for name in node.names: if name.name == S3_INPUT_NAME: name.name = "TrainingInput" if node.module == "sagemaker.session": node.module = "sagemaker.inputs" + return node diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py b/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py index 3368e694d4..48ddf14201 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/training_params.py @@ -91,7 +91,11 @@ def modify_node(self, node): Args: node (ast.Call): a node that represents an estimator constructor. + + Returns: + ast.AST: the original node, which has been potentially modified. """ for kw in node.keywords: if kw.arg in PARAMS: kw.arg = kw.arg.replace("train_", "") + return node diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py new file mode 100644 index 0000000000..d48a4dfc11 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -0,0 +1,347 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import ast + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import serde +from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import + + +@pytest.mark.parametrize( + "src, expected", + [ + ("sagemaker.predictor._CsvSerializer()", True), + ("sagemaker.predictor._JsonSerializer()", True), + ("sagemaker.predictor._NpySerializer()", True), + ("sagemaker.predictor._CsvDeserializer()", True), + ("sagemaker.predictor.BytesDeserializer()", True), + ("sagemaker.predictor.StringDeserializer()", True), + ("sagemaker.predictor.StreamDeserializer()", True), + ("sagemaker.predictor._NumpyDeserializer()", True), + ("sagemaker.predictor._JsonDeserializer()", True), + ("sagemaker.amazon.common.numpy_to_record_serializer()", True), + ("sagemaker.amazon.common.record_deserializer()", True), + ("_CsvSerializer()", True), + ("_JsonSerializer()", True), + ("_NpySerializer()", True), + ("_CsvDeserializer()", True), + ("BytesDeserializer()", True), + ("StringDeserializer()", True), + ("StreamDeserializer()", True), + ("_NumpyDeserializer()", True), + ("_JsonDeserializer()", True), + ("numpy_to_record_serializer()", True), + ("record_deserializer()", True), + ("OtherClass()", False), + ], +) +def test_constructor_node_should_be_modified(src, expected): + modifier = serde.SerdeConstructorRenamer() + node = ast_call(src) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "src, expected", + [ + ("sagemaker.predictor._CsvSerializer()", "serializers.CSVSerializer()"), + ("sagemaker.predictor._JsonSerializer()", "serializers.JSONSerializer()"), + ("sagemaker.predictor._NpySerializer()", "serializers.NumpySerializer()"), + ("sagemaker.predictor._CsvDeserializer()", "deserializers.CSVDeserializer()"), + ("sagemaker.predictor.BytesDeserializer()", "deserializers.BytesDeserializer()"), + ("sagemaker.predictor.StringDeserializer()", "deserializers.StringDeserializer()",), + ("sagemaker.predictor.StreamDeserializer()", "deserializers.StreamDeserializer()",), + ("sagemaker.predictor._NumpyDeserializer()", "deserializers.NumpyDeserializer()"), + ("sagemaker.predictor._JsonDeserializer()", "deserializers.JSONDeserializer()"), + ( + "sagemaker.amazon.common.numpy_to_record_serializer()", + "sagemaker.amazon.common.RecordSerializer()", + ), + ( + "sagemaker.amazon.common.record_deserializer()", + "sagemaker.amazon.common.RecordDeserializer()", + ), + ("_CsvSerializer()", "serializers.CSVSerializer()"), + ("_JsonSerializer()", "serializers.JSONSerializer()"), + ("_NpySerializer()", "serializers.NumpySerializer()"), + ("_CsvDeserializer()", "deserializers.CSVDeserializer()"), + ("BytesDeserializer()", "deserializers.BytesDeserializer()"), + ("StringDeserializer()", "deserializers.StringDeserializer()"), + ("StreamDeserializer()", "deserializers.StreamDeserializer()"), + ("_NumpyDeserializer()", "deserializers.NumpyDeserializer()"), + ("_JsonDeserializer()", "deserializers.JSONDeserializer()"), + ("numpy_to_record_serializer()", "RecordSerializer()"), + ("record_deserializer()", "RecordDeserializer()"), + ], +) +def test_constructor_modify_node(src, expected): + modifier = serde.SerdeConstructorRenamer() + node = ast_call(src) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node) + assert isinstance(modified_node, ast.Call) + + +@pytest.mark.parametrize( + "src, expected", + [ + ("sagemaker.predictor.csv_serializer", True,), + ("sagemaker.predictor.json_serializer", True,), + ("sagemaker.predictor.npy_serializer", True,), + ("sagemaker.predictor.csv_deserializer", True,), + ("sagemaker.predictor.json_deserializer", True,), + ("sagemaker.predictor.numpy_deserializer", True,), + ("csv_serializer", True,), + ("json_serializer", True,), + ("npy_serializer", True,), + ("csv_deserializer", True,), + ("json_deserializer", True,), + ("numpy_deserializer", True,), + ], +) +def test_name_node_should_be_modified(src, expected): + modifier = serde.SerdeObjectRenamer() + node = ast_call(src) + assert modifier.node_should_be_modified(node) is True + + +@pytest.mark.parametrize( + "src, expected", + [ + ("sagemaker.predictor.csv_serializer", "serializers.CSVSerializer()"), + ("sagemaker.predictor.json_serializer", "serializers.JSONSerializer()"), + ("sagemaker.predictor.npy_serializer", "serializers.NumpySerializer()"), + ("sagemaker.predictor.csv_deserializer", "deserializers.CSVDeserializer()"), + ("sagemaker.predictor.json_deserializer", "deserializers.JSONDeserializer()"), + ("sagemaker.predictor.numpy_deserializer", "deserializers.NumpyDeserializer()"), + ("csv_serializer", "serializers.CSVSerializer()"), + ("json_serializer", "serializers.JSONSerializer()"), + ("npy_serializer", "serializers.NumpySerializer()"), + ("csv_deserializer", "deserializers.CSVDeserializer()"), + ("json_deserializer", "deserializers.JSONDeserializer()"), + ("numpy_deserializer", "deserializers.NumpyDeserializer()"), + ], +) +def test_name_modify_node(src, expected): + modifier = serde.SerdeObjectRenamer() + node = ast_call(src) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node) + assert isinstance(modified_node, ast.Call) + + +@pytest.mark.parametrize( + "src, expected", + [ + ("from sagemaker.predictor import _CsvSerializer", True), + ("from sagemaker.predictor import _JsonSerializer", True), + ("from sagemaker.predictor import _NpySerializer", True), + ("from sagemaker.predictor import _CsvDeserializer", True), + ("from sagemaker.predictor import BytesDeserializer", True), + ("from sagemaker.predictor import StringDeserializer", True), + ("from sagemaker.predictor import StreamDeserializer", True), + ("from sagemaker.predictor import _NumpyDeserializer", True), + ("from sagemaker.predictor import _JsonDeserializer", True), + ("from sagemaker.predictor import csv_serializer", True), + ("from sagemaker.predictor import json_serializer", True), + ("from sagemaker.predictor import npy_serializer", True), + ("from sagemaker.predictor import csv_deserializer", True), + ("from sagemaker.predictor import json_deserializer", True), + ("from sagemaker.predictor import numpy_deserializer", True), + ("from sagemaker.predictor import RealTimePredictor, _CsvSerializer", True), + ("from sagemaker.predictor import RealTimePredictor", False), + ("from sagemaker.amazon.common import numpy_to_record_serializer", False), + ], +) +def test_import_from_predictor_node_should_be_modified(src, expected): + modifier = serde.SerdeImportFromPredictorRenamer() + node = ast_import(src) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "src, expected", + [ + ("from sagemaker.predictor import _CsvSerializer", None), + ("from sagemaker.predictor import _JsonSerializer", None), + ("from sagemaker.predictor import _NpySerializer", None), + ("from sagemaker.predictor import _CsvDeserializer", None), + ("from sagemaker.predictor import BytesDeserializer", None), + ("from sagemaker.predictor import StringDeserializer", None), + ("from sagemaker.predictor import StreamDeserializer", None), + ("from sagemaker.predictor import _NumpyDeserializer", None), + ("from sagemaker.predictor import _JsonDeserializer", None), + ("from sagemaker.predictor import csv_serializer", None), + ("from sagemaker.predictor import json_serializer", None), + ("from sagemaker.predictor import npy_serializer", None), + ("from sagemaker.predictor import csv_deserializer", None), + ("from sagemaker.predictor import json_deserializer", None), + ("from sagemaker.predictor import numpy_deserializer", None), + ( + "from sagemaker.predictor import RealTimePredictor, _NpySerializer", + "from sagemaker.predictor import RealTimePredictor", + ), + ], +) +def test_import_from_predictor_modify_node(src, expected): + modifier = serde.SerdeImportFromPredictorRenamer() + node = ast_import(src) + modified_node = modifier.modify_node(node) + assert expected == (modified_node and pasta.dump(modified_node)) + + +@pytest.mark.parametrize( + "import_statement, expected", + [ + ("from sagemaker.amazon.common import numpy_to_record_serializer", True), + ("from sagemaker.amazon.common import record_deserializer", True), + ("from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor", False), + ], +) +def test_import_from_amazon_common_node_should_be_modified(import_statement, expected): + modifier = serde.SerdeImportFromAmazonCommonRenamer() + node = ast_import(import_statement) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "import_statement, expected", + [ + ( + "from sagemaker.amazon.common import numpy_to_record_serializer", + "from sagemaker.amazon.common import RecordSerializer", + ), + ( + "from sagemaker.amazon.common import record_deserializer", + "from sagemaker.amazon.common import RecordDeserializer", + ), + ( + "from sagemaker.amazon.common import numpy_to_record_serializer, record_deserializer", + "from sagemaker.amazon.common import RecordSerializer, RecordDeserializer", + ), + ( + "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, numpy_to_record_serializer", + "from sagemaker.amazon.common import write_spmatrix_to_sparse_tensor, RecordSerializer", + ), + ], +) +def test_import_from_amazon_common_modify_node(import_statement, expected): + modifier = serde.SerdeImportFromAmazonCommonRenamer() + node = ast_import(import_statement) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node) + + +@pytest.mark.parametrize( + "src, expected", + [ + ("serializers.CSVSerializer()", True), + ("serializers.JSONSerializer()", True), + ("serializers.NumpySerializer()", True), + ("pass", False), + ], +) +def test_serializer_module_node_should_be_modified(src, expected): + modifier = serde.SerializerImportInserter() + node = pasta.parse(src) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "src, expected", + [ + ( + "serializers.CSVSerializer()", + "from sagemaker import serializers\nserializers.CSVSerializer()", + ), + ( + "serializers.JSONSerializer()", + "from sagemaker import serializers\nserializers.JSONSerializer()", + ), + ( + "serializers.NumpySerializer()", + "from sagemaker import serializers\nserializers.NumpySerializer()", + ), + ( + "pass\nimport random\nserializers.CSVSerializer()", + "pass\nfrom sagemaker import serializers\nimport random\nserializers.CSVSerializer()", + ), + ], +) +def test_serializer_module_modify_node(src, expected): + modifier = serde.SerializerImportInserter() + node = pasta.parse(src) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node) + + +@pytest.mark.parametrize( + "src, expected", + [ + ("deserializers.CSVDeserializer()", True), + ("deserializers.BytesDeserializer()", True), + ("deserializers.StringDeserializer()", True), + ("deserializers.StreamDeserializer()", True), + ("deserializers.NumpyDeserializer()", True), + ("deserializers.JSONDeserializer()", True), + ("pass", False), + ], +) +def test_deserializer_module_node_should_be_modified(src, expected): + modifier = serde.DeserializerImportInserter() + node = pasta.parse(src) + assert modifier.node_should_be_modified(node) is expected + + +@pytest.mark.parametrize( + "src, expected", + [ + ( + "deserializers.CSVDeserializer()", + "from sagemaker import deserializers\ndeserializers.CSVDeserializer()", + ), + ( + "deserializers.BytesDeserializer()", + "from sagemaker import deserializers\ndeserializers.BytesDeserializer()", + ), + ( + "deserializers.StringDeserializer()", + "from sagemaker import deserializers\ndeserializers.StringDeserializer()", + ), + ( + "deserializers.StreamDeserializer()", + "from sagemaker import deserializers\ndeserializers.StreamDeserializer()", + ), + ( + "deserializers.NumpyDeserializer()", + "from sagemaker import deserializers\ndeserializers.NumpyDeserializer()", + ), + ( + "deserializers.JSONDeserializer()", + "from sagemaker import deserializers\ndeserializers.JSONDeserializer()", + ), + ( + "pass\nimport random\ndeserializers.CSVDeserializer()", + "pass\nfrom sagemaker import deserializers\nimport random\ndeserializers.CSVDeserializer()", + ), + ], +) +def test_deserializer_module_modify_node(src, expected): + modifier = serde.DeserializerImportInserter() + node = pasta.parse(src) + modified_node = modifier.modify_node(node) + assert expected == pasta.dump(modified_node) From 2009c6a00ce6a058c63469db9fffe250660a5528 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 21 Jul 2020 19:58:12 -0500 Subject: [PATCH 02/11] Fix typo --- src/sagemaker/cli/compatibility/v2/modifiers/serde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 94b25153ec..2c9358a5ee 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -310,7 +310,7 @@ class _ImportInserter(Modifier): """A class to insert import statements into the Python module.""" def __init__(self, class_names, import_node): - """Initialize the ``class_names`` and ``import_node attributes``. + """Initialize the ``class_names`` and ``import_node`` attributes. Args: class_names (set): If any of these class names are referenced in the From b8d6c07781c6f180664bc048c003dae77f361bf3 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Tue, 21 Jul 2020 19:58:57 -0500 Subject: [PATCH 03/11] Add test arguments --- .../unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py index d48a4dfc11..11da4fde6d 100644 --- a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_serde.py @@ -33,6 +33,7 @@ ("sagemaker.predictor.StreamDeserializer()", True), ("sagemaker.predictor._NumpyDeserializer()", True), ("sagemaker.predictor._JsonDeserializer()", True), + ("sagemaker.predictor.OtherClass()", False), ("sagemaker.amazon.common.numpy_to_record_serializer()", True), ("sagemaker.amazon.common.record_deserializer()", True), ("_CsvSerializer()", True), From df021c9c129825fab689ee0fdd86330ba8188c9e Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 22 Jul 2020 11:34:31 -0500 Subject: [PATCH 04/11] Fix style issue --- src/sagemaker/cli/compatibility/v2/ast_transformer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 3e12f21a9a..098dd1f0b1 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -91,10 +91,7 @@ def visit_Name(self, node): """ for name_checker in NAME_MODIFIERS: node = name_checker.check_and_modify_node(node) - if node is None: - return None - node = ast.fix_missing_locations(node) - return node + return ast.fix_missing_locations(node) if node else None def visit_Import(self, node): """Visits an ``ast.Import`` node and returns a modified node or None. From 6eab4949c14e4ac1cd35d8a3de7a359208176255 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Wed, 22 Jul 2020 16:37:49 -0500 Subject: [PATCH 05/11] Update v2.rst --- doc/v2.rst | 55 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) diff --git a/doc/v2.rst b/doc/v2.rst index 669ea84be2..33ff8fad6a 100644 --- a/doc/v2.rst +++ b/doc/v2.rst @@ -64,6 +64,30 @@ For more information, see `Upgrade from Legacy TensorFlow Support Date: Fri, 24 Jul 2020 15:02:37 -0500 Subject: [PATCH 06/11] Update src/sagemaker/cli/compatibility/v2/modifiers/serde.py Co-authored-by: Lauren Yu <6631887+laurenyu@users.noreply.github.com> --- src/sagemaker/cli/compatibility/v2/modifiers/serde.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 2c9358a5ee..56e24b5255 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -159,7 +159,7 @@ def modify_node(self, node): return node namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] - subpackage_name = namespace_name[namespace_name.find(".") + 1 :] + subpackage_name = namespace_name.split(".")[1] assert subpackage_name in {"serializers", "deserializers"} return pasta.parse("%s.%s()" % (subpackage_name, new_class_name)).body[0].value From b9ce3782fee6ec51ee0001d0f224d07e5c478c07 Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 24 Jul 2020 15:09:40 -0500 Subject: [PATCH 07/11] Update src/sagemaker/cli/compatibility/v2/modifiers/serde.py Co-authored-by: Eric Johnson <65414824+metrizable@users.noreply.github.com> --- src/sagemaker/cli/compatibility/v2/modifiers/serde.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 56e24b5255..3f59c8d0cd 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -255,10 +255,7 @@ def modify_node(self, node): else: i += 1 - if not node.names: - return None - - return node + return node if node.names else None class SerdeImportFromAmazonCommonRenamer(Modifier): From bbf6f2afc21e576150c738387430097f433b0b2a Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 24 Jul 2020 15:16:48 -0500 Subject: [PATCH 08/11] Address review comments --- .../cli/compatibility/v2/modifiers/serde.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 2c9358a5ee..ae2ea39100 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -22,6 +22,9 @@ from sagemaker.cli.compatibility.v2.modifiers import matching from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier +OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer", "record_deserializer"} +NEW_AMAZON_CLASS_NAMES = {"RecordSerializer", "RecordDeserializer"} + # The values are tuples so that the object can be passed to matching.matches_any. OLD_CLASS_NAME_TO_NAMESPACES = { "_CsvSerializer": ("sagemaker.predictor",), @@ -33,9 +36,10 @@ "StreamDeserializer": ("sagemaker.predictor",), "_NumpyDeserializer": ("sagemaker.predictor",), "_JsonDeserializer": ("sagemaker.predictor",), - "numpy_to_record_serializer": ("sagemaker.amazon.common",), - "record_deserializer": ("sagemaker.amazon.common",), } +OLD_CLASS_NAMES_TO_NAMESPACES.update({ + class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES +}) # The values are tuples so that the object can be passed to matching.matches_any. NEW_CLASS_NAME_TO_NAMESPACES = { @@ -75,21 +79,6 @@ "numpy_deserializer": "NumpyDeserializer", } -OLD_AMAZON_CLASS_NAMES = set( - { - class_name - for class_name, namespaces in OLD_CLASS_NAME_TO_NAMESPACES.items() - if "sagemaker.amazon.common" in namespaces - } -) -NEW_AMAZON_CLASS_NAMES = set( - { - class_name - for class_name, namespaces in NEW_CLASS_NAME_TO_NAMESPACES.items() - if "sagemaker.amazon.common" in namespaces - } -) - NEW_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.values()) OLD_CLASS_NAMES = set(OLD_CLASS_NAME_TO_NEW_CLASS_NAME.keys()) @@ -102,7 +91,7 @@ class SerdeConstructorRenamer(Modifier): def node_should_be_modified(self, node): """Checks if the ``ast.Call`` node instantiates a SerDe class. - This looks for the following calls: + This looks for the following calls (both with and without namespaces): - ``sagemaker.predictor._CsvSerializer`` - ``sagemaker.predictor._JsonSerializer`` @@ -126,7 +115,9 @@ def node_should_be_modified(self, node): return matching.matches_any(node, OLD_CLASS_NAME_TO_NAMESPACES) def modify_node(self, node): - """Modifies the ``ast.Call`` node to use the classes for SerDe + """Updates the name and namespace of the ``ast.Call`` node, as applicable. + + This method modifies the ``ast.Call`` node to use the SerDe classes available in version 2.0 and later of the Python SDK: - ``sagemaker.serializers.CSVSerializer`` @@ -232,7 +223,7 @@ def node_should_be_modified(self, node): from the ``sagemaker.predictor`` module. """ return node.module == "sagemaker.predictor" and any( - [name.name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES) for name in node.names] + name.name in (OLD_CLASS_NAMES | OLD_OBJECT_NAMES) for name in node.names ) def modify_node(self, node): From 68843c1ad05f0fb4d804d5ca61abc540d478ffbe Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Fri, 24 Jul 2020 15:50:53 -0500 Subject: [PATCH 09/11] Address review comments --- doc/v2.rst | 4 ++-- .../cli/compatibility/v2/modifiers/serde.py | 20 ++++++++++++------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/doc/v2.rst b/doc/v2.rst index 33ff8fad6a..06cd16374f 100644 --- a/doc/v2.rst +++ b/doc/v2.rst @@ -67,14 +67,14 @@ Please use :func:`sagemaker.predictor.Predictor.delete_endpoint` instead. Pre-instantiated Serializer and Deserializer Objects ---------------------------------------------------- -The ``csv_serializer``, ``json_serializer``, ``npy_serializer``, ``csv_deserializer``, +The ``csv_serializer``, ``json_serializer``, ``npy_serializer``, ``csv_deserializer``, ``json_deserializer``, and ``numpy_deserializer`` objects have been deprecated. Please instantiate the objects instead. +--------------------------------------------+------------------------------------------------+ | v1.x | v2.0 and later | -+============================================|================================================+ ++============================================+================================================+ | ``sagemaker.predictor.csv_serializer`` | ``sagemaker.deserializers.CSVSerializer()`` | +--------------------------------------------+------------------------------------------------+ | ``sagemaker.predictor.json_serializer`` | ``sagemaker.serializers.JSONSerializer()`` | diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index bd9fdde8b2..7e885b35ec 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -37,9 +37,9 @@ "_NumpyDeserializer": ("sagemaker.predictor",), "_JsonDeserializer": ("sagemaker.predictor",), } -OLD_CLASS_NAMES_TO_NAMESPACES.update({ - class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES -}) +OLD_CLASS_NAME_TO_NAMESPACES.update( + {class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES} +) # The values are tuples so that the object can be passed to matching.matches_any. NEW_CLASS_NAME_TO_NAMESPACES = { @@ -151,8 +151,11 @@ def modify_node(self, node): namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] subpackage_name = namespace_name.split(".")[1] - assert subpackage_name in {"serializers", "deserializers"} - return pasta.parse("%s.%s()" % (subpackage_name, new_class_name)).body[0].value + return ast.Call( + func=ast.Attribute(value=ast.Name(id=subpackage_name), attr=new_class_name), + args=[], + keywords=[], + ) class SerdeObjectRenamer(Modifier): @@ -203,8 +206,11 @@ def modify_node(self, node): new_class_name = OLD_OBJECT_NAME_TO_NEW_CLASS_NAME[object_name] namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] subpackage_name = namespace_name[namespace_name.find(".") + 1 :] - assert subpackage_name in {"serializers", "deserializers"} - return pasta.parse("%s.%s()" % (subpackage_name, new_class_name)).body[0].value + return ast.Call( + func=ast.Attribute(value=ast.Name(id=subpackage_name), attr=new_class_name), + args=[], + keywords=[], + ) class SerdeImportFromPredictorRenamer(Modifier): From a309b62d67cabcd19cd0e30c152845986c82f26f Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Jul 2020 15:39:32 -0500 Subject: [PATCH 10/11] Appease review comments --- .../cli/compatibility/v2/ast_transformer.py | 21 +++++++---- .../cli/compatibility/v2/modifiers/serde.py | 37 ++++++++++--------- 2 files changed, 33 insertions(+), 25 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index 098dd1f0b1..cfc80afecf 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -63,6 +63,7 @@ class ASTTransformer(ast.NodeTransformer): def visit_Call(self, node): """Visits an ``ast.Call`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: @@ -79,6 +80,7 @@ def visit_Call(self, node): def visit_Name(self, node): """Visits an ``ast.Name`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: @@ -86,8 +88,8 @@ def visit_Name(self, node): Returns: ast.AST: if the returned node is None, the original node is removed - from its location. Otherwise, the original node is replaced with the - returned node. + from its location. Otherwise, the original node is replaced with + the returned node. """ for name_checker in NAME_MODIFIERS: node = name_checker.check_and_modify_node(node) @@ -95,6 +97,7 @@ def visit_Name(self, node): def visit_Import(self, node): """Visits an ``ast.Import`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: @@ -102,8 +105,8 @@ def visit_Import(self, node): Returns: ast.AST: if the returned node is None, the original node is removed - from its location. Otherwise, the original node is replaced with the - returned node. + from its location. Otherwise, the original node is replaced with + the returned node. """ for import_checker in IMPORT_MODIFIERS: node = import_checker.check_and_modify_node(node) @@ -111,6 +114,7 @@ def visit_Import(self, node): def visit_Module(self, node): """Visits an ``ast.Module`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. The ``ast.NodeTransformer`` walks the abstract syntax tree and modifies @@ -121,8 +125,8 @@ def visit_Module(self, node): Returns: ast.AST: if the returned node is None, the original node is removed - from its location. Otherwise, the original node is replaced with the - returned node. + from its location. Otherwise, the original node is replaced with + the returned node. """ self.generic_visit(node) for module_checker in MODULE_MODIFIERS: @@ -131,6 +135,7 @@ def visit_Module(self, node): def visit_ImportFrom(self, node): """Visits an ``ast.ImportFrom`` node and returns a modified node or None. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. Args: @@ -138,8 +143,8 @@ def visit_ImportFrom(self, node): Returns: ast.AST: if the returned node is None, the original node is removed - from its location. Otherwise, the original node is replaced with the - returned node. + from its location. Otherwise, the original node is replaced with + the returned node. """ for import_checker in IMPORT_FROM_MODIFIERS: node = import_checker.check_and_modify_node(node) diff --git a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py index 7e885b35ec..bdf6b6c73c 100644 --- a/src/sagemaker/cli/compatibility/v2/modifiers/serde.py +++ b/src/sagemaker/cli/compatibility/v2/modifiers/serde.py @@ -10,32 +10,31 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -"""Classes to modify serializer and deserializer code to be compatible with -version 2.0 and later of the SageMaker Python SDK. -""" +"""Classes to modify SerDe code to be compatibile with version 2.0 and later.""" from __future__ import absolute_import import ast -import pasta - from sagemaker.cli.compatibility.v2.modifiers import matching from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier OLD_AMAZON_CLASS_NAMES = {"numpy_to_record_serializer", "record_deserializer"} NEW_AMAZON_CLASS_NAMES = {"RecordSerializer", "RecordDeserializer"} +OLD_PREDICTOR_CLASS_NAMES = { + "_CsvSerializer", + "_JsonSerializer", + "_NpySerializer", + "_CsvDeserializer", + "BytesDeserializer", + "StringDeserializer", + "StreamDeserializer", + "_NumpyDeserializer", + "_JsonDeserializer", +} # The values are tuples so that the object can be passed to matching.matches_any. OLD_CLASS_NAME_TO_NAMESPACES = { - "_CsvSerializer": ("sagemaker.predictor",), - "_JsonSerializer": ("sagemaker.predictor",), - "_NpySerializer": ("sagemaker.predictor",), - "_CsvDeserializer": ("sagemaker.predictor",), - "BytesDeserializer": ("sagemaker.predictor",), - "StringDeserializer": ("sagemaker.predictor",), - "StreamDeserializer": ("sagemaker.predictor",), - "_NumpyDeserializer": ("sagemaker.predictor",), - "_JsonDeserializer": ("sagemaker.predictor",), + class_name: ("sagemaker.predictor",) for class_name in OLD_PREDICTOR_CLASS_NAMES } OLD_CLASS_NAME_TO_NAMESPACES.update( {class_name: ("sagemaker.amazon.common",) for class_name in OLD_AMAZON_CLASS_NAMES} @@ -205,7 +204,7 @@ def modify_node(self, node): object_name = node.id if isinstance(node, ast.Name) else node.attr new_class_name = OLD_OBJECT_NAME_TO_NEW_CLASS_NAME[object_name] namespace_name = NEW_CLASS_NAME_TO_NAMESPACES[new_class_name][0] - subpackage_name = namespace_name[namespace_name.find(".") + 1 :] + subpackage_name = namespace_name.split(".")[1] return ast.Call( func=ast.Attribute(value=ast.Name(id=subpackage_name), attr=new_class_name), args=[], @@ -375,7 +374,9 @@ def __init__(self): for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES if "Serializer" in class_name } - import_node = pasta.parse("from sagemaker import serializers\n").body[0] + import_node = ast.ImportFrom( + module="sagemaker", names=[ast.alias(name="serializers", asname=None)], level=0 + ) super().__init__(class_names, import_node) @@ -403,5 +404,7 @@ def __init__(self): for class_name in NEW_CLASS_NAMES - NEW_AMAZON_CLASS_NAMES if "Deserializer" in class_name } - import_node = pasta.parse("from sagemaker import deserializers\n").body[0] + import_node = ast.ImportFrom( + module="sagemaker", names=[ast.alias(name="deserializers", asname=None)], level=0 + ) super().__init__(class_names, import_node) From 573787215a998127f41bbeb12fe3ce7fb067fb0d Mon Sep 17 00:00:00 2001 From: Balaji Veeramani Date: Mon, 27 Jul 2020 15:40:26 -0500 Subject: [PATCH 11/11] Update ast_transformer.py --- src/sagemaker/cli/compatibility/v2/ast_transformer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/sagemaker/cli/compatibility/v2/ast_transformer.py b/src/sagemaker/cli/compatibility/v2/ast_transformer.py index cfc80afecf..228b68b594 100644 --- a/src/sagemaker/cli/compatibility/v2/ast_transformer.py +++ b/src/sagemaker/cli/compatibility/v2/ast_transformer.py @@ -71,8 +71,8 @@ def visit_Call(self, node): Returns: ast.AST: if the returned node is None, the original node is removed - from its location. Otherwise, the original node is replaced with the - returned node. + from its location. Otherwise, the original node is replaced with + the returned node. """ for function_checker in FUNCTION_CALL_MODIFIERS: node = function_checker.check_and_modify_node(node)