Skip to content

Commit 1d409b3

Browse files
committed
change: make v2 migration script add version args when needed
1 parent 49bab2b commit 1d409b3

File tree

2 files changed

+277
-168
lines changed

2 files changed

+277
-168
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+148-64
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1919

20+
FRAMEWORK_ARG = "framework_version"
21+
PY_ARG = "py_version"
22+
2023
FRAMEWORK_DEFAULTS = {
2124
"Chainer": "4.1.0",
2225
"MXNet": "1.2.0",
@@ -25,10 +28,11 @@
2528
"TensorFlow": "1.11.0",
2629
}
2730

28-
FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys())
31+
FRAMEWORK_CLASSES = list(FRAMEWORK_DEFAULTS.keys())
32+
MODEL_CLASSES = ["{}Model".format(fw) for fw in FRAMEWORK_CLASSES]
33+
2934
# TODO: check for sagemaker.tensorflow.serving.Model
30-
FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS]
31-
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS]
35+
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES]
3236
FRAMEWORK_SUBMODULES = ("model", "estimator")
3337

3438

@@ -39,7 +43,8 @@ class FrameworkVersionEnforcer(Modifier):
3943

4044
def node_should_be_modified(self, node):
4145
"""Checks if the ast.Call node instantiates a framework estimator or model,
42-
but doesn't specify the ``framework_version`` parameter.
46+
but doesn't specify the ``framework_version`` and ``py_version`` parameter,
47+
as appropriate.
4348
4449
This looks for the following formats:
4550
@@ -56,49 +61,12 @@ def node_should_be_modified(self, node):
5661
bool: If the ``ast.Call`` is instantiating a framework class that
5762
should specify ``framework_version``, but doesn't.
5863
"""
59-
if self._is_framework_constructor(node):
60-
return not self._fw_version_in_keywords(node)
64+
if _is_named_constructor(node, FRAMEWORK_CLASSES):
65+
return _version_args_needed(node, "image_name")
6166

62-
return False
67+
if _is_named_constructor(node, MODEL_CLASSES):
68+
return _version_args_needed(node, "image")
6369

64-
def _is_framework_constructor(self, node):
65-
"""Checks if the ``ast.Call`` node represents a call of the form
66-
<Framework> or sagemaker.<framework>.<Framework>.
67-
"""
68-
# Check for <Framework> call
69-
if isinstance(node.func, ast.Name):
70-
return node.func.id in FRAMEWORK_CLASSES
71-
72-
# Check for something.that.ends.with.<framework>.<Framework> call
73-
if not (isinstance(node.func, ast.Attribute) and node.func.attr in FRAMEWORK_CLASSES):
74-
return False
75-
76-
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
77-
if (
78-
isinstance(node.func.value, ast.Attribute)
79-
and node.func.value.attr in FRAMEWORK_SUBMODULES
80-
):
81-
return self._is_in_framework_module(node.func.value)
82-
83-
# Check for sagemaker.<framework>.<Framework> call
84-
return self._is_in_framework_module(node.func)
85-
86-
def _is_in_framework_module(self, node):
87-
"""Checks if the node is an ``ast.Attribute`` that represents a
88-
``sagemaker.<framework>`` module.
89-
"""
90-
return (
91-
isinstance(node.value, ast.Attribute)
92-
and node.value.attr in FRAMEWORK_MODULES
93-
and isinstance(node.value.value, ast.Name)
94-
and node.value.value.id == "sagemaker"
95-
)
96-
97-
def _fw_version_in_keywords(self, node):
98-
"""Checks if the ``ast.Call`` node's keywords contain ``framework_version``."""
99-
for kw in node.keywords:
100-
if kw.arg == "framework_version" and kw.value:
101-
return True
10270
return False
10371

10472
def modify_node(self, node):
@@ -112,30 +80,146 @@ def modify_node(self, node):
11280
- SKLearn: "0.20.0"
11381
- TensorFlow: "1.11.0"
11482
83+
The ``py_version`` value is determined by the framework, framework_version, and if it is a
84+
model, whether the model accepts a py_version
85+
11586
Args:
11687
node (ast.Call): a node that represents the constructor of a framework class.
11788
"""
118-
framework = self._framework_name_from_node(node)
119-
node.keywords.append(
120-
ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework]))
121-
)
89+
framework, is_model = _framework_from_node(node)
12290

123-
def _framework_name_from_node(self, node):
124-
"""Retrieves the framework name based on the function call.
91+
# if framework_version is not supplied, get default and append keyword
92+
framework_version = _arg_value(node, FRAMEWORK_ARG)
93+
if framework_version is None:
94+
framework_version = FRAMEWORK_DEFAULTS[framework]
95+
node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version)))
12596

