Skip to content

Commit 7f1150e

Browse files
authored
Merge branch 'zwei' into update-endpoint-method
2 parents b7d6a59 + ed2e428 commit 7f1150e

File tree

9 files changed

+218
-175
lines changed

9 files changed

+218
-175
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
from sagemaker.cli.compatibility.v2 import modifiers
1919

2020
FUNCTION_CALL_MODIFIERS = [
21-
modifiers.predictors.PredictorConstructorRefactor(),
2221
modifiers.framework_version.FrameworkVersionEnforcer(),
2322
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2423
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
2524
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
2625
modifiers.tfs.TensorFlowServingConstructorRenamer(),
26+
modifiers.predictors.PredictorConstructorRefactor(),
2727
modifiers.airflow.ModelConfigArgModifier(),
2828
]
2929

src/sagemaker/cli/compatibility/v2/modifiers/airflow.py

+6-23
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@
1515

1616
import ast
1717

18+
from sagemaker.cli.compatibility.v2.modifiers import matching
1819
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1920

21+
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
22+
NAMESPACES = ("sagemaker.workflow.airflow", "workflow.airflow", "airflow")
23+
FUNCTIONS = {name: NAMESPACES for name in FUNCTION_NAMES}
24+
2025

2126
class ModelConfigArgModifier(Modifier):
2227
"""A class to handle argument changes for Airflow model config functions."""
2328

24-
FUNCTION_NAMES = ("model_config", "model_config_from_estimator")
25-
2629
def node_should_be_modified(self, node):
2730
"""Checks if the ``ast.Call`` node creates an Airflow model config and
2831
contains positional arguments.
@@ -44,27 +47,7 @@ def node_should_be_modified(self, node):
4447
bool: If the ``ast.Call`` is either a ``model_config`` call or
4548
a ``model_config_from_estimator`` call and has positional arguments.
4649
"""
47-
return self._is_model_config_call(node) and len(node.args) > 0
48-
49-
def _is_model_config_call(self, node):
50-
"""Checks if the node is a ``model_config`` or ``model_config_from_estimator`` call."""
51-
if isinstance(node.func, ast.Name):
52-
return node.func.id in self.FUNCTION_NAMES
53-
54-
if not (isinstance(node.func, ast.Attribute) and node.func.attr in self.FUNCTION_NAMES):
55-
return False
56-
57-
return self._is_in_module(node.func, "sagemaker.workflow.airflow".split("."))
58-
59-
def _is_in_module(self, node, module):
60-
"""Checks if the node is in the module, including partial matches to the module path."""
61-
if isinstance(node.value, ast.Name):
62-
return node.value.id == module[-1]
63-
64-
if isinstance(node.value, ast.Attribute) and node.value.attr == module[-1]:
65-
return self._is_in_module(node.value, module[:-1])
66-
67-
return False
50+
return matching.matches_any(node, FUNCTIONS) and len(node.args) > 0
6851

6952
def modify_node(self, node):
7053
"""Modifies the ``ast.Call`` node's arguments.

src/sagemaker/cli/compatibility/v2/modifiers/deprecated_params.py

+5-25
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@
1313
"""Classes to remove deprecated parameters."""
1414
from __future__ import absolute_import
1515

16-
import ast
17-
16+
from sagemaker.cli.compatibility.v2.modifiers import matching
1817
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1918

19+
TF_NAMESPACES = ("sagemaker.tensorflow", "sagemaker.tensorflow.estimator")
20+
2021

2122
class TensorFlowScriptModeParameterRemover(Modifier):
2223
"""A class to remove ``script_mode`` from TensorFlow estimators (because it's the only mode)."""
@@ -37,29 +38,8 @@ def node_should_be_modified(self, node):
3738
Returns:
3839
bool: If the ``ast.Call`` is instantiating a TensorFlow estimator with ``script_mode``.
3940
"""
40-
return self._is_tf_constructor(node) and self._has_script_mode_param(node)
41-
42-
def _is_tf_constructor(self, node):
43-
"""Checks if the ``ast.Call`` node represents a call of the form
44-
``TensorFlow`` or ``sagemaker.tensorflow.TensorFlow``.
45-
"""
46-
# Check for TensorFlow()
47-
if isinstance(node.func, ast.Name):
48-
return node.func.id == "TensorFlow"
49-
50-
# Check for sagemaker.tensorflow.TensorFlow()
51-
ends_with_tensorflow_constructor = (
52-
isinstance(node.func, ast.Attribute) and node.func.attr == "TensorFlow"
53-
)
54-
55-
is_in_tensorflow_module = (
56-
isinstance(node.func.value, ast.Attribute)
57-
and node.func.value.attr == "tensorflow"
58-
and isinstance(node.func.value.value, ast.Name)
59-
and node.func.value.value.id == "sagemaker"
60-
)
61-
62-
return ends_with_tensorflow_constructor and is_in_tensorflow_module
41+
is_tf_constructor = matching.matches_name_or_namespaces(node, "TensorFlow", TF_NAMESPACES)
42+
return is_tf_constructor and self._has_script_mode_param(node)
6343

6444
def _has_script_mode_param(self, node):
6545
"""Checks if the ``ast.Call`` node's keywords include ``script_mode``."""

src/sagemaker/cli/compatibility/v2/modifiers/framework_version.py

+14-37
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import ast
1717

18+
from sagemaker.cli.compatibility.v2.modifiers import matching
1819
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
1920

2021
FRAMEWORK_ARG = "framework_version"
@@ -29,11 +30,19 @@
2930
}
3031

