Skip to content

Commit b381033

Browse files
committed
change: make v2 migration script remove script_mode param and set model_dir=False if needed
1 parent d0eb4a2 commit b381033

File tree

6 files changed

+211
-36
lines changed

6 files changed

+211
-36
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to remove deprecated parameters."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
21+
class TensorFlowScriptModeParamRemover(Modifier):
22+
"""A class to remove ``script_mode`` from TensorFlow estimators (because it's the only mode)."""
23+
24+
def node_should_be_modified(self, node):
25+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with
26+
``script_mode`` set.
27+
28+
This looks for the following formats:
29+
30+
- ``TensorFlow``
31+
- ``sagemaker.tensorflow.TensorFlow``
32+
33+
Args:
34+
node (ast.Call): a node that represents a function call. For more,
35+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
36+
37+
Returns:
38+
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with ``script_mode``.
39+
"""
40+
return self._is_tf_constructor(node) and self._has_script_mode_param(node)
41+
42+
def _is_tf_constructor(self, node):
43+
"""Checks if the ``ast.Call`` node represents a call of the form
44+
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
45+
"""
46+
# Check for TensorFlow()
47+
if isinstance(node.func, ast.Name):
48+
return node.func.id == "TensorFlow"
49+
50+
# Check for sagemaker.tensorflow.TensorFlow()
51+
ends_with_tensorflow_constructor = (
52+
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
53+
)
54+
55+
is_in_tensorflow_module = (
56+
isinstance(node.func.value, ast.Attribute)
57+
and node.func.value.attr == "tensorflow"
58+
and isinstance(node.func.value.value, ast.Name)
59+
and node.func.value.value.id == "sagemaker"
60+
)
61+
62+
return ends_with_tensorflow_constructor and is_in_tensorflow_module
63+
64+
def _has_script_mode_param(self, node):
65+
"""Checks if the ``ast.Call`` node's keywords include ``script_mode``."""
66+
for kw in node.keywords:
67+
if kw.arg == "script_mode":
68+
return True
69+
70+
return False
71+
72+
def modify_node(self, node):
73+
"""Modifies the ``ast.Call`` node's keywords to remove ``script_mode``.
74+
75+
Args:
76+
node (ast.Call): a node that represents a TensorFlow constructor.
77+
"""
78+
for kw in node.keywords:
79+
if kw.arg == "script_mode":
80+
node.keywords.remove(kw)

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

+31-6
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
import ast
1717

18+
import boto3
1819
import six
1920

21+
from sagemaker.cli.compatibility.v2.modifiers import framework_version
2022
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
23+
from sagemaker import fw_utils
2124

2225

