Skip to content

Commit 160339d

Browse files
authored
Merge branch 'zwei' into add-bytes-deserializer
2 parents c78b440 + e10b29b commit 160339d

File tree

9 files changed

+244
-75
lines changed

9 files changed

+244
-75
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
modifiers.tfs.TensorFlowServingConstructorRenamer(),
2626
modifiers.predictors.PredictorConstructorRefactor(),
2727
modifiers.airflow.ModelConfigArgModifier(),
28+
modifiers.airflow.ModelConfigImageURIRenamer(),
2829
modifiers.renamed_params.DistributionParameterRenamer(),
2930
modifiers.renamed_params.S3SessionRenamer(),
3031
]

src/sagemaker/cli/compatibility/v2/modifiers/airflow.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import matching
18+
from sagemaker.cli.compatibility.v2.modifiers import matching, renamed_params
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

2121
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
@@ -61,3 +61,32 @@ def modify_node(self, node):
6161
"""
6262
instance_type = node.args.pop(0)
6363
node.keywords.append(ast.keyword(arg="instance_type", value=instance_type))
64+
65+
66+
class ModelConfigImageURIRenamer(renamed_params.ParamRenamer):
67+
"""A class to rename the ``image`` attribute to ``image_uri`` in Airflow model config functions.
68+
69+
This looks for the following formats:
70+
71+
- ``model_config``
72+
- ``airflow.model_config``
73+
- ``workflow.airflow.model_config``
74+
- ``sagemaker.workflow.airflow.model_config``
75+
76+
where ``model_config`` is either ``model_config`` or ``model_config_from_estimator``.
77+
"""
78+
79+
@property
80+
def calls_to_modify(self):
81+
"""A dictionary mapping Airflow model config functions to their respective namespaces."""
82+
return FUNCTIONS
83+
84+
@property
85+
def old_param_name(self):
86+
"""The previous name for the image URI argument."""
87+
return "image"
88+
89+
@property
90+
def new_param_name(self):
91+
"""The new name for the image URI argument."""
92+
return "image_uri"

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+10-18
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
import ast
1717

18-
from sagemaker.cli.compatibility.v2.modifiers import matching
18+
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
1919
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2020

2121
FRAMEWORK_ARG = "framework_version"
@@ -98,14 +98,14 @@ def modify_node(self, node):
9898
framework, is_model = _framework_from_node(node)
9999

100100
# if framework_version is not supplied, get default and append keyword
101-
framework_version = _arg_value(node, FRAMEWORK_ARG)
102-
if framework_version is None:
101+
if matching.has_arg(node, FRAMEWORK_ARG):
102+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
103+
else:
103104
framework_version = FRAMEWORK_DEFAULTS[framework]
104105
node.keywords.append(ast.keyword(arg=FRAMEWORK_ARG, value=ast.Str(s=framework_version)))
105106

106107
# if py_version is not supplied, get a conditional default, and if not None, append keyword
107-
py_version = _arg_value(node, PY_ARG)
108-
if py_version is None:
108+
if not matching.has_arg(node, PY_ARG):
109109
py_version = _py_version_defaults(framework, framework_version, is_model)
110110
if py_version:
111111
node.keywords.append(ast.keyword(arg=PY_ARG, value=ast.Str(s=py_version)))
@@ -175,28 +175,20 @@ def _version_args_needed(node, image_arg):
175175
Applies similar logic as ``validate_version_or_image_args``
176176
"""
177177
# if image_arg is present, no need to supply version arguments
178-
image_name = _arg_value(node, image_arg)
179-
if image_name:
178+
if matching.has_arg(node, image_arg):
180179
return False
181180

182181
# if framework_version is None, need args
183-
framework_version = _arg_value(node, FRAMEWORK_ARG)
184-
if framework_version is None:
182+
if matching.has_arg(node, FRAMEWORK_ARG):
183+
framework_version = parsing.arg_value(node, FRAMEWORK_ARG)
184+
else:
185185
return True
186186

187187
# check if we expect py_version and we don't get it -- framework and model dependent
188188
framework, is_model = _framework_from_node(node)
189189
expecting_py_version = _py_version_defaults(framework, framework_version, is_model)
190190
if expecting_py_version:
191-
py_version = _arg_value(node, PY_ARG)
191+
py_version = parsing.arg_value(node, PY_ARG)
192192
return py_version is None
193193

194194
return False
195-
196-
197-
def _arg_value(node, arg):
198-
"""Gets the value associated with the arg keyword, if present"""
199-
for kw in node.keywords:
200-
if kw.arg == arg and kw.value:
201-
return kw.value.s
202-
return None

src/sagemaker/cli/compatibility/v2/modifiers/matching.py

+19
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import ast
1717

18+
from sagemaker.cli.compatibility.v2.modifiers import parsing
19+
1820