3132
FRAMEWORK_CLASSES = list(FRAMEWORK_DEFAULTS.keys())
32-
MODEL_CLASSES = ["{}Model".format(fw) for fw in FRAMEWORK_CLASSES]
3333

34+
ESTIMATORS = {
35+
fw: ("sagemaker.{}".format(fw.lower()), "sagemaker.{}.estimator".format(fw.lower()))
36+
for fw in FRAMEWORK_CLASSES
37+
}
3438
# TODO: check for sagemaker.tensorflow.serving.Model
35-
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORK_CLASSES]
36-
FRAMEWORK_SUBMODULES = ("model", "estimator")
39+
MODELS = {
40+
"{}Model".format(fw): (
41+
"sagemaker.{}".format(fw.lower()),
42+
"sagemaker.{}.model".format(fw.lower()),
43+
)
44+
for fw in FRAMEWORK_CLASSES
45+
}
3746

3847

3948
class FrameworkVersionEnforcer(Modifier):
@@ -61,10 +70,10 @@ def node_should_be_modified(self, node):
6170
bool: If the ``ast.Call`` is instantiating a framework class that
6271
should specify ``framework_version``, but doesn't.
6372
"""
64-
if _is_named_constructor(node, FRAMEWORK_CLASSES):
73+
if matching.matches_any(node, ESTIMATORS):
6574
return _version_args_needed(node, "image_name")
6675

67-
if _is_named_constructor(node, MODEL_CLASSES):
76+
if matching.matches_any(node, MODELS):
6877
return _version_args_needed(node, "image")
6978

7079
return False
@@ -160,38 +169,6 @@ def _framework_from_node(node):
160169
return framework, is_model
161170

162171

163-
def _is_named_constructor(node, names):
164-
"""Checks if the ``ast.Call`` node represents a call to particular named constructors.
165-
166-
Forms that qualify are either <Framework> or sagemaker.<framework>.<Framework>
167-
where <Framework> belongs to the list of names passed in.
168-
"""
169-
# Check for call from particular names of constructors
170-
if isinstance(node.func, ast.Name):
171-
return node.func.id in names
172-
173-
# Check for something.that.ends.with.<framework>.<Framework> call for Framework in names
174-
if not (isinstance(node.func, ast.Attribute) and node.func.attr in names):
175-
return False
176-
177-
# Check for sagemaker.<frameworks>.<estimator/model>.<Framework> call
178-
if isinstance(node.func.value, ast.Attribute) and node.func.value.attr in FRAMEWORK_SUBMODULES:
179-
return _is_in_framework_module(node.func.value)
180-
181-
# Check for sagemaker.<framework>.<Framework> call
182-
return _is_in_framework_module(node.func)
183-
184-
185-
def _is_in_framework_module(node):
186-
"""Checks if node is an ``ast.Attribute`` representing a ``sagemaker.<framework>`` module."""
187-
return (
188-
isinstance(node.value, ast.Attribute)
189-
and node.value.attr in FRAMEWORK_MODULES
190-
and isinstance(node.value.value, ast.Name)
191-
and node.value.value.id == "sagemaker"
192-
)
193-
194-
195172
def _version_args_needed(node, image_arg):
196173
"""Determines if image_arg or version_arg was supplied
197174
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Functions for checking AST nodes for matches."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
19+
def matches_any(node, name_to_namespaces_dict):
20+
"""Determines if the ``ast.Call`` node matches any of the provided names and namespaces.
21+
22+
Args:
23+
node (ast.Call): a node that represents a function call. For more,
24+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
25+
name_to_namespaces_dict (dict[str, tuple]): a mapping of names to appropriate namespaces.
26+
27+
Returns:
28+
bool: if the node matches any of the names and namespaces.
29+
"""
30+
return any(
31+
matches_name_or_namespaces(node, name, namespaces)
32+
for name, namespaces in name_to_namespaces_dict.items()
33+
)
34+
35+
36+
def matches_name_or_namespaces(node, name, namespaces):
37+
"""Determines if the ``ast.Call`` node matches the function name in the right namespace.
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
name (str): the function name.
43+
namespaces (tuple): the possible namespaces to match to.
44+
45+
Returns:
46+
bool: if the node matches the name and any of the namespaces.
47+
"""
48+
if matches_name(node, name):
49+
return True
50+
51+
if not matches_attr(node, name):
52+
return False
53+
54+
return any(matches_namespace(node, namespace) for namespace in namespaces)
55+
56+
57+
def matches_name(node, name):
58+
"""Determines if the ``ast.Call`` node points to an ``ast.Name`` node with a matching name.
59+
60+
Args:
61+
node (ast.Call): a node that represents a function call. For more,
62+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
63+
name (str): the function name.
64+
65+
Returns:
66+
bool: if ``node.func`` is an ``ast.Name`` node with a matching name.
67+
"""
68+
return isinstance(node.func, ast.Name) and node.func.id == name
69+
70+
71+
def matches_attr(node, name):
72+
"""Determines if the ``ast.Call`` node points to an ``ast.Attribute`` node with a matching name.
73+
74+
Args:
75+
node (ast.Call): a node that represents a function call. For more,
76+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
77+
name (str): the function name.
78+
79+
Returns:
80+
bool: if ``node.func`` is an ``ast.Attribute`` node with a matching name.
81+
"""
82+
return isinstance(node.func, ast.Attribute) and node.func.attr == name
83+
84+
85+
def matches_namespace(node, namespace):
86+
"""Determines if the ``ast.Call`` node corresponds to a matching namespace.
87+
88+
Args:
89+
node (ast.Call): a node that represents a function call. For more,
90+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
91+
namespace (str): the namespace.
92+
93+
Returns:
94+
bool: if the node's namespaces matches the given namespace.
95+
"""
96+
names = namespace.split(".")
97+
name, value = names.pop(), node.func.value
98+
while isinstance(value, ast.Attribute) and len(names) > 0:
99+
if value.attr != name:
100+
return False
101+
name, value = names.pop(), value.value
102+
103+
return isinstance(value, ast.Name) and value.id == name