2326
class TensorFlowLegacyModeConstructorUpgrader(Modifier):
24-
"""A class to turn legacy mode parameters into hyperparameters when
25-
instantiating a TensorFlow estimator.
27+
"""A class to turn legacy mode parameters into hyperparameters, disable the ``model_dir``
28+
hyperparameter, and set the image URI when instantiating a TensorFlow estimator.
2629
"""
2730

2831
LEGACY_MODE_PARAMETERS = (
@@ -89,7 +92,7 @@ def _is_legacy_mode(self, node):
8992

9093
def modify_node(self, node):
9194
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
92-
into hyperparameters and set ``script_mode=False``.
95+
into hyperparameters and sets ``model_dir=False``.
9396
9497
The parameters that are converted into hyperparameters:
9598
@@ -105,9 +108,11 @@ def modify_node(self, node):
105108
additional_hps = {}
106109
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
107110

111+
add_image_uri = True
112+
108113
for kw in node.keywords:
109-
if kw.arg == "script_mode":
110-
# remove here because is set to False later regardless of current value
114+
if kw.arg in ("script_mode", "model_dir"):
115+
# model_dir is removed so that it can be set to False later
111116
kw_to_remove.append(kw)
112117
if kw.arg == "hyperparameters" and kw.value:
113118
base_hps = dict(zip(kw.value.keys, kw.value.values))
@@ -116,11 +121,17 @@ def modify_node(self, node):
116121
hp_key = self._hyperparameter_key_for_param(kw.arg)
117122
additional_hps[hp_key] = kw.value
118123
kw_to_remove.append(kw)
124+
if kw.arg == "image_name":
125+
add_image_uri = False
119126

120127
self._remove_keywords(node, kw_to_remove)
121128
self._add_updated_hyperparameters(node, base_hps, additional_hps)
122129

123-
node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))
130+
if add_image_uri:
131+
image_uri = self._image_uri_from_args(node.keywords)
132+
node.keywords.append(ast.keyword(arg="image_name", value=ast.Str(s=image_uri)))
133+
134+
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
124135

125136
def _hyperparameter_key_for_param(self, arg):
126137
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
@@ -148,6 +159,20 @@ def _to_ast_keyword(self, hps):
148159

149160
return None
150161

162+
def _image_uri_from_args(self, keywords):
163+
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
164+
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
165+
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
166+
167+
for kw in keywords:
168+
if kw.arg == "framework_version":
169+
tf_version = kw.value.s
170+
if kw.arg == "train_instance_type":
171+
instance_type = kw.value.s
172+
173+
region = boto3.Session().region_name
174+
return fw_utils.create_image_uri(region, "tensorflow", instance_type, tf_version, "py2")
175+
151176

152177
class TensorBoardParameterRemover(Modifier):
153178
"""A class for removing the ``run_tensorboard_locally`` parameter from ``fit()``."""

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker.cli.compatibility.v2.modifiers import framework_version
21+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
2122

2223

2324
@pytest.fixture(autouse=True)
@@ -54,7 +55,7 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
5455
modifier = framework_version.FrameworkVersionEnforcer()
5556

5657
for constructor in fw_constructors:
57-
node = _ast_call(constructor)
58+
node = ast_call(constructor)
5859
assert modifier.node_should_be_modified(node) is True
5960

6061

@@ -85,12 +86,12 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
8586
modifier = framework_version.FrameworkVersionEnforcer()
8687

8788
for constructor in fw_constructors:
88-
node = _ast_call(constructor)
89+
node = ast_call(constructor)
8990
assert modifier.node_should_be_modified(node) is False
9091

9192

9293
def test_node_should_be_modified_random_function_call():
93-
node = _ast_call("sagemaker.session.Session()")
94+
node = ast_call("sagemaker.session.Session()")
9495
modifier = framework_version.FrameworkVersionEnforcer()
9596
assert modifier.node_should_be_modified(node) is False
9697

@@ -139,14 +140,10 @@ def test_modify_node_sklearn():
139140
_test_modify_node(classes, "0.20.0")
140141

141142

142-
def _ast_call(code):
143-
return pasta.parse(code).body[0].value
144-
145-
146143
def _test_modify_node(classes, default_version):
147144
modifier = framework_version.FrameworkVersionEnforcer()
148145
for cls in classes:
149-
node = _ast_call("{}()".format(cls))
146+
node = ast_call("{}()".format(cls))
150147
modifier.modify_node(node)
151148

152149
expected_result = "{}(framework_version='{}')".format(cls, default_version)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tensorboard_remove_param.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pasta
1616

1717
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1819

1920

2021
def test_node_should_be_modified_fit_with_tensorboard():
@@ -26,7 +27,7 @@ def test_node_should_be_modified_fit_with_tensorboard():
2627
modifier = tf_legacy_mode.TensorBoardParameterRemover()
2728

2829
for call in fit_calls:
29-
node = _ast_call(call)
30+
node = ast_call(call)
3031
assert modifier.node_should_be_modified(node) is True
3132

3233

@@ -36,12 +37,12 @@ def test_node_should_be_modified_fit_without_tensorboard():
3637
modifier = tf_legacy_mode.TensorBoardParameterRemover()
3738

3839
for call in fit_calls:
39-
node = _ast_call(call)
40+
node = ast_call(call)
4041
assert modifier.node_should_be_modified(node) is False
4142

4243

4344
def test_node_should_be_modified_random_function_call():
44-
node = _ast_call("estimator.deploy(1, 'local')")
45+
node = ast_call("estimator.deploy(1, 'local')")
4546
modifier = tf_legacy_mode.TensorBoardParameterRemover()
4647
assert modifier.node_should_be_modified(node) is False
4748

@@ -54,10 +55,6 @@ def test_modify_node():
5455
modifier = tf_legacy_mode.TensorBoardParameterRemover()
5556

5657
for call in fit_calls:
57-
node = _ast_call(call)
58+
node = ast_call(call)
5859
modifier.modify_node(node)
5960
assert "estimator.fit()" == pasta.dump(node)
60-
61-
62-
def _ast_call(code):
63-
return pasta.parse(code).body[0].value

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tf_legacy_mode.py

+43-14
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,13 @@
1616

1717
import pasta
1818
import pytest
19+
from mock import patch
1920

2021
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
22+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
23+
24+
IMAGE_URI = "sagemaker-tensorflow:latest"
25+
REGION_NAME = "us-west-2"
2126

2227

2328
@pytest.fixture(autouse=True)
@@ -42,7 +47,7 @@ def test_node_should_be_modified_tf_constructor_legacy_mode():
4247
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
4348

4449
for constructor in tf_legacy_mode_constructors:
45-
node = _ast_call(constructor)
50+
node = ast_call(constructor)
4651
assert modifier.node_should_be_modified(node) is True
4752

4853

@@ -61,28 +66,56 @@ def test_node_should_be_modified_tf_constructor_script_mode():
6166
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
6267

6368
for constructor in tf_script_mode_constructors:
64-
node = _ast_call(constructor)
69+
node = ast_call(constructor)
6570
assert modifier.node_should_be_modified(node) is False
6671

6772

6873
def test_node_should_be_modified_random_function_call():
69-
node = _ast_call("MXNet(py_version='py3')")
74+
node = ast_call("MXNet(py_version='py3')")
7075
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
7176
assert modifier.node_should_be_modified(node) is False
7277

7378

74-
def test_modify_node_set_script_mode_false():
79+
@patch("boto3.Session")
80+
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
81+
def test_modify_node_set_model_dir_and_image_name(create_image_uri, boto_session):
82+
boto_session.return_value.region_name = REGION_NAME
83+
7584
tf_constructors = (
7685
"TensorFlow()",
7786
"TensorFlow(script_mode=False)",
78-
"TensorFlow(script_mode=None)",
87+
"TensorFlow(model_dir='s3//bucket/model')",
7988
)
8089
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
8190

8291
for constructor in tf_constructors:
83-
node = _ast_call(constructor)
92+
node = ast_call(constructor)
8493
modifier.modify_node(node)
85-
assert "TensorFlow(script_mode=False)" == pasta.dump(node)
94+
95+
assert "TensorFlow(image_name='{}', model_dir=False)".format(IMAGE_URI) == pasta.dump(node)
96+
create_image_uri.assert_called_with(
97+
REGION_NAME, "tensorflow", "ml.m4.xlarge", "1.11.0", "py2"
98+
)
99+
100+
101+
@patch("boto3.Session")
102+
@patch("sagemaker.fw_utils.create_image_uri", return_value=IMAGE_URI)
103+
def test_modify_node_set_image_name_from_args(create_image_uri, boto_session):
104+
boto_session.return_value.region_name = REGION_NAME
105+
106+
tf_constructor = "TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0')"
107+
108+
node = ast_call(tf_constructor)
109+
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
110+
modifier.modify_node(node)
111+
112+
create_image_uri.assert_called_with(REGION_NAME, "tensorflow", "ml.p2.xlarge", "1.4.0", "py2")
113+
114+
expected_string = (
115+
"TensorFlow(train_instance_type='ml.p2.xlarge', framework_version='1.4.0', "
116+
"image_name='{}', model_dir=False)".format(IMAGE_URI)
117+
)
118+
assert expected_string == pasta.dump(node)
86119

87120

88121
def test_modify_node_set_hyperparameters():
@@ -93,7 +126,7 @@ def test_modify_node_set_hyperparameters():
93126
requirements_file='source/requirements.txt',
94127
)"""
95128

