Skip to content

Commit e0dd8db

Browse files
committed
change: convert TF legacy mode parameters to hyperparameters in v2 migration script
1 parent 5b078f7 commit e0dd8db

File tree

4 files changed

+163
-4
lines changed

4 files changed

+163
-4
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,12 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import framework_version
18+
from sagemaker.cli.compatibility.v2 import modifiers
1919

20-
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
20+
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.framework_version.FrameworkVersionEnforcer(),
22+
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
23+
]
2124

2225

2326
class ASTTransformer(ast.NodeTransformer):
@@ -38,4 +41,6 @@ def visit_Call(self, node):
3841
"""
3942
for function_checker in FUNCTION_CALL_MODIFIERS:
4043
function_checker.check_and_modify_node(node)
44+
45+
ast.fix_missing_locations(node)
4146
return node

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,8 @@
1212
# language governing permissions and limitations under the License.
1313
"""Classes for modifying AST nodes"""
1414
from __future__ import absolute_import
15+
16+
from sagemaker.cli.compatibility.v2.modifiers import ( # noqa: F401 (imported but unused)
17+
framework_version,
18+
tf_legacy_mode,
19+
)

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def _is_framework_constructor(self, node):
6666
"""
6767
# Check for <Framework> call
6868
if isinstance(node.func, ast.Name):
69-
if node.func.id in FRAMEWORK_CLASSES:
70-
return True
69+
return node.func.id in FRAMEWORK_CLASSES
7170

7271
# Check for sagemaker.<framework>.<Framework> call
7372
ends_with_framework_constructor = (
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TensorFlow legacy mode code to be compatible with SageMaker Python SDK v2."""
14+
# TODO: handle fit(run_tensorboard_locally=True)
15+
from __future__ import absolute_import
16+
17+
import ast
18+
19+
import six
20+
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
24+
class TensorFlowLegacyModeConstructorUpgrader(Modifier):
25+
"""A class to turn legacy mode parameters into hyperparameters when
26+
instantiating a TensorFlow estimator.
27+
"""
28+
29+
LEGACY_MODE_PARAMETERS = (
30+
"checkpoint_path",
31+
"evaluation_steps",
32+
"requirements_file",
33+
"training_steps",
34+
)
35+
36+
def node_should_be_modified(self, node):
37+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.
38+
39+
This looks for the following formats:
40+
41+
- ``TensorFlow``
42+
- ``sagemaker.tensorflow.TensorFlow``
43+
44+
Legacy mode is enabled if (1) ``script_mode`` is ``False``, ``None``, or not specified,
45+
and (2) if ``py_version`` is ``py2`` or not specified.
46+
47+
Args:
48+
node (ast.Call): a node that represents a function call. For more,
49+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
50+
51+
Returns:
52+
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with legacy mode.
53+
"""
54+
return self._is_tf_constructor(node) and self._is_legacy_mode(node)
55+
56+
def _is_tf_constructor(self, node):
57+
"""Checks if the ``ast.Call`` node represents a call of the form
58+
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
59+
"""
60+
# Check for TensorFlow()
61+
if isinstance(node.func, ast.Name):
62+
return node.func.id == "TensorFlow"
63+
64+
# Check for sagemaker.tensorflow.TensorFlow()
65+
ends_with_tensorflow_constructor = (
66+
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
67+
)
68+
69+
is_in_tensorflow_module = (
70+
isinstance(node.func.value, ast.Attribute)
71+
and node.func.value.attr == "tensorflow"
72+
and isinstance(node.func.value.value, ast.Name)
73+
and node.func.value.value.id == "sagemaker"
74+
)
75+
76+
return ends_with_tensorflow_constructor and is_in_tensorflow_module
77+
78+
def _is_legacy_mode(self, node):
79+
"""Checks if the ``ast.Call`` node's keywords signal using legacy mode."""
80+
script_mode = False
81+
py_version = "py2"
82+
83+
for kw in node.keywords:
84+
if kw.arg == "script_mode":
85+
script_mode = bool(kw.value.value)
86+
if kw.arg == "py_version":
87+
py_version = kw.value.s
88+
89+
return not (py_version.startswith("py3") or script_mode)
90+
91+
def modify_node(self, node):
92+
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
93+
into hyperparameters and set ``script_mode=False``.
94+
95+
The parameters that are converted into hyperparameters:
96+
97+
- ``training_steps``
98+
- ``evaluation_steps``
99+
- ``checkpoint_path``
100+
- ``requirements_file``
101+
102+
Args:
103+
node (ast.Call): a node that represents a TensorFlow constructor.
104+
"""
105+
base_hps = {}
106+
additional_hps = {}
107+
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
108+
109+
for kw in node.keywords:
110+
if kw.arg == "script_mode":
111+
# remove here because is set to False later regardless of current value
112+
kw_to_remove.append(kw)
113+
if kw.arg == "hyperparameters" and kw.value:
114+
base_hps = dict(zip(kw.value.keys, kw.value.values))
115+
kw_to_remove.append(kw)
116+
if kw.arg in self.LEGACY_MODE_PARAMETERS and kw.value:
117+
hp_key = self._hyperparameter_key_for_param(kw.arg)
118+
additional_hps[hp_key] = kw.value
119+
kw_to_remove.append(kw)
120+
121+
self._remove_keywords(node, kw_to_remove)
122+
self._add_updated_hyperparameters(node, base_hps, additional_hps)
123+
124+
node.keywords.append(ast.keyword(arg="script_mode", value=ast.NameConstant(value=False)))
125+
126+
def _hyperparameter_key_for_param(self, arg):
127+
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
128+
name = "sagemaker_requirements" if arg == "requirements_file" else arg
129+
return ast.Str(s=name)
130+
131+
def _remove_keywords(self, node, keywords):
132+
"""Removes the keywords from the ``ast.Call`` node."""
133+
for kw in keywords:
134+
node.keywords.remove(kw)
135+
136+
def _add_updated_hyperparameters(self, node, base_hps, additional_hps):
137+
"""Combines and adds the hyperparameters to the ``ast.Call`` node's keywords."""
138+
base_hps.update(additional_hps)
139+
updated_hp_keyword = self._to_ast_keyword(base_hps)
140+
141+
if updated_hp_keyword:
142+
node.keywords.append(updated_hp_keyword)
143+
144+
def _to_ast_keyword(self, hps):
145+
"""Returns an ``ast.keyword`` for the ``hyperparameters`` kwarg if there are any."""
146+
if hps:
147+
keys, values = zip(*six.iteritems(hps))
148+
return ast.keyword(arg="hyperparameters", value=ast.Dict(keys=keys, values=values))
149+
150+
return None

0 commit comments

Comments
 (0)