Skip to content

Commit 242f81b

Browse files
authored
change: make v2 migration script remove script_mode param and set model_dir=False if needed (#1540)
1 parent 778a4ee commit 242f81b

13 files changed

+306
-39
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
modifiers.framework_version.FrameworkVersionEnforcer(),
2222
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2323
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
24+
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
2425
]
2526

2627

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from __future__ import absolute_import
1515

1616
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
deprecated_params,
1718
framework_version,
1819
tf_legacy_mode,
1920
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to remove deprecated parameters."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
21+
class TensorFlowScriptModeParameterRemover(Modifier):
22+
"""A class to remove ``script_mode`` from TensorFlow estimators (because it's the only mode)."""
23+
24+
def node_should_be_modified(self, node):
25+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with
26+
``script_mode`` set.
27+
28+
This looks for the following formats:
29+
30+
- ``TensorFlow``
31+
- ``sagemaker.tensorflow.TensorFlow``
32+
33+
Args:
34+
node (ast.Call): a node that represents a function call. For more,
35+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
36+
37+
Returns:
38+
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with ``script_mode``.
39+
"""
40+
return self._is_tf_constructor(node) and self._has_script_mode_param(node)
41+
42+
def _is_tf_constructor(self, node):
43+
"""Checks if the ``ast.Call`` node represents a call of the form
44+
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
45+
"""
46+
# Check for TensorFlow()
47+
if isinstance(node.func, ast.Name):
48+
return node.func.id == "TensorFlow"
49+
50+
# Check for sagemaker.tensorflow.TensorFlow()
51+
ends_with_tensorflow_constructor = (
52+
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
53+
)
54+
55+
is_in_tensorflow_module = (
56+
isinstance(node.func.value, ast.Attribute)
57+
and node.func.value.attr == "tensorflow"
58+
and isinstance(node.func.value.value, ast.Name)
59+
and node.func.value.value.id == "sagemaker"
60+
)
61+
62+
return ends_with_tensorflow_constructor and is_in_tensorflow_module
63+
64+
def _has_script_mode_param(self, node):
65+
"""Checks if the ``ast.Call`` node's keywords include ``script_mode``."""
66+
for kw in node.keywords:
67+
if kw.arg == "script_mode":
68+
return True
69+
70+
return False
71+
72+
def modify_node(self, node):
73+
"""Modifies the ``ast.Call`` node's keywords to remove ``script_mode``.
74+
75+
Args:
76+
node (ast.Call): a node that represents a TensorFlow constructor.
77+
"""
78+
for kw in node.keywords:
79+
if kw.arg == "script_mode":
80+
node.keywords.remove(kw)

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

+44-6
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
import ast
1717

18+
import boto3
1819
import six
1920

21+
from sagemaker.cli.compatibility.v2.modifiers import framework_version
2022
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
23+
from sagemaker import fw_utils
2124

2225

