diff --git a/src/sagemaker/cli/compatibility/v2/files.py b/src/sagemaker/cli/compatibility/v2/files.py index e00bf35245..bfed230f98 100644 --- a/src/sagemaker/cli/compatibility/v2/files.py +++ b/src/sagemaker/cli/compatibility/v2/files.py @@ -48,6 +48,18 @@ def update(self): updated code to an output file. """ + def _make_output_dirs_if_needed(self): + """Checks if the directory path for ``self.output_path`` exists, + and creates the directories if not. This function also logs a warning if + ``self.output_path`` already exists. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file %s", self.output_path) + class PyFileUpdater(FileUpdater): """A class for updating Python (``*.py``) files.""" @@ -88,12 +100,7 @@ def _write_output_file(self, output): Args: output (ast.Module): AST to save as the output file. """ - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) - - if os.path.exists(self.output_path): - LOGGER.warning("Overwriting file %s", self.output_path) + self._make_output_dirs_if_needed() with open(self.output_path, "w") as output_file: output_file.write(pasta.dump(output)) @@ -168,12 +175,7 @@ def _write_output_file(self, output): Args: output (dict): JSON to save as the output file. """ - output_dir = os.path.dirname(self.output_path) - if output_dir and not os.path.exists(output_dir): - os.makedirs(output_dir) - - if os.path.exists(self.output_path): - LOGGER.warning("Overwriting file %s", self.output_path) + self._make_output_dirs_if_needed() with open(self.output_path, "w") as output_file: json.dump(output, output_file, indent=1) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py new file mode 100644 index 0000000000..eb876c6668 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py @@ -0,0 +1,153 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import sys + +import pasta +import pytest + +from sagemaker.cli.compatibility.v2.modifiers import framework_version + + +@pytest.fixture(autouse=True) +def skip_if_py2(): + # Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed. + if sys.version_info.major < 3: + pytest.skip("v2 migration script doesn't support Python 2.") + + +def test_node_should_be_modified_fw_constructor_no_fw_version(): + fw_constructors = ( + "TensorFlow()", + "sagemaker.tensorflow.TensorFlow()", + "TensorFlowModel()", + "sagemaker.tensorflow.TensorFlowModel()", + "MXNet()", + "sagemaker.mxnet.MXNet()", + "MXNetModel()", + "sagemaker.mxnet.MXNetModel()", + "Chainer()", + "sagemaker.chainer.Chainer()", + "ChainerModel()", + "sagemaker.chainer.ChainerModel()", + "PyTorch()", + "sagemaker.pytorch.PyTorch()", + "PyTorchModel()", + "sagemaker.pytorch.PyTorchModel()", + "SKLearn()", + "sagemaker.sklearn.SKLearn()", + "SKLearnModel()", + "sagemaker.sklearn.SKLearnModel()", + ) + + modifier = framework_version.FrameworkVersionEnforcer() + + for constructor in fw_constructors: + node = _ast_call(constructor) + assert modifier.node_should_be_modified(node) is True + + +def test_node_should_be_modified_fw_constructor_with_fw_version(): + fw_constructors = ( + "TensorFlow(framework_version='2.2')", + "sagemaker.tensorflow.TensorFlow(framework_version='2.2')", + "TensorFlowModel(framework_version='1.10')", + "sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')", + "MXNet(framework_version='1.6')", + "sagemaker.mxnet.MXNet(framework_version='1.6')", + "MXNetModel(framework_version='1.6')", + "sagemaker.mxnet.MXNetModel(framework_version='1.6')", + "PyTorch(framework_version='1.4')", + "sagemaker.pytorch.PyTorch(framework_version='1.4')", + "PyTorchModel(framework_version='1.4')", + "sagemaker.pytorch.PyTorchModel(framework_version='1.4')", + "Chainer(framework_version='5.0')", + "sagemaker.chainer.Chainer(framework_version='5.0')", + "ChainerModel(framework_version='5.0')", + "sagemaker.chainer.ChainerModel(framework_version='5.0')", + "SKLearn(framework_version='0.20.0')", + "sagemaker.sklearn.SKLearn(framework_version='0.20.0')", + "SKLearnModel(framework_version='0.20.0')", + "sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')", + ) + + modifier = framework_version.FrameworkVersionEnforcer() + + for constructor in fw_constructors: + node = _ast_call(constructor) + assert modifier.node_should_be_modified(node) is False + + +def test_node_should_be_modified_random_function_call(): + node = _ast_call("sagemaker.session.Session()") + modifier = framework_version.FrameworkVersionEnforcer() + assert modifier.node_should_be_modified(node) is False + + +def test_modify_node_tf(): + classes = ( + "TensorFlow" "sagemaker.tensorflow.TensorFlow", + "TensorFlowModel", + "sagemaker.tensorflow.TensorFlowModel", + ) + _test_modify_node(classes, "1.11.0") + + +def test_modify_node_mx(): + classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel") + _test_modify_node(classes, "1.2.0") + + +def test_modify_node_chainer(): + classes = ( + "Chainer", + "sagemaker.chainer.Chainer", + "ChainerModel", + "sagemaker.chainer.ChainerModel", + ) + _test_modify_node(classes, "4.1.0") + + +def test_modify_node_pt(): + classes = ( + "PyTorch", + "sagemaker.pytorch.PyTorch", + "PyTorchModel", + "sagemaker.pytorch.PyTorchModel", + ) + _test_modify_node(classes, "0.4.0") + + +def test_modify_node_sklearn(): + classes = ( + "SKLearn", + "sagemaker.sklearn.SKLearn", + "SKLearnModel", + "sagemaker.sklearn.SKLearnModel", + ) + _test_modify_node(classes, "0.20.0") + + +def _ast_call(code): + return pasta.parse(code).body[0].value + + +def _test_modify_node(classes, default_version): + modifier = framework_version.FrameworkVersionEnforcer() + for cls in classes: + node = _ast_call("{}()".format(cls)) + modifier.modify_node(node) + + expected_result = "{}(framework_version='{}')".format(cls, default_version) + assert expected_result == pasta.dump(node) diff --git a/tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py b/tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py new file mode 100644 index 0000000000..c30ffe29b6 --- /dev/null +++ b/tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py @@ -0,0 +1,154 @@ +# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import + +import json +import os + +from mock import call, Mock, mock_open, patch + +from sagemaker.cli.compatibility.v2 import files + + +def test_init(): + input_file = "input.py" + output_file = "output.py" + + updater = files.FileUpdater(input_file, output_file) + assert input_file == updater.input_path + assert output_file == updater.output_path + + +@patch("six.moves.builtins.open", mock_open()) +@patch("os.makedirs") +def test_make_output_dirs_if_needed_make_path(makedirs): + output_dir = "dir" + output_path = os.path.join(output_dir, "output.py") + + updater = files.FileUpdater("input.py", output_path) + updater._make_output_dirs_if_needed() + + makedirs.assert_called_with(output_dir) + + +@patch("six.moves.builtins.open", mock_open()) +@patch("os.path.exists", return_value=True) +def test_make_output_dirs_if_needed_overwrite_with_warning(os_path_exists, caplog): + output_file = "output.py" + + updater = files.FileUpdater("input.py", output_file) + updater._make_output_dirs_if_needed() + + assert "Overwriting file {}".format(output_file) in caplog.text + + +@patch("pasta.dump") +@patch("pasta.parse") +@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer") +def test_py_file_update(ast_transformer, pasta_parse, pasta_dump): + input_ast = Mock() + pasta_parse.return_value = input_ast + + output_ast = Mock(_fields=[]) + ast_transformer.return_value.visit.return_value = output_ast + output_code = "print('goodbye')" + pasta_dump.return_value = output_code + + input_file = "input.py" + output_file = "output.py" + + input_code = "print('hello, world!')" + open_mock = mock_open(read_data=input_code) + with patch("six.moves.builtins.open", open_mock): + updater = files.PyFileUpdater(input_file, output_file) + updater.update() + + pasta_parse.assert_called_with(input_code) + ast_transformer.return_value.visit.assert_called_with(input_ast) + + assert call(input_file) in open_mock.mock_calls + assert call(output_file, "w") in open_mock.mock_calls + + open_mock().write.assert_called_with(output_code) + pasta_dump.assert_called_with(output_ast) + + +@patch("json.dump") +@patch("pasta.dump") +@patch("pasta.parse") +@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer") +def test_update(ast_transformer, pasta_parse, pasta_dump, json_dump): + notebook_template = """{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# code to be modified\\n", + "%s" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 + } + """ + input_code = "print('hello, world!')" + input_notebook = notebook_template % input_code + + input_ast = Mock() + pasta_parse.return_value = input_ast + + output_ast = Mock(_fields=[]) + ast_transformer.return_value.visit.return_value = output_ast + output_code = "print('goodbye')" + pasta_dump.return_value = "# code to be modified\n{}".format(output_code) + + input_file = "input.py" + output_file = "output.py" + + open_mock = mock_open(read_data=input_notebook) + with patch("six.moves.builtins.open", open_mock): + updater = files.JupyterNotebookFileUpdater(input_file, output_file) + updater.update() + + pasta_parse.assert_called_with("# code to be modified\n{}".format(input_code)) + ast_transformer.return_value.visit.assert_called_with(input_ast) + pasta_dump.assert_called_with(output_ast) + + assert call(input_file) in open_mock.mock_calls + assert call(output_file, "w") in open_mock.mock_calls + + json_dump.assert_called_with(json.loads(notebook_template % output_code), open_mock(), indent=1) + open_mock().write.assert_called_with("\n")