96-
node = _ast_call(tf_constructor)
129+
node = ast_call(tf_constructor)
97130
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
98131
modifier.modify_node(node)
99132

@@ -115,7 +148,7 @@ def test_modify_node_preserve_other_hyperparameters():
115148
hyperparameters={'optimizer': 'sgd', 'lr': 0.1, 'checkpoint_path': 's3://foo/bar'},
116149
)"""
117150

118-
node = _ast_call(tf_constructor)
151+
node = ast_call(tf_constructor)
119152
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
120153
modifier.modify_node(node)
121154

@@ -138,7 +171,7 @@ def test_modify_node_prefer_param_over_hyperparameter():
138171
hyperparameters={'training_steps': 10, 'sagemaker_requirements': 'foo.txt'},
139172
)"""
140173

141-
node = _ast_call(tf_constructor)
174+
node = ast_call(tf_constructor)
142175
modifier = tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader()
143176
modifier.modify_node(node)
144177

@@ -156,7 +189,3 @@ def _hyperparameters_from_node(node):
156189
keys = [k.s for k in kw.value.keys]
157190
values = [getattr(v, v._fields[0]) for v in kw.value.values]
158191
return dict(zip(keys, values))
159-
160-
161-
def _ast_call(code):
162-
return pasta.parse(code).body[0].value

0 commit comments

Comments
 (0)