Skip to content

Commit 0c5392f

Browse files
authored
change: make v2 migration script remove legacy run_tensorboard_locally parameter (#1537)
1 parent 6eeca73 commit 0c5392f

File tree

3 files changed

+99
-1
lines changed

3 files changed

+99
-1
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
FUNCTION_CALL_MODIFIERS = [
2121
modifiers.framework_version.FrameworkVersionEnforcer(),
2222
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
23+
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
2324
]
2425

2526

src/sagemaker/cli/compatibility/v2/modifiers/tf_legacy_mode.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
1313
"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2."""
14-
# TODO: handle fit(run_tensorboard_locally=True)
1514
from __future__ import absolute_import
1615

1716
import ast
@@ -148,3 +147,38 @@ def _to_ast_keyword(self, hps):
148147
return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values))
149148

150149
return None
150+
151+
152+
class TensorBoardParameterRemover(Modifier):
153+
"""A class for removing the ``run_tensorboard_locally`` parameter from ``fit()``."""
154+
155+
def node_should_be_modified(self, node):
156+
"""Checks if the ``ast.Call`` node invokes a function named "fit" and
157+
contains a keyword argument named "run_tensorboard_locally".
158+
159+
Args:
160+
node (ast.Call): a node that represents a function call. For more,
161+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
162+
163+
Returns:
164+
bool: If the ``ast.Call`` is invoking a function named "fit" with
165+
a parameter named "run_tensorboard_locally".
166+
"""
167+
is_fit_call = isinstance(node.func, ast.Attribute) and node.func.attr == "fit"
168+
if is_fit_call:
169+
for kw in node.keywords:
170+
if kw.arg == "run_tensorboard_locally":
171+
return True
172+
173+
return False
174+
175+
def modify_node(self, node):
176+
"""Removes ``run_tensorboard_locally`` from the ``ast.Call`` node's keywords.
177+
178+
Args:
179+
node (ast.Call): a node that represents ``fit`` being called with
180+
``run_tensorboard_locally`` set.
181+
"""
182+
for kw in node.keywords:
183+
if kw.arg == "run_tensorboard_locally":
184+
node.keywords.remove(kw)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import tf_legacy_mode
18+
19+
20+
def test_node_should_be_modified_fit_with_tensorboard():
21+
fit_calls = (
22+
"estimator.fit(run_tensorboard_locally=True)",
23+
"tensorflow.fit(run_tensorboard_locally=False)",
24+
)
25+
26+
modifier = tf_legacy_mode.TensorBoardParameterRemover()
27+
28+
for call in fit_calls:
29+
node = _ast_call(call)
30+
assert modifier.node_should_be_modified(node) is True
31+
32+
33+
def test_node_should_be_modified_fit_without_tensorboard():
34+
fit_calls = ("estimator.fit()", "tensorflow.fit()")
35+
36+
modifier = tf_legacy_mode.TensorBoardParameterRemover()
37+
38+
for call in fit_calls:
39+
node = _ast_call(call)
40+
assert modifier.node_should_be_modified(node) is False
41+
42+
43+
def test_node_should_be_modified_random_function_call():
44+
node = _ast_call("estimator.deploy(1, 'local')")
45+
modifier = tf_legacy_mode.TensorBoardParameterRemover()
46+
assert modifier.node_should_be_modified(node) is False
47+
48+
49+
def test_modify_node():
50+
fit_calls = (
51+
"estimator.fit(run_tensorboard_locally=True)",
52+
"estimator.fit(run_tensorboard_locally=False)",
53+
)
54+
modifier = tf_legacy_mode.TensorBoardParameterRemover()
55+
56+
for call in fit_calls:
57+
node = _ast_call(call)
58+
modifier.modify_node(node)
59+
assert "estimator.fit()" == pasta.dump(node)
60+
61+
62+
def _ast_call(code):
63+
return pasta.parse(code).body[0].value

0 commit comments

Comments
 (0)