src/sagemaker/cli/compatibility/v2/modifiers/predictors.py

+4-38
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@
1515
"""
1616
from __future__ import absolute_import
1717

18-
import ast
19-
18+
from sagemaker.cli.compatibility.v2.modifiers import matching
2019
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2120

2221
BASE_PREDICTOR = "RealTimePredictor"
@@ -54,7 +53,7 @@ def node_should_be_modified(self, node):
5453
Returns:
5554
bool: If the ``ast.Call`` instantiates a class of interest.
5655
"""
57-
return any(_matching(node, name, namespaces) for name, namespaces in PREDICTORS.items())
56+
return matching.matches_any(node, PREDICTORS)
5857

5958
def modify_node(self, node):
6059
"""Modifies the ``ast.Call`` node to call ``Predictor`` instead.
@@ -68,44 +67,11 @@ def modify_node(self, node):
6867
_rename_endpoint(node)
6968

7069

71-
def _matching(node, name, namespaces):
72-
"""Determines if the node matches the constructor name in the right namespace"""
73-
if _matching_name(node, name):
74-
return True
75-
76-
if not _matching_attr(node, name):
77-
return False
78-
79-
return any(_matching_namespace(node, namespace) for namespace in namespaces)
80-
81-
82-
def _matching_name(node, name):
83-
"""Determines if the node is an ast.Name node with a matching name"""
84-
return isinstance(node.func, ast.Name) and node.func.id == name
85-
86-
87-
def _matching_attr(node, name):
88-
"""Determines if the node is an ast.Attribute node with a matching name"""
89-
return isinstance(node.func, ast.Attribute) and node.func.attr == name
90-
91-
92-
def _matching_namespace(node, namespace):
93-
"""Determines if the node corresponds to a matching namespace"""
94-
names = namespace.split(".")
95-
name, value = names.pop(), node.func.value
96-
while isinstance(value, ast.Attribute) and len(names) > 0:
97-
if value.attr != name:
98-
return False
99-
name, value = names.pop(), value.value
100-
101-
return isinstance(value, ast.Name) and value.id == name
102-
103-
10470
def _rename_class(node):
10571
"""Renames the RealTimePredictor base class to Predictor"""
106-
if _matching_name(node, BASE_PREDICTOR):
72+
if matching.matches_name(node, BASE_PREDICTOR):
10773
node.func.id = "Predictor"
108-
elif _matching_attr(node, BASE_PREDICTOR):
74+
elif matching.matches_attr(node, BASE_PREDICTOR):
10975
node.func.attr = "Predictor"
11076

11177

0 commit comments

Comments
 (0)