1921
def matches_any(node, name_to_namespaces_dict):
2022
"""Determines if the ``ast.Call`` node matches any of the provided names and namespaces.
@@ -101,3 +103,20 @@ def matches_namespace(node, namespace):
101103
name, value = names.pop(), value.value
102104

103105
return isinstance(value, ast.Name) and value.id == name
106+
107+
108+
def has_arg(node, arg):
109+
"""Checks if the call has the given argument.
110+
111+
Args:
112+
node (ast.Call): a node that represents a function call. For more,
113+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
114+
arg (str): the name of the argument.
115+
116+
Returns:
117+
bool: if the node has the given argument.
118+
"""
119+
try:
120+
return parsing.arg_value(node, arg) is not None
121+
except KeyError:
122+
return False
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for parsing AST nodes."""
14+
from __future__ import absolute_import
15+
16+
import pasta
17+
18+
19+
def arg_from_keywords(node, arg):
20+
"""Retrieves a keyword argument from the node's keywords.
21+
22+
Args:
23+
node (ast.Call): a node that represents a function call. For more,
24+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
25+
arg (str): the name of the argument.
26+
27+
Returns:
28+
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
29+
"""
30+
for kw in node.keywords:
31+
if kw.arg == arg:
32+
return kw
33+
34+
return None
35+
36+
37+
def arg_value(node, arg):
38+
"""Retrieves a keyword argument's value from the node's keywords.
39+
40+
Args:
41+
node (ast.Call): a node that represents a function call. For more,
42+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
43+
arg (str): the name of the argument.
44+
45+
Returns:
46+
obj: the keyword argument's value.
47+
48+
Raises:
49+
KeyError: if the node's keywords do not contain the argument.
50+
"""
51+
keyword = arg_from_keywords(node, arg)
52+
if keyword is None:
53+
raise KeyError("arg '{}' not found in call: {}".format(arg, pasta.dump(node)))
54+
55+
return getattr(keyword.value, keyword.value._fields[0], None) if keyword.value else None

src/sagemaker/cli/compatibility/v2/modifiers/renamed_params.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import ast
1919
from abc import abstractmethod
2020

21-
from sagemaker.cli.compatibility.v2.modifiers import matching
21+
from sagemaker.cli.compatibility.v2.modifiers import matching, parsing
2222
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2323

2424

