Skip to content

infra: add unit tests for v2 migration script file updaters and modifiers #1536

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions src/sagemaker/cli/compatibility/v2/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,18 @@ def update(self):
updated code to an output file.
"""

def _make_output_dirs_if_needed(self):
"""Checks if the directory path for ``self.output_path`` exists,
and creates the directories if not. This function also logs a warning if
``self.output_path`` already exists.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)


class PyFileUpdater(FileUpdater):
"""A class for updating Python (``*.py``) files."""
Expand Down Expand Up @@ -88,12 +100,7 @@ def _write_output_file(self, output):
Args:
output (ast.Module): AST to save as the output file.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)
self._make_output_dirs_if_needed()

with open(self.output_path, "w") as output_file:
output_file.write(pasta.dump(output))
Expand Down Expand Up @@ -168,12 +175,7 @@ def _write_output_file(self, output):
Args:
output (dict): JSON to save as the output file.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)
self._make_output_dirs_if_needed()

with open(self.output_path, "w") as output_file:
json.dump(output, output_file, indent=1)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import sys

import pasta
import pytest

from sagemaker.cli.compatibility.v2.modifiers import framework_version


@pytest.fixture(autouse=True)
def skip_if_py2():
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
if sys.version_info.major < 3:
pytest.skip("v2 migration script doesn't support Python 2.")


def test_node_should_be_modified_fw_constructor_no_fw_version():
fw_constructors = (
"TensorFlow()",
"sagemaker.tensorflow.TensorFlow()",
"TensorFlowModel()",
"sagemaker.tensorflow.TensorFlowModel()",
"MXNet()",
"sagemaker.mxnet.MXNet()",
"MXNetModel()",
"sagemaker.mxnet.MXNetModel()",
"Chainer()",
"sagemaker.chainer.Chainer()",
"ChainerModel()",
"sagemaker.chainer.ChainerModel()",
"PyTorch()",
"sagemaker.pytorch.PyTorch()",
"PyTorchModel()",
"sagemaker.pytorch.PyTorchModel()",
"SKLearn()",
"sagemaker.sklearn.SKLearn()",
"SKLearnModel()",
"sagemaker.sklearn.SKLearnModel()",
)

modifier = framework_version.FrameworkVersionEnforcer()

for constructor in fw_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is True


def test_node_should_be_modified_fw_constructor_with_fw_version():
fw_constructors = (
"TensorFlow(framework_version='2.2')",
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')",
"TensorFlowModel(framework_version='1.10')",
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')",
"MXNet(framework_version='1.6')",
"sagemaker.mxnet.MXNet(framework_version='1.6')",
"MXNetModel(framework_version='1.6')",
"sagemaker.mxnet.MXNetModel(framework_version='1.6')",
"PyTorch(framework_version='1.4')",
"sagemaker.pytorch.PyTorch(framework_version='1.4')",
"PyTorchModel(framework_version='1.4')",
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')",
"Chainer(framework_version='5.0')",
"sagemaker.chainer.Chainer(framework_version='5.0')",
"ChainerModel(framework_version='5.0')",
"sagemaker.chainer.ChainerModel(framework_version='5.0')",
"SKLearn(framework_version='0.20.0')",
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')",
"SKLearnModel(framework_version='0.20.0')",
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')",
)

modifier = framework_version.FrameworkVersionEnforcer()

for constructor in fw_constructors:
node = _ast_call(constructor)
assert modifier.node_should_be_modified(node) is False


def test_node_should_be_modified_random_function_call():
node = _ast_call("sagemaker.session.Session()")
modifier = framework_version.FrameworkVersionEnforcer()
assert modifier.node_should_be_modified(node) is False


def test_modify_node_tf():
classes = (
"TensorFlow" "sagemaker.tensorflow.TensorFlow",
"TensorFlowModel",
"sagemaker.tensorflow.TensorFlowModel",
)
_test_modify_node(classes, "1.11.0")


def test_modify_node_mx():
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel")
_test_modify_node(classes, "1.2.0")


def test_modify_node_chainer():
classes = (
"Chainer",
"sagemaker.chainer.Chainer",
"ChainerModel",
"sagemaker.chainer.ChainerModel",
)
_test_modify_node(classes, "4.1.0")


