16
16
17
17
import pasta
18
18
import pytest
19
+ from mock import patch
19
20
20
21
from sagemaker .cli .compatibility .v2 .modifiers import tf_legacy_mode
22
+ from tests .unit .sagemaker .cli .compatibility .v2 .modifiers .ast_converter import ast_call
23
+
24
+ IMAGE_URI = "sagemaker-tensorflow:latest"
25
+ REGION_NAME = "us-west-2"
21
26
22
27
23
28
@pytest .fixture (autouse = True )
@@ -42,7 +47,7 @@ def test_node_should_be_modified_tf_constructor_legacy_mode():
42
47
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
43
48
44
49
for constructor in tf_legacy_mode_constructors :
45
- node = _ast_call (constructor )
50
+ node = ast_call (constructor )
46
51
assert modifier .node_should_be_modified (node ) is True
47
52
48
53
@@ -61,28 +66,56 @@ def test_node_should_be_modified_tf_constructor_script_mode():
61
66
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
62
67
63
68
for constructor in tf_script_mode_constructors :
64
- node = _ast_call (constructor )
69
+ node = ast_call (constructor )
65
70
assert modifier .node_should_be_modified (node ) is False
66
71
67
72
68
73
def test_node_should_be_modified_random_function_call ():
69
- node = _ast_call ("MXNet(py_version='py3')" )
74
+ node = ast_call ("MXNet(py_version='py3')" )
70
75
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
71
76
assert modifier .node_should_be_modified (node ) is False
72
77
73
78
74
- def test_modify_node_set_script_mode_false ():
79
+ @patch ("boto3.Session" )
80
+ @patch ("sagemaker.fw_utils.create_image_uri" , return_value = IMAGE_URI )
81
+ def test_modify_node_set_model_dir_and_image_name (create_image_uri , boto_session ):
82
+ boto_session .return_value .region_name = REGION_NAME
83
+
75
84
tf_constructors = (
76
85
"TensorFlow()" ,
77
86
"TensorFlow(script_mode=False)" ,
78
- "TensorFlow(script_mode=None )" ,
87
+ "TensorFlow(model_dir='s3//bucket/model' )" ,
79
88
)
80
89
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
81
90
82
91
for constructor in tf_constructors :
83
- node = _ast_call (constructor )
92
+ node = ast_call (constructor )
84
93
modifier .modify_node (node )
85
- assert "TensorFlow(script_mode=False)" == pasta .dump (node )
94
+
95
+ assert "TensorFlow(image_name='{}', model_dir=False)" .format (IMAGE_URI ) == pasta .dump (node )
96
+ create_image_uri .assert_called_with (
97
+ REGION_NAME , "tensorflow" , "ml.m4.xlarge" , "1.11.0" , "py2"
98
+ )
99
+
100
+
101
+ @patch ("boto3.Session" )
102
+ @patch ("sagemaker.fw_utils.create_image_uri" , return_value = IMAGE_URI )
103
+ def test_modify_node_set_image_name_from_args (create_image_uri , boto_session ):
104
+ boto_session .return_value .region_name = REGION_NAME
105
+
106
+ tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"
107
+
108
+ node = ast_call (tf_constructor )
109
+ modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
110
+ modifier .modify_node (node )
111
+
112
+ create_image_uri .assert_called_with (REGION_NAME , "tensorflow" , "ml.p2.xlarge" , "1.4.0" , "py2" )
113
+
114
+ expected_string = (
115
+ "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
116
+ "image_name='{}', model_dir=False)" .format (IMAGE_URI )
117
+ )
118
+ assert expected_string == pasta .dump (node )
86
119
87
120
88
121
def test_modify_node_set_hyperparameters ():
@@ -93,7 +126,7 @@ def test_modify_node_set_hyperparameters():
93
126
requirements_file='source/requirements.txt',
94
127
)"""
95
128
96
- node = _ast_call (tf_constructor )
129
+ node = ast_call (tf_constructor )
97
130
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
98
131
modifier .modify_node (node )
99
132
@@ -115,7 +148,7 @@ def test_modify_node_preserve_other_hyperparameters():
115
148
hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
116
149
)"""
117
150
118
- node = _ast_call (tf_constructor )
151
+ node = ast_call (tf_constructor )
119
152
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
120
153
modifier .modify_node (node )
121
154
@@ -138,7 +171,7 @@ def test_modify_node_prefer_param_over_hyperparameter():
138
171
hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
139
172
)"""
140
173
141
- node = _ast_call (tf_constructor )
174
+ node = ast_call (tf_constructor )
142
175
modifier = tf_legacy_mode .TensorFlowLegacyModeConstructorUpgrader ()
143
176
modifier .modify_node (node )
144
177
@@ -156,7 +189,3 @@ def _hyperparameters_from_node(node):
156
189
keys = [k .s for k in kw .value .keys ]
157
190
values = [getattr (v , v ._fields [0 ]) for v in kw .value .values ]
158
191
return dict (zip (keys , values ))
159
-
160
-
161
- def _ast_call (code ):
162
- return pasta .parse (code ).body [0 ].value
0 commit comments