@@ -54,40 +54,20 @@ def node_should_be_modified(self, node):
5454
bool: If the ``ast.Call`` matches the relevant function calls and
5555
contains the parameter to be renamed.
5656
"""
57-
return matching.matches_any(node, self.calls_to_modify) and self._has_param_to_rename(node)
58-
59-
def _has_param_to_rename(self, node):
60-
"""Checks if the call has the argument that needs to be renamed."""
61-
return _keyword_from_keywords(node, self.old_param_name) is not None
57+
return matching.matches_any(node, self.calls_to_modify) and matching.has_arg(
58+
node, self.old_param_name
59+
)
6260

6361
def modify_node(self, node):
6462
"""Modifies the ``ast.Call`` node to rename the attribute.
6563
6664
Args:
6765
node (ast.Call): a node that represents the relevant function call.
6866
"""
69-
keyword = _keyword_from_keywords(node, self.old_param_name)
67+
keyword = parsing.arg_from_keywords(node, self.old_param_name)
7068
keyword.arg = self.new_param_name
7169

7270

73-
def _keyword_from_keywords(node, param_name):
74-
"""Retrieves a keyword argument from the node's keywords.
75-
76-
Args:
77-
node (ast.Call): a node that represents a function call. For more,
78-
see https://docs.python.org/3/library/ast.html#abstract-grammar.
79-
param_name (str): the name of the argument.
80-
81-
Returns:
82-
ast.keyword: the keyword argument if it is present. Otherwise, this returns ``None``.
83-
"""
84-
for kw in node.keywords:
85-
if kw.arg == param_name:
86-
return kw
87-
88-
return None
89-
90-
9171
class DistributionParameterRenamer(ParamRenamer):
9272
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
9373
MXNet and TensorFlow estimators.

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_airflow.py

+59-31
Original file line numberDiff line numberDiff line change
@@ -17,52 +17,41 @@
1717
from sagemaker.cli.compatibility.v2.modifiers import airflow
1818
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call
1919

20-
21-
def test_node_should_be_modified_model_config_with_args():
22-
model_config_calls = (
23-
"model_config(instance_type, model)",
24-
"airflow.model_config(instance_type, model)",
25-
"workflow.airflow.model_config(instance_type, model)",
26-
"sagemaker.workflow.airflow.model_config(instance_type, model)",
27-
"model_config_from_estimator(instance_type, model)",
28-
"airflow.model_config_from_estimator(instance_type, model)",
29-
"workflow.airflow.model_config_from_estimator(instance_type, model)",
30-
"sagemaker.workflow.airflow.model_config_from_estimator(instance_type, model)",
31-
)
32-
20+
MODEL_CONFIG_CALL_TEMPLATES = (
21+
"model_config({})",
22+
"airflow.model_config({})",
23+
"workflow.airflow.model_config({})",
24+
"sagemaker.workflow.airflow.model_config({})",
25+
"model_config_from_estimator({})",
26+
"airflow.model_config_from_estimator({})",
27+
"workflow.airflow.model_config_from_estimator({})",
28+
"sagemaker.workflow.airflow.model_config_from_estimator({})",
29+
)
30+
31+
32+
def test_arg_order_node_should_be_modified_model_config_with_args():
3333
modifier = airflow.ModelConfigArgModifier()
3434

35-
for call in model_config_calls:
36-
node = ast_call(call)
35+
for template in MODEL_CONFIG_CALL_TEMPLATES:
36+
node = ast_call(template.format("instance_type, model"))
3737
assert modifier.node_should_be_modified(node) is True
3838

3939

40-
def test_node_should_be_modified_model_config_without_args():
41-
model_config_calls = (
42-
"model_config()",
43-
"airflow.model_config()",
44-
"workflow.airflow.model_config()",
45-
"sagemaker.workflow.airflow.model_config()",
46-
"model_config_from_estimator()",
47-
"airflow.model_config_from_estimator()",
48-
"workflow.airflow.model_config_from_estimator()",
49-
"sagemaker.workflow.airflow.model_config_from_estimator()",
50-
)
51-
40+
def test_arg_order_node_should_be_modified_model_config_without_args():
5241
modifier = airflow.ModelConfigArgModifier()
5342

54-
for call in model_config_calls:
55-
node = ast_call(call)
43+
for template in MODEL_CONFIG_CALL_TEMPLATES:
44+
node = ast_call(template.format(""))
5645
assert modifier.node_should_be_modified(node) is False
5746

5847

59-
def test_node_should_be_modified_random_function_call():
48+
def test_arg_order_node_should_be_modified_random_function_call():
6049
node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()")
6150
modifier = airflow.ModelConfigArgModifier()
6251
assert modifier.node_should_be_modified(node) is False
6352

6453

65-
def test_modify_node():
54+
def test_arg_order_modify_node():
6655
model_config_calls = (
6756
("model_config(instance_type, model)", "model_config(model, instance_type=instance_type)"),
6857
(
@@ -89,3 +78,42 @@ def test_modify_node():
8978
node = ast_call(call)
9079
modifier.modify_node(node)
9180
assert expected == pasta.dump(node)
81+
82+
83+
def test_image_arg_node_should_be_modified_model_config_with_arg():
84+
modifier = airflow.ModelConfigImageURIRenamer()
85+
86+
for template in MODEL_CONFIG_CALL_TEMPLATES:
87+
node = ast_call(template.format("image=my_image"))
88+
assert modifier.node_should_be_modified(node) is True
89+
90+
91+
def test_image_arg_node_should_be_modified_model_config_without_arg():
92+
modifier = airflow.ModelConfigImageURIRenamer()
93+
94+
for template in MODEL_CONFIG_CALL_TEMPLATES:
95+
node = ast_call(template.format(""))
96+
assert modifier.node_should_be_modified(node) is False
97+
98+
99+
def test_image_arg_node_should_be_modified_random_function_call():
100+
node = ast_call("sagemaker.workflow.airflow.prepare_framework_container_def()")
101+
modifier = airflow.ModelConfigImageURIRenamer()
102+
assert modifier.node_should_be_modified(node) is False
103+
104+
105+
def test_image_arg_modify_node():
106+
model_config_calls = (
107+
("model_config(image='image:latest')", "model_config(image_uri='image:latest')"),
108+
(
109+
"model_config_from_estimator(image=my_image)",
110+
"model_config_from_estimator(image_uri=my_image)",
111+
),
112+
)
113+
114+
modifier = airflow.ModelConfigImageURIRenamer()
115+
116+
for call, expected in model_config_calls:
117+
node = ast_call(call)
118+
modifier.modify_node(node)
119+
assert expected == pasta.dump(node)

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_matching.py

+5
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,8 @@ def test_matches_attr():
6666
def test_matches_namespace():
6767
assert matching.matches_namespace(ast_call("sagemaker.mxnet.MXNet()"), "sagemaker.mxnet")
6868
assert not matching.matches_namespace(ast_call("sagemaker.KMeans()"), "sagemaker.mxnet")
69+
70+
71+
def test_has_arg():
72+
assert matching.has_arg(ast_call("MXNet(framework_version=mxnet_version)"), "framework_version")
73+
assert not matching.has_arg(ast_call("MXNet()"), "framework_version")

0 commit comments

Comments
 (0)