2326
class TensorFlowLegacyModeConstructorUpgrader(Modifier):
24-
"""A class to turn legacy mode parameters into hyperparameters when
25-
instantiating a TensorFlow estimator.
27+
"""A class to turn legacy mode parameters into hyperparameters, disable the ``model_dir``
28+
hyperparameter, and set the image URI when instantiating a TensorFlow estimator.
2629
"""
2730

2831
LEGACY_MODE_PARAMETERS = (
@@ -32,6 +35,18 @@ class TensorFlowLegacyModeConstructorUpgrader(Modifier):
3235
"training_steps",
3336
)
3437

38+
def __init__(self):
39+
"""Initializes a ``TensorFlowLegacyModeConstructorUpgrader``."""
40+
self._region = None
41+
42+
@property
43+
def region(self):
44+
"""Returns the AWS region for constructing an ECR image URI."""
45+
if self._region is None:
46+
self._region = boto3.Session().region_name
47+
48+
return self._region
49+
3550
def node_should_be_modified(self, node):
3651
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.
3752
@@ -89,7 +104,7 @@ def _is_legacy_mode(self, node):
89104

90105
def modify_node(self, node):
91106
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
92-
into hyperparameters and set ``script_mode=False``.
107+
into hyperparameters and sets ``model_dir=False``.
93108
94109
The parameters that are converted into hyperparameters:
95110
@@ -105,9 +120,11 @@ def modify_node(self, node):
105120
additional_hps = {}
106121
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
107122

123+
add_image_uri = True
124+
108125
for kw in node.keywords:
109-
if kw.arg == "script_mode":
110-
# remove here because is set to False later regardless of current value
126+
if kw.arg in ("script_mode", "model_dir"):
127+
# model_dir is removed so that it can be set to False later
111128
kw_to_remove.append(kw)
112129
if kw.arg == "hyperparameters" and kw.value:
113130
base_hps = dict(zip(kw.value.keys, kw.value.values))
@@ -116,11 +133,17 @@ def modify_node(self, node):
116133
hp_key = self._hyperparameter_key_for_param(kw.arg)
117134
additional_hps[hp_key] = kw.value
118135
kw_to_remove.append(kw)
136+
if kw.arg == "image_name":
137+
add_image_uri = False
119138

120139
self._remove_keywords(node, kw_to_remove)
121140
self._add_updated_hyperparameters(node, base_hps, additional_hps)
122141

123-
node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))
142+
if add_image_uri:
143+
image_uri = self._image_uri_from_args(node.keywords)
144+
node.keywords.append(ast.keyword(arg="image_name", value=ast.Str(s=image_uri)))
145+
146+
node.keywords.append(ast.keyword(arg="model_dir", value=ast.NameConstant(value=False)))
124147

125148
def _hyperparameter_key_for_param(self, arg):
126149
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
@@ -148,6 +171,21 @@ def _to_ast_keyword(self, hps):
148171

149172
return None
150173

174+
def _image_uri_from_args(self, keywords):
175+
"""Returns a legacy TensorFlow image URI based on the estimator arguments."""
176+
tf_version = framework_version.FRAMEWORK_DEFAULTS["TensorFlow"]
177+
instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
178+
179+
for kw in keywords:
180+
if kw.arg == "framework_version":
181+
tf_version = kw.value.s
182+
if kw.arg == "train_instance_type":
183+
instance_type = kw.value.s
184+
185+
return fw_utils.create_image_uri(
186+
self.region, "tensorflow", instance_type, tf_version, "py2"
187+
)
188+
151189

152190
class TensorBoardParameterRemover(Modifier):
153191
"""A class for removing the ``run_tensorboard_locally`` parameter from ``fit()``."""

tests/unit/sagemaker/cli/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
18+
def ast_call(code):
19+
return pasta.parse(code).body[0].value

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_framework_version.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919

2020
from sagemaker.cli.compatibility.v2.modifiers import framework_version
21+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
2122

2223

2324
@pytest.fixture(autouse=True)
@@ -54,7 +55,7 @@ def test_node_should_be_modified_fw_constructor_no_fw_version():
5455
modifier = framework_version.FrameworkVersionEnforcer()
5556

5657
for constructor in fw_constructors:
57-
node = _ast_call(constructor)
58+
node = ast_call(constructor)
5859
assert modifier.node_should_be_modified(node) is True
5960

6061

@@ -85,12 +86,12 @@ def test_node_should_be_modified_fw_constructor_with_fw_version():
8586
modifier = framework_version.FrameworkVersionEnforcer()
8687

8788
for constructor in fw_constructors:
88-
node = _ast_call(constructor)
89+
node = ast_call(constructor)
8990
assert modifier.node_should_be_modified(node) is False
9091

9192

9293
def test_node_should_be_modified_random_function_call():
93-
node = _ast_call("sagemaker.session.Session()")
94+
node = ast_call("sagemaker.session.Session()")
9495
modifier = framework_version.FrameworkVersionEnforcer()
9596
assert modifier.node_should_be_modified(node) is False
9697

@@ -139,14 +140,10 @@ def test_modify_node_sklearn():
139140
_test_modify_node(classes, "0.20.0")
140141

141142

142-
def _ast_call(code):
143-
return pasta.parse(code).body[0].value
144-
145-
146143
def _test_modify_node(classes, default_version):
147144
modifier = framework_version.FrameworkVersionEnforcer()
148145
for cls in classes:
149-
node = _ast_call("{}()".format(cls))
146+
node = ast_call("{}()".format(cls))
150147
modifier.modify_node(node)
151148

152149
expected_result = "{}(framework_version='{}')".format(cls, default_version)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_tensorboard_remove_param.py

+5-8
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pasta
1616

1717
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1819

1920

2021
def test_node_should_be_modified_fit_with_tensorboard():
@@ -26,7 +27,7 @@ def test_node_should_be_modified_fit_with_tensorboard():
2627
modifier = tf_legacy_mode.TensorBoardParameterRemover()
2728

2829
for call in fit_calls:
29-
node = _ast_call(call)
30+
node = ast_call(call)
3031
assert modifier.node_should_be_modified(node) is True
3132

3233

@@ -36,12 +37,12 @@ def test_node_should_be_modified_fit_without_tensorboard():
3637
modifier = tf_legacy_mode.TensorBoardParameterRemover()
3738

3839
for call in fit_calls:
39-
node = _ast_call(call)
40+
node = ast_call(call)
4041
assert modifier.node_should_be_modified(node) is False
4142

4243

4344
def test_node_should_be_modified_random_function_call():
44-
node = _ast_call("estimator.deploy(1, 'local')")
45+
node = ast_call("estimator.deploy(1, 'local')")
4546
modifier = tf_legacy_mode.TensorBoardParameterRemover()
4647
assert modifier.node_should_be_modified(node) is False
4748

@@ -54,10 +55,6 @@ def test_modify_node():
5455
modifier = tf_legacy_mode.TensorBoardParameterRemover()
5556

5657
for call in fit_calls:
57-
node = _ast_call(call)
58+
node = ast_call(call)
5859
modifier.modify_node(node)
5960
assert "estimator.fit()" == pasta.dump(node)
60-
61-
62-
def _ast_call(code):
63-
return pasta.parse(code).body[0].value

0 commit comments

Comments
 (0)