Skip to content

Commit de6ebfe

Browse files
committed
add unit tests
1 parent e0dd8db commit de6ebfe

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
18+
19+
20+
def test_node_should_be_modified_tf_constructor_legacy_mode():
21+
tf_legacy_mode_constructors = (
22+
"TensorFlow(script_mode=False)",
23+
"TensorFlow(script_mode=None)",
24+
"TensorFlow(py_version='py2')",
25+
"TensorFlow()",
26+
"sagemaker.tensorflow.TensorFlow(script_mode=False)",
27+
"sagemaker.tensorflow.TensorFlow(script_mode=None)",
28+
"sagemaker.tensorflow.TensorFlow(py_version='py2')",
29+
"sagemaker.tensorflow.TensorFlow()",
30+
)
31+
32+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
33+
34+
for constructor in tf_legacy_mode_constructors:
35+
node = _ast_call(constructor)
36+
assert modifier.node_should_be_modified(node) is True
37+
38+
39+
def test_node_should_be_modified_tf_constructor_script_mode():
40+
tf_script_mode_constructors = (
41+
"TensorFlow(script_mode=True)",
42+
"TensorFlow(py_version='py3')",
43+
"TensorFlow(py_version='py37')",
44+
"TensorFlow(py_version='py3', script_mode=False)",
45+
"sagemaker.tensorflow.TensorFlow(script_mode=True)",
46+
"sagemaker.tensorflow.TensorFlow(py_version='py3')",
47+
"sagemaker.tensorflow.TensorFlow(py_version='py37')",
48+
"sagemaker.tensorflow.TensorFlow(py_version='py3', script_mode=False)",
49+
)
50+
51+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
52+
53+
for constructor in tf_script_mode_constructors:
54+
node = _ast_call(constructor)
55+
assert modifier.node_should_be_modified(node) is False
56+
57+
58+
def test_node_should_be_modified_random_function_call():
59+
node = _ast_call("MXNet(py_version='py3')")
60+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
61+
assert modifier.node_should_be_modified(node) is False
62+
63+
64+
def test_modify_node_set_script_mode_false():
65+
tf_constructors = (
66+
"TensorFlow()",
67+
"TensorFlow(script_mode=False)",
68+
"TensorFlow(script_mode=None)",
69+
)
70+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
71+
72+
for constructor in tf_constructors:
73+
node = _ast_call(constructor)
74+
modifier.modify_node(node)
75+
assert "TensorFlow(script_mode=False)" == pasta.dump(node)
76+
77+
78+
def test_modify_node_set_hyperparameters():
79+
tf_constructor = """TensorFlow(
80+
checkpoint_path='s3://foo/bar',
81+
training_steps=100,
82+
evaluation_steps=10,
83+
requirements_file='source/requirements.txt',
84+
)"""
85+
86+
node = _ast_call(tf_constructor)
87+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
88+
modifier.modify_node(node)
89+
90+
expected_hyperparameters = {
91+
"checkpoint_path": "s3://foo/bar",
92+
"evaluation_steps": 10,
93+
"sagemaker_requirements": "source/requirements.txt",
94+
"training_steps": 100,
95+
}
96+
97+
assert expected_hyperparameters == _hyperparameters_from_node(node)
98+
99+
100+
def test_modify_node_preserve_other_hyperparameters():
101+
tf_constructor = """sagemaker.tensorflow.TensorFlow(
102+
training_steps=100,
103+
evaluation_steps=10,
104+
requirements_file='source/requirements.txt',
105+
hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
106+
)"""
107+
108+
node = _ast_call(tf_constructor)
109+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
110+
modifier.modify_node(node)
111+
112+
expected_hyperparameters = {
113+
"optimizer": "sgd",
114+
"lr": 0.1,
115+
"checkpoint_path": "s3://foo/bar",
116+
"evaluation_steps": 10,
117+
"sagemaker_requirements": "source/requirements.txt",
118+
"training_steps": 100,
119+
}
120+
121+
assert expected_hyperparameters == _hyperparameters_from_node(node)
122+
123+
124+
def test_modify_node_prefer_param_over_hyperparameter():
125+
tf_constructor = """sagemaker.tensorflow.TensorFlow(
126+
training_steps=100,
127+
requirements_file='source/requirements.txt',
128+
hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
129+
)"""
130+
131+
node = _ast_call(tf_constructor)
132+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
133+
modifier.modify_node(node)
134+
135+
expected_hyperparameters = {
136+
"sagemaker_requirements": "source/requirements.txt",
137+
"training_steps": 100,
138+
}
139+
140+
assert expected_hyperparameters == _hyperparameters_from_node(node)
141+
142+
143+
def _hyperparameters_from_node(node):
144+
for kw in node.keywords:
145+
if kw.arg == "hyperparameters":
146+
keys = [k.s for k in kw.value.keys]
147+
values = [getattr(v, v._fields[0]) for v in kw.value.values]
148+
return dict(zip(keys, values))
149+
150+
151+
def _ast_call(code):
152+
return pasta.parse(code).body[0].value

0 commit comments

Comments
 (0)