126-
Args:
127-
node (ast.Call): a node that represents the constructor of a framework class.
128-
This can represent either <Framework> or sagemaker.<framework>.<Framework>.
97+
# if py_version is not supplied, get a conditional default, and if not None, append keyword
98+
py_version = _arg_value(node, PY_ARG)
99+
if py_version is None:
100+
py_version = _py_version_defaults(framework, framework_version, is_model)
101+
if py_version:
102+
node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
129103

130-
Returns:
131-
str: the (capitalized) framework name.
132-
"""
133-
if isinstance(node.func, ast.Name):
134-
framework = node.func.id
135-
elif isinstance(node.func, ast.Attribute):
136-
framework = node.func.attr
137104

138-
if framework.endswith("Model"):
139-
framework = framework[: framework.find("Model")]
105+
def _py_version_defaults(framework, framework_version, is_model=False):
106+
"""Gets the py_version required for the framework_version and if it's a model
107+
108+
Args:
109+
framework (str): name of the framework
110+
framework_version (str): version of the framework
111+
is_model (bool): whether it is a constructor for a model or not
112+
113+
Returns:
114+
str: the default py version, as appropriate. None if no default py_version
115+
"""
116+
if framework in ("Chainer", "PyTorch"):
117+
return "py3"
118+
if framework == "SKLearn" and not is_model:
119+
return "py3"
120+
if framework == "MXNet":
121+
return "py2"
122+
if framework == "TensorFlow" and not is_model:
123+
return _tf_py_version_default(framework_version)
124+
return None
125+
126+
127+
def _tf_py_version_default(framework_version):
128+
"""Gets the py_version default based on framework_version for TensorFlow."""
129+
if not framework_version:
130+
return "py2"
131+
version = [int(s) for s in framework_version.split(".")]
132+
if version < [1, 12]:
133+
return "py2"
134+
if version < [2, 2]:
135+
return "py3"
136+
return "py37"
137+
138+
139+
def _framework_from_node(node):
140+
"""Retrieves the framework class name based on the function call, and if it was a model
141+
142+
Args:
143+
node (ast.Call): a node that represents the constructor of a framework class.
144+
This can represent either <Framework> or sagemaker.<framework>.<Framework>.
145+
146+
Returns:
147+
str, bool: the (capitalized) framework class name, and if it is a model class
148+
"""
149+
if isinstance(node.func, ast.Name):
150+
framework = node.func.id
151+
elif isinstance(node.func, ast.Attribute):
152+
framework = node.func.attr
153+
else:
154+
framework = ""
155+
156+
is_model = framework.endswith("Model")
157+
if is_model:
158+
framework = framework[: framework.find("Model")]
159+
160+
return framework, is_model
161+
162+
163+
def _is_named_constructor(node, names):
164+
"""Checks if the ``ast.Call`` node represents a call to particular named constructors.
165+
166+
Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
167+
where <Framework> belongs to the list of names passed in.
168+
"""
169+
# Check for call from particular names of constructors
170+
if isinstance(node.func, ast.Name):
171+
return node.func.id in names
172+
173+
# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
174+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names):
175+
return False
176+
177+
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
178+
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES:
179+
return _is_in_framework_module(node.func.value)
180+
181+
# Check for sagemaker.<framework>.<Framework> call
182+
return _is_in_framework_module(node.func)
183+
184+
185+
def _is_in_framework_module(node):
186+
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
187+
return (
188+
isinstance(node.value, ast.Attribute)
189+
and node.value.attr in FRAMEWORK_MODULES
190+
and isinstance(node.value.value, ast.Name)
191+
and node.value.value.id == "sagemaker"
192+
)
193+
194+
195+
def _version_args_needed(node, image_arg):
196+
"""Determines if image_arg or version_arg was supplied
197+
198+
Applies similar logic as ``validate_version_or_image_args``
199+
"""
200+
# if image_arg is present, no need to supply version arguments
201+
image_name = _arg_value(node, image_arg)
202+
if image_name:
203+
return False
204+
205+
# if framework_version is None, need args
206+
framework_version = _arg_value(node, FRAMEWORK_ARG)
207+
if framework_version is None:
208+
return True
209+
210+
# check if we expect py_version and we don't get it -- framework and model dependent
211+
framework, is_model = _framework_from_node(node)
212+
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
213+
if expecting_py_version:
214+
py_version = _arg_value(node, PY_ARG)
215+
return py_version is None
216+
217+
return False
218+
140219

141-
return framework
220+
def _arg_value(node, arg):
221+
"""Gets the value associated with the arg keyword, if present"""
222+
for kw in node.keywords:
223+
if kw.arg == arg and kw.value:
224+
return kw.value.s
225+
return None

0 commit comments

Comments
 (0)