Skip to content

Commit 3650882

Browse files
committed
infra: add unit tests for v2 migration script file updaters and modifiers
1 parent a680be1 commit 3650882

File tree

3 files changed

+321
-12
lines changed

3 files changed

+321
-12
lines changed

src/sagemaker/cli/compatibility/v2/files.py

+14-12
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,18 @@ def update(self):
4848
updated code to an output file.
4949
"""
5050

51+
def _make_output_dirs_if_needed(self):
52+
"""Checks if the directory path for ``self.output_path`` exists,
53+
and creates the directories if not. This function also logs a warning if
54+
``self.output_path`` already exists.
55+
"""
56+
output_dir = os.path.dirname(self.output_path)
57+
if output_dir and not os.path.exists(output_dir):
58+
os.makedirs(output_dir)
59+
60+
if os.path.exists(self.output_path):
61+
LOGGER.warning("Overwriting file %s", self.output_path)
62+
5163

5264
class PyFileUpdater(FileUpdater):
5365
"""A class for updating Python (``*.py``) files."""
@@ -88,12 +100,7 @@ def _write_output_file(self, output):
88100
Args:
89101
output (ast.Module): AST to save as the output file.
90102
"""
91-
output_dir = os.path.dirname(self.output_path)
92-
if output_dir and not os.path.exists(output_dir):
93-
os.makedirs(output_dir)
94-
95-
if os.path.exists(self.output_path):
96-
LOGGER.warning("Overwriting file %s", self.output_path)
103+
self._make_output_dirs_if_needed()
97104