def test_modify_node_pt():
classes = (
"PyTorch",
"sagemaker.pytorch.PyTorch",
"PyTorchModel",
"sagemaker.pytorch.PyTorchModel",
)
_test_modify_node(classes, "0.4.0")


def test_modify_node_sklearn():
classes = (
"SKLearn",
"sagemaker.sklearn.SKLearn",
"SKLearnModel",
"sagemaker.sklearn.SKLearnModel",
)
_test_modify_node(classes, "0.20.0")


def _ast_call(code):
return pasta.parse(code).body[0].value


def _test_modify_node(classes, default_version):
modifier = framework_version.FrameworkVersionEnforcer()
for cls in classes:
node = _ast_call("{}()".format(cls))
modifier.modify_node(node)

expected_result = "{}(framework_version='{}')".format(cls, default_version)
assert expected_result == pasta.dump(node)
154 changes: 154 additions & 0 deletions tests/unit/sagemaker/cli/compatibility/v2/test_file_updaters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
from __future__ import absolute_import

import json
import os

from mock import call, Mock, mock_open, patch

from sagemaker.cli.compatibility.v2 import files


def test_init():
input_file = "input.py"
output_file = "output.py"

updater = files.FileUpdater(input_file, output_file)
assert input_file == updater.input_path
assert output_file == updater.output_path


@patch("six.moves.builtins.open", mock_open())
@patch("os.makedirs")
def test_make_output_dirs_if_needed_make_path(makedirs):
output_dir = "dir"
output_path = os.path.join(output_dir, "output.py")

updater = files.FileUpdater("input.py", output_path)
updater._make_output_dirs_if_needed()

makedirs.assert_called_with(output_dir)


@patch("six.moves.builtins.open", mock_open())
@patch("os.path.exists", return_value=True)
def test_make_output_dirs_if_needed_overwrite_with_warning(os_path_exists, caplog):
output_file = "output.py"

updater = files.FileUpdater("input.py", output_file)
updater._make_output_dirs_if_needed()

assert "Overwriting file {}".format(output_file) in caplog.text


@patch("pasta.dump")
@patch("pasta.parse")
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
def test_py_file_update(ast_transformer, pasta_parse, pasta_dump):
input_ast = Mock()
pasta_parse.return_value = input_ast

output_ast = Mock(_fields=[])
ast_transformer.return_value.visit.return_value = output_ast
output_code = "print('goodbye')"
pasta_dump.return_value = output_code

input_file = "input.py"
output_file = "output.py"

input_code = "print('hello, world!')"
open_mock = mock_open(read_data=input_code)
with patch("six.moves.builtins.open", open_mock):
updater = files.PyFileUpdater(input_file, output_file)
updater.update()

pasta_parse.assert_called_with(input_code)
ast_transformer.return_value.visit.assert_called_with(input_ast)

assert call(input_file) in open_mock.mock_calls
assert call(output_file, "w") in open_mock.mock_calls

open_mock().write.assert_called_with(output_code)
pasta_dump.assert_called_with(output_ast)


@patch("json.dump")
@patch("pasta.dump")
@patch("pasta.parse")
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
def test_update(ast_transformer, pasta_parse, pasta_dump, json_dump):
notebook_template = """{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# code to be modified\\n",
"%s"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
"""
input_code = "print('hello, world!')"
input_notebook = notebook_template % input_code

input_ast = Mock()
pasta_parse.return_value = input_ast

output_ast = Mock(_fields=[])
ast_transformer.return_value.visit.return_value = output_ast
output_code = "print('goodbye')"
pasta_dump.return_value = "# code to be modified\n{}".format(output_code)

input_file = "input.py"
output_file = "output.py"

open_mock = mock_open(read_data=input_notebook)
with patch("six.moves.builtins.open", open_mock):
updater = files.JupyterNotebookFileUpdater(input_file, output_file)
updater.update()

pasta_parse.assert_called_with("# code to be modified\n{}".format(input_code))
ast_transformer.return_value.visit.assert_called_with(input_ast)
pasta_dump.assert_called_with(output_ast)

assert call(input_file) in open_mock.mock_calls
assert call(output_file, "w") in open_mock.mock_calls

json_dump.assert_called_with(json.loads(notebook_template % output_code), open_mock(), indent=1)
open_mock().write.assert_called_with("\n")