diff --git a/tests/unit/v2/__init__.py b/tests/unit/v2/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/unit/v2/test_transformer.py b/tests/unit/v2/test_transformer.py new file mode 100644 index 0000000000..e8487de0b1 --- /dev/null +++ b/tests/unit/v2/test_transformer.py @@ -0,0 +1,38 @@ +from __future__ import absolute_import + +import ast +from sagemaker.tools.compatibility.v2.ast_transformer import ASTTransformer +import pasta + + +def test_code_needs_transform(): + simple = """ +TensorFlow(entry_point="foo.py") +sagemaker.tensorflow.TensorFlow() +m = MXNet() +sagemaker.mxnet.MXNet() +""" + + transformer_class = ASTTransformer() + rewrite = transformer_class.visit(ast.parse(simple)) + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + + assert pasta.dump(rewrite) == expected + + +def test_code_does_not_need_transform(): + simple = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + transformer_class = ASTTransformer() + rewrite = transformer_class.visit(ast.parse(simple)) + expected = """TensorFlow(entry_point='foo.py', framework_version='1.11.0') +sagemaker.tensorflow.TensorFlow(framework_version='1.11.0') +m = MXNet(framework_version='1.2.0') +sagemaker.mxnet.MXNet(framework_version='1.2.0')\n""" + + assert pasta.dump(rewrite) == expected diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000..96abea2567 --- /dev/null +++ b/tools/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tools to assist with using the SageMake Python SDK.""" +from __future__ import absolute_import diff --git a/tools/compatibility/__init__.py b/tools/compatibility/__init__.py new file mode 100644 index 0000000000..e3a46fe406 --- /dev/null +++ b/tools/compatibility/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tools to assist with compatibility between SageMaker Python SDK versions.""" +from __future__ import absolute_import diff --git a/tools/compatibility/v2/__init__.py b/tools/compatibility/v2/__init__.py new file mode 100644 index 0000000000..b44e22749e --- /dev/null +++ b/tools/compatibility/v2/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Tools to assist with upgrading to v2 of the SageMaker Python SDK.""" +from __future__ import absolute_import diff --git a/tools/compatibility/v2/ast_transformer.py b/tools/compatibility/v2/ast_transformer.py new file mode 100644 index 0000000000..7171840ad0 --- /dev/null +++ b/tools/compatibility/v2/ast_transformer.py @@ -0,0 +1,41 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code.""" +from __future__ import absolute_import + +import ast + +from tools.compatibility.v2.modifiers import framework_version + +FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()] + + +class ASTTransformer(ast.NodeTransformer): + """An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and + modifies nodes to upgrade the given SageMaker Python SDK code. + """ + + def visit_Call(self, node): + """Visits an ``ast.Call`` node and returns a modified node, if needed. + See https://docs.python.org/3/library/ast.html#ast.NodeTransformer. + + Args: + node (ast.Call): a node that represents a function call. + + Returns: + ast.Call: a node that represents a function call, which has + potentially been modified from the original input. + """ + for function_checker in FUNCTION_CALL_MODIFIERS: + function_checker.check_and_modify_node(node) + return node diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py new file mode 100644 index 0000000000..b385274093 --- /dev/null +++ b/tools/compatibility/v2/files.py @@ -0,0 +1,180 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes for updating code in files.""" +from __future__ import absolute_import + +from abc import abstractmethod +import json +import logging +import os + +import pasta + +from ast_transformer import ASTTransformer + +LOGGER = logging.getLogger(__name__) + + +class FileUpdater(object): + """An abstract class for updating files.""" + + def __init__(self, input_path, output_path): + """Creates a ``FileUpdater`` for updating a file so that + it is compatible with v2 of the SageMaker Python SDK. + + Args: + input_path (str): Location of the input file. + output_path (str): Desired location for the output file. + If the directories don't already exist, then they are created. + If a file exists at ``output_path``, then it is overwritten. + """ + self.input_path = input_path + self.output_path = output_path + + @abstractmethod + def update(self): + """Reads the input file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + + +class PyFileUpdater(FileUpdater): + """A class for updating Python (``*.py``) files.""" + + def update(self): + """Reads the input Python file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + output = self._update_ast(self._read_input_file()) + self._write_output_file(output) + + def _update_ast(self, input_ast): + """Updates an abstract syntax tree (AST) so that it is compatible + with v2 of the SageMaker Python SDK. + + Args: + input_ast (ast.Module): AST to be updated for use with Python SDK v2. + + Returns: + ast.Module: Updated AST that is compatible with Python SDK v2. + """ + return ASTTransformer().visit(input_ast) + + def _read_input_file(self): + """Reads input file and parses it as an abstract syntax tree (AST). + + Returns: + ast.Module: AST representation of the input file. + """ + with open(self.input_path) as input_file: + return pasta.parse(input_file.read()) + + def _write_output_file(self, output): + """Writes abstract syntax tree (AST) to output file. + Creates the directories for the output path, if needed. + + Args: + output (ast.Module): AST to save as the output file. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file %s", self.output_path) + + with open(self.output_path, "w") as output_file: + output_file.write(pasta.dump(output)) + + +class JupyterNotebookFileUpdater(FileUpdater): + """A class for updating Jupyter notebook (``*.ipynb``) files. + + For more on this file format, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat. + """ + + def update(self): + """Reads the input Jupyter notebook file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + nb_json = self._read_input_file() + for cell in nb_json["cells"]: + if cell["cell_type"] == "code": + updated_source = self._update_code_from_cell(cell) + cell["source"] = updated_source + + self._write_output_file(nb_json) + + def _update_code_from_cell(self, cell): + """Updates the code from a code cell so that it is + compatible with v2 of the SageMaker Python SDK. + + Args: + cell (dict): A dictionary representation of a code cell from + a Jupyter notebook. For more info, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells. + + Returns: + list[str]: A list of strings containing the lines of updated code that + can be used for the "source" attribute of a Jupyter notebook code cell. + """ + code = "".join(cell["source"]) + updated_ast = ASTTransformer().visit(pasta.parse(code)) + updated_code = pasta.dump(updated_ast) + return self._code_str_to_source_list(updated_code) + + def _code_str_to_source_list(self, code): + """Converts a string of code into a list for a Jupyter notebook code cell. + + Args: + code (str): Code to be converted. + + Returns: + list[str]: A list of strings containing the lines of code that + can be used for the "source" attribute of a Jupyter notebook code cell. + Each element of the list (i.e. line of code) contains a + trailing newline character ("\n") except for the last element. + """ + source_list = ["{}\n".format(s) for s in code.split("\n")] + source_list[-1] = source_list[-1].rstrip("\n") + return source_list + + def _read_input_file(self): + """Reads input file and parses it as JSON. + + Returns: + dict: JSON representation of the input file. + """ + with open(self.input_path) as input_file: + return json.load(input_file) + + def _write_output_file(self, output): + """Writes JSON to output file. Creates the directories for the output path, if needed. + + Args: + output (dict): JSON to save as the output file. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file %s", self.output_path) + + with open(self.output_path, "w") as output_file: + json.dump(output, output_file, indent=1) + output_file.write("\n") # json.dump does not write trailing newline diff --git a/tools/compatibility/v2/modifiers/__init__.py b/tools/compatibility/v2/modifiers/__init__.py new file mode 100644 index 0000000000..9fca9c35da --- /dev/null +++ b/tools/compatibility/v2/modifiers/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Classes for modifying AST nodes""" +from __future__ import absolute_import diff --git a/tools/compatibility/v2/modifiers/framework_version.py b/tools/compatibility/v2/modifiers/framework_version.py new file mode 100644 index 0000000000..0115526549 --- /dev/null +++ b/tools/compatibility/v2/modifiers/framework_version.py @@ -0,0 +1,130 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A class to ensure that ``framework_version`` is defined when constructing framework classes.""" +from __future__ import absolute_import + +import ast + +from tools.compatibility.v2.modifiers.modifier import Modifier + +FRAMEWORK_DEFAULTS = { + "Chainer": "4.1.0", + "MXNet": "1.2.0", + "PyTorch": "0.4.0", + "SKLearn": "0.20.0", + "TensorFlow": "1.11.0", +} + +FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys()) +# TODO: check for sagemaker.tensorflow.serving.Model +FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS] +FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS] + + +class FrameworkVersionEnforcer(Modifier): + """A class to ensure that ``framework_version`` is defined when + instantiating a framework estimator or model. + """ + + def node_should_be_modified(self, node): + """Checks if the ast.Call node instantiates a framework estimator or model, + but doesn't specify the ``framework_version`` parameter. + + This looks for the following formats: + + - ``TensorFlow`` + - ``sagemaker.tensorflow.TensorFlow`` + + where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow. + + Args: + node (ast.Call): a node that represents a function call. For more, + see https://docs.python.org/3/library/ast.html#abstract-grammar. + + Returns: + bool: If the ``ast.Call`` is instantiating a framework class that + should specify ``framework_version``, but doesn't. + """ + if self._is_framework_constructor(node): + return not self._fw_version_in_keywords(node) + + return False + + def _is_framework_constructor(self, node): + """Checks if the ``ast.Call`` node represents a call of the form + or sagemaker... + """ + # Check for call + if isinstance(node.func, ast.Name): + if node.func.id in FRAMEWORK_CLASSES: + return True + + # Check for sagemaker.. call + ends_with_framework_constructor = ( + isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES + ) + + is_in_framework_module = ( + isinstance(node.func.value, ast.Attribute) + and node.func.value.attr in FRAMEWORK_MODULES + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "sagemaker" + ) + + return ends_with_framework_constructor and is_in_framework_module + + def _fw_version_in_keywords(self, node): + """Checks if the ``ast.Call`` node's keywords contain ``framework_version``.""" + for kw in node.keywords: + if kw.arg == "framework_version" and kw.value: + return True + return False + + def modify_node(self, node): + """Modifies the ``ast.Call`` node's keywords to include ``framework_version``. + + The ``framework_version`` value is determined by the framework: + + - Chainer: "4.1.0" + - MXNet: "1.2.0" + - PyTorch: "0.4.0" + - SKLearn: "0.20.0" + - TensorFlow: "1.11.0" + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + """ + framework = self._framework_name_from_node(node) + node.keywords.append( + ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework])) + ) + + def _framework_name_from_node(self, node): + """Retrieves the framework name based on the function call. + + Args: + node (ast.Call): a node that represents the constructor of a framework class. + This can represent either or sagemaker... + + Returns: + str: the (capitalized) framework name. + """ + if isinstance(node.func, ast.Name): + framework = node.func.id + elif isinstance(node.func, ast.Attribute): + framework = node.func.attr + + if framework.endswith("Model"): + framework = framework[: framework.find("Model")] + + return framework diff --git a/tools/compatibility/v2/modifiers/modifier.py b/tools/compatibility/v2/modifiers/modifier.py new file mode 100644 index 0000000000..c1d53dfc85 --- /dev/null +++ b/tools/compatibility/v2/modifiers/modifier.py @@ -0,0 +1,35 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Abstract class for modifying AST nodes.""" +from __future__ import absolute_import + +from abc import abstractmethod + + +class Modifier(object): + """Abstract class to take in an AST node, check if it needs modification, + and potentially modify the node. + """ + + def check_and_modify_node(self, node): + """Check an AST node, and modify it if applicable.""" + if self.node_should_be_modified(node): + self.modify_node(node) + + @abstractmethod + def node_should_be_modified(self, node): + """Check if an AST node should be modified.""" + + @abstractmethod + def modify_node(self, node): + """Modify an AST node.""" diff --git a/tools/compatibility/v2/sagemaker_upgrade_v2.py b/tools/compatibility/v2/sagemaker_upgrade_v2.py new file mode 100644 index 0000000000..2238775e1a --- /dev/null +++ b/tools/compatibility/v2/sagemaker_upgrade_v2.py @@ -0,0 +1,76 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""A tool to upgrade SageMaker Python SDK code to be compatible with v2.""" +from __future__ import absolute_import + +import argparse +import os + +import files + +_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater} + + +def _update_file(input_file, output_file): + """Update a file to be compatible with v2 of the SageMaker Python SDK, + and write the updated source to the output file. + + Args: + input_file (str): The path to the file to be updated. + output_file (str): The output file destination. + + Raises: + ValueError: If the input and output filename extensions don't match, + or if the file extensions are neither ".py" nor ".ipynb". + """ + input_file_ext = os.path.splitext(input_file)[1] + output_file_ext = os.path.splitext(output_file)[1] + + if input_file_ext != output_file_ext: + raise ValueError( + "Mismatched file extensions: input: {}, output: {}".format( + input_file_ext, output_file_ext + ) + ) + + if input_file_ext not in _EXT_TO_UPDATER_CLS: + raise ValueError("Unrecognized file extension: {}".format(input_file_ext)) + + updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext] + updater_cls(input_path=input_file, output_path=output_file).update() + + +def _parse_args(): + """Parses CLI arguments""" + parser = argparse.ArgumentParser( + description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. " + "Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" + ) + parser.add_argument( + "--in-file", + help="If converting a single file, the file to convert. The file's extension " + "must be either '.py' or '.ipynb'.", + ) + parser.add_argument( + "--out-file", + help="If converting a single file, the output file destination. The file's extension " + "must be either '.py' or '.ipynb'. If needed, directories in the output path are created. " + "If the output file already exists, it is overwritten.", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = _parse_args() + _update_file(args.in_file, args.out_file) diff --git a/tox.ini b/tox.ini index c83ed2daf6..fa43d1e0a1 100644 --- a/tox.ini +++ b/tox.ini @@ -80,7 +80,7 @@ skip_install = true deps = pylint==2.5.2 commands = - python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker + python -m pylint --rcfile=.pylintrc -j 0 src/sagemaker tools [testenv:twine] basepython = python3