98105
with open(self.output_path, "w") as output_file:
99106
output_file.write(pasta.dump(output))
@@ -168,12 +175,7 @@ def _write_output_file(self, output):
168175
Args:
169176
output (dict): JSON to save as the output file.
170177
"""
171-
output_dir = os.path.dirname(self.output_path)
172-
if output_dir and not os.path.exists(output_dir):
173-
os.makedirs(output_dir)
174-
175-
if os.path.exists(self.output_path):
176-
LOGGER.warning("Overwriting file %s", self.output_path)
178+
self._make_output_dirs_if_needed()
177179

178180
with open(self.output_path, "w") as output_file:
179181
json.dump(output, output_file, indent=1)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import sys
16+
17+
import pasta
18+
import pytest
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import framework_version
21+
22+
23+
@pytest.fixture(autouse=True)
24+
def skip_if_py2():
25+
# Remove once https://github.com/aws/sagemaker-python-sdk/issues/1461 is addressed.
26+
if sys.version_info.major < 3:
27+
pytest.skip("v2 migration script doesn't support Python 2.")
28+
29+
30+
def test_node_should_be_modified_fw_constructor_no_fw_version():
31+
fw_constructors = (
32+
"TensorFlow()",
33+
"sagemaker.tensorflow.TensorFlow()",
34+
"TensorFlowModel()",
35+
"sagemaker.tensorflow.TensorFlowModel()",
36+
"MXNet()",
37+
"sagemaker.mxnet.MXNet()",
38+
"MXNetModel()",
39+
"sagemaker.mxnet.MXNetModel()",
40+
"Chainer()",
41+
"sagemaker.chainer.Chainer()",
42+
"ChainerModel()",
43+
"sagemaker.chainer.ChainerModel()",
44+
"PyTorch()",
45+
"sagemaker.pytorch.PyTorch()",
46+
"PyTorchModel()",
47+
"sagemaker.pytorch.PyTorchModel()",
48+
"SKLearn()",
49+
"sagemaker.sklearn.SKLearn()",
50+
"SKLearnModel()",
51+
"sagemaker.sklearn.SKLearnModel()",
52+
)
53+
54+
modifier = framework_version.FrameworkVersionEnforcer()
55+
56+
for constructor in fw_constructors:
57+
node = _ast_call(constructor)
58+
assert modifier.node_should_be_modified(node) is True
59+
60+
61+
def test_node_should_be_modified_fw_constructor_with_fw_version():
62+
fw_constructors = (
63+
"TensorFlow(framework_version='2.2')",
64+
"sagemaker.tensorflow.TensorFlow(framework_version='2.2')",
65+
"TensorFlowModel(framework_version='1.10')",
66+
"sagemaker.tensorflow.TensorFlowModel(framework_version='1.10')",
67+
"MXNet(framework_version='1.6')",
68+
"sagemaker.mxnet.MXNet(framework_version='1.6')",
69+
"MXNetModel(framework_version='1.6')",
70+
"sagemaker.mxnet.MXNetModel(framework_version='1.6')",
71+
"PyTorch(framework_version='1.4')",
72+
"sagemaker.pytorch.PyTorch(framework_version='1.4')",
73+
"PyTorchModel(framework_version='1.4')",
74+
"sagemaker.pytorch.PyTorchModel(framework_version='1.4')",
75+
"Chainer(framework_version='5.0')",
76+
"sagemaker.chainer.Chainer(framework_version='5.0')",
77+
"ChainerModel(framework_version='5.0')",
78+
"sagemaker.chainer.ChainerModel(framework_version='5.0')",
79+
"SKLearn(framework_version='0.20.0')",
80+
"sagemaker.sklearn.SKLearn(framework_version='0.20.0')",
81+
"SKLearnModel(framework_version='0.20.0')",
82+
"sagemaker.sklearn.SKLearnModel(framework_version='0.20.0')",
83+
)
84+
85+
modifier = framework_version.FrameworkVersionEnforcer()
86+
87+
for constructor in fw_constructors:
88+
node = _ast_call(constructor)
89+
assert modifier.node_should_be_modified(node) is False
90+
91+
92+
def test_node_should_be_modified_random_function_call():
93+
node = _ast_call("sagemaker.session.Session()")
94+
modifier = framework_version.FrameworkVersionEnforcer()
95+
assert modifier.node_should_be_modified(node) is False
96+
97+
98+
def test_modify_node_tf():
99+
classes = (
100+
"TensorFlow" "sagemaker.tensorflow.TensorFlow",
101+
"TensorFlowModel",
102+
"sagemaker.tensorflow.TensorFlowModel",
103+
)
104+
_test_modify_node(classes, "1.11.0")
105+
106+
107+
def test_modify_node_mx():
108+
classes = ("MXNet", "sagemaker.mxnet.MXNet", "MXNetModel", "sagemaker.mxnet.MXNetModel")
109+
_test_modify_node(classes, "1.2.0")
110+
111+
112+
def test_modify_node_chainer():
113+
classes = (
114+
"Chainer",
115+
"sagemaker.chainer.Chainer",
116+
"ChainerModel",
117+
"sagemaker.chainer.ChainerModel",
118+
)
119+
_test_modify_node(classes, "4.1.0")
120+
121+
122+
def test_modify_node_pt():
123+
classes = (
124+
"PyTorch",
125+
"sagemaker.pytorch.PyTorch",
126+
"PyTorchModel",
127+
"sagemaker.pytorch.PyTorchModel",
128+
)
129+
_test_modify_node(classes, "0.4.0")
130+
131+
132+
def test_modify_node_sklearn():
133+
classes = (
134+
"SKLearn",
135+
"sagemaker.sklearn.SKLearn",
136+
"SKLearnModel",
137+
"sagemaker.sklearn.SKLearnModel",
138+
)
139+
_test_modify_node(classes, "0.20.0")
140+
141+
142+
def _ast_call(code):
143+
return pasta.parse(code).body[0].value
144+
145+
146+
def _test_modify_node(classes, default_version):
147+
modifier = framework_version.FrameworkVersionEnforcer()
148+
for cls in classes:
149+
node = _ast_call("{}()".format(cls))
150+
modifier.modify_node(node)
151+
152+
expected_result = "{}(framework_version='{}')".format(cls, default_version)
153+
assert expected_result == pasta.dump(node)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import json
16+
import os
17+
18+
from mock import call, Mock, mock_open, patch
19+
20+
from sagemaker.cli.compatibility.v2 import files
21+
22+
23+
def test_init():
24+
input_file = "input.py"
25+
output_file = "output.py"
26+
27+
updater = files.FileUpdater(input_file, output_file)
28+
assert input_file == updater.input_path
29+
assert output_file == updater.output_path
30+
31+
32+
@patch("six.moves.builtins.open", mock_open())
33+
@patch("os.makedirs")
34+
def test_make_output_dirs_if_needed_make_path(makedirs):
35+
output_dir = "dir"
36+
output_path = os.path.join(output_dir, "output.py")
37+
38+
updater = files.FileUpdater("input.py", output_path)
39+
updater._make_output_dirs_if_needed()
40+
41+
makedirs.assert_called_with(output_dir)
42+
43+
44+
@patch("six.moves.builtins.open", mock_open())
45+
@patch("os.path.exists", return_value=True)
46+
def test_make_output_dirs_if_needed_overwrite_with_warning(os_path_exists, caplog):
47+
output_file = "output.py"
48+
49+
updater = files.FileUpdater("input.py", output_file)
50+
updater._make_output_dirs_if_needed()
51+
52+
assert "Overwriting file {}".format(output_file) in caplog.text
53+
54+
55+
@patch("pasta.dump")
56+
@patch("pasta.parse")
57+
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
58+
def test_py_file_update(ast_transformer, pasta_parse, pasta_dump):
59+
input_ast = Mock()
60+
pasta_parse.return_value = input_ast
61+
62+
output_ast = Mock(_fields=[])
63+
ast_transformer.return_value.visit.return_value = output_ast
64+
output_code = "print('goodbye')"
65+
pasta_dump.return_value = output_code
66+
67+
input_file = "input.py"
68+
output_file = "output.py"
69+
70+
input_code = "print('hello, world!')"
71+
open_mock = mock_open(read_data=input_code)
72+
with patch("six.moves.builtins.open", open_mock):
73+
updater = files.PyFileUpdater(input_file, output_file)
74+
updater.update()
75+
76+
pasta_parse.assert_called_with(input_code)
77+
ast_transformer.return_value.visit.assert_called_with(input_ast)
78+
79+
assert call(input_file) in open_mock.mock_calls
80+
assert call(output_file, "w") in open_mock.mock_calls
81+
82+
open_mock().write.assert_called_with(output_code)
83+
pasta_dump.assert_called_with(output_ast)
84+
85+
86+
@patch("json.dump")
87+
@patch("pasta.dump")
88+
@patch("pasta.parse")
89+
@patch("sagemaker.cli.compatibility.v2.files.ASTTransformer")
90+
def test_update(ast_transformer, pasta_parse, pasta_dump, json_dump):
91+
notebook_template = """{
92+
"cells": [
93+
{
94+
"cell_type": "code",
95+
"execution_count": 1,
96+
"metadata": {},
97+
"outputs": [],
98+
"source": [
99+
"# code to be modified\\n",
100+
"%s"
101+
]
102+
}
103+
],
104+
"metadata": {
105+
"kernelspec": {
106+
"display_name": "Python 3",
107+
"language": "python",
108+
"name": "python3"
109+
},
110+
"language_info": {
111+
"codemirror_mode": {
112+
"name": "ipython",
113+
"version": 3
114+
},
115+
"file_extension": ".py",
116+
"mimetype": "text/x-python",
117+
"name": "python",
118+
"nbconvert_exporter": "python",
119+
"pygments_lexer": "ipython3",
120+
"version": "3.6.8"
121+
}
122+
},
123+
"nbformat": 4,
124+
"nbformat_minor": 2
125+
}
126+
"""
127+
input_code = "print('hello, world!')"
128+
input_notebook = notebook_template % input_code
129+
130+
input_ast = Mock()
131+
pasta_parse.return_value = input_ast
132+
133+
output_ast = Mock(_fields=[])
134+
ast_transformer.return_value.visit.return_value = output_ast
135+
output_code = "print('goodbye')"
136+
pasta_dump.return_value = "# code to be modified\n{}".format(output_code)
137+
138+
input_file = "input.py"
139+
output_file = "output.py"
140+
141+
open_mock = mock_open(read_data=input_notebook)
142+
with patch("six.moves.builtins.open", open_mock):
143+
updater = files.JupyterNotebookFileUpdater(input_file, output_file)
144+
updater.update()
145+
146+
pasta_parse.assert_called_with("# code to be modified\n{}".format(input_code))
147+
ast_transformer.return_value.visit.assert_called_with(input_ast)
148+
pasta_dump.assert_called_with(output_ast)
149+
150+
assert call(input_file) in open_mock.mock_calls
151+
assert call(output_file, "w") in open_mock.mock_calls
152+
153+
json_dump.assert_called_with(json.loads(notebook_template % output_code), open_mock(), indent=1)
154+
open_mock().write.assert_called_with("\n")

0 commit comments

Comments
 (0)