Skip to content

Commit aab6779

Browse files
committed
infra: add cli modifier for RealTimePredictor and derived classes
1 parent f9628f8 commit aab6779

File tree

4 files changed

+280
-1
lines changed

4 files changed

+280
-1
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sagemaker.cli.compatibility.v2 import modifiers
1919

2020
FUNCTION_CALL_MODIFIERS = [
21+
modifiers.predictors.PredictorConstructorRefactor(),
2122
modifiers.framework_version.FrameworkVersionEnforcer(),
2223
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2324
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
@@ -28,7 +29,10 @@
2829

2930
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
3031

31-
IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()]
32+
IMPORT_FROM_MODIFIERS = [
33+
modifiers.predictors.PredictorImportFromRenamer(),
34+
modifiers.tfs.TensorFlowServingImportFromRenamer(),
35+
]
3236

3337

3438
class ASTTransformer(ast.NodeTransformer):

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
airflow,
1818
deprecated_params,
1919
framework_version,
20+
predictors,
2021
tf_legacy_mode,
2122
tfs,
2223
)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify Predictor code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
20+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
21+
22+
BASE_PREDICTOR = "RealTimePredictor"
23+
PREDICTORS = {
24+
"FactorizationMachinesPredictor": ("sagemaker", "sagemaker.amazon.factorization_machines"),
25+
"IPInsightsPredictor": ("sagemaker", "sagemaker.amazon.ipinsights"),
26+
"KMeansPredictor": ("sagemaker", "sagemaker.amazon.kmeans"),
27+
"KNNPredictor": ("sagemaker", "sagemaker.amazon.knn"),
28+
"LDAPredictor": ("sagemaker", "sagemaker.amazon.lda"),
29+
"LinearLearnerPredictor": ("sagemaker", "sagemaker.amazon.linear_learner"),
30+
"NTMPredictor": ("sagemaker", "sagemaker.amazon.ntm"),
31+
"PCAPredictor": ("sagemaker", "sagemaker.amazon.pca"),
32+
"RandomCutForestPredictor": ("sagemaker", "sagemaker.amazon.randomcutforest"),
33+
"RealTimePredictor": ("sagemaker", "sagemaker.predictor"),
34+
"SparkMLPredictor": ("sagemaker.sparkml", "sagemaker.sparkml.model"),
35+
}
36+
37+
38+
class PredictorConstructorRefactor(Modifier):
39+
"""A class to refactor *Predictor class and refactor endpoint attribute."""
40+
41+
def node_should_be_modified(self, node):
42+
"""Checks if the ``ast.Call`` node instantiates a class of interest.
43+
44+
This looks for the following calls:
45+
46+
- ``sagemaker.<my>.<namespace>.<MyPredictor>``
47+
- ``sagemaker.<namespace>.<MyPredictor>``
48+
- ``<MyPredictor>``
49+
50+
Args:
51+
node (ast.Call): a node that represents a function call. For more,
52+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
53+
54+
Returns:
55+
bool: If the ``ast.Call`` instantiates a class of interest.
56+
"""
57+
return any(_matching(node, name, namespaces) for name, namespaces in PREDICTORS.items())
58+
59+
def modify_node(self, node):
60+
"""Modifies the ``ast.Call`` node to call ``Predictor`` instead.
61+
62+
Also renames ``endpoint`` attribute to ``endpoint_name``.
63+
64+
Args:
65+
node (ast.Call): a node that represents a *Predictor constructor.
66+
"""
67+
_rename_class(node)
68+
_rename_endpoint(node)
69+
70+
71+
def _matching(node, name, namespaces):
72+
"""Determines if the node matches the constructor name in the right namespace"""
73+
if _matching_name(node, name):
74+
return True
75+
76+
if not _matching_attr(node, name):
77+
return False
78+
79+
return any(_matching_namespace(node, namespace) for namespace in namespaces)
80+
81+
82+
def _matching_name(node, name):
83+
"""Determines if the node is an ast.Name node with a matching name"""
84+
return isinstance(node.func, ast.Name) and node.func.id == name
85+
86+
87+
def _matching_attr(node, name):
88+
"""Determines if the node is an ast.Attribute node with a matching name"""
89+
return isinstance(node.func, ast.Attribute) and node.func.attr == name
90+
91+
92+
def _matching_namespace(node, namespace):
93+
"""Determines if the node corresponds to a matching namespace"""
94+
names = namespace.split(".")
95+
name, value = names.pop(), node.func.value
96+
while isinstance(value, ast.Attribute) and len(names) > 0:
97+
if value.attr != name:
98+
return False
99+
name, value = names.pop(), value.value
100+
101+
return isinstance(value, ast.Name) and value.id == name
102+
103+
104+
def _rename_class(node):
105+
"""Renames the RealTimePredictor base class to Predictor"""
106+
if _matching_name(node, BASE_PREDICTOR):
107+
node.func.id = "Predictor"
108+
elif _matching_attr(node, BASE_PREDICTOR):
109+
node.func.attr = "Predictor"
110+
111+
112+
def _rename_endpoint(node):
113+
"""Renames keyword endpoint argument to endpoint_name"""
114+
for keyword in node.keywords:
115+
if keyword.arg == "endpoint":
116+
keyword.arg = "endpoint_name"
117+
break
118+
119+
120+
class PredictorImportFromRenamer(Modifier):
121+
"""A class to update import statements of ``RealTimePredictor``."""
122+
123+
def node_should_be_modified(self, node):
124+
"""Checks if the import statement imports ``RealTimePredictor`` from the correct module.
125+
126+
Args:
127+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
128+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
129+
130+
Returns:
131+
bool: If the import statement imports ``RealTimePredictor`` from the correct module.
132+
"""
133+
return node.module in PREDICTORS[BASE_PREDICTOR] and any(
134+
name.name == BASE_PREDICTOR for name in node.names
135+
)
136+
137+
def modify_node(self, node):
138+
"""Changes the ``ast.ImportFrom`` node's name from ``RealTimePredictor`` to ``Predictor``.
139+
140+
Args:
141+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
142+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
143+
"""
144+
for name in node.names:
145+
if name.name == BASE_PREDICTOR:
146+
name.name = "Predictor"
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import predictors
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def base_constructors():
24+
return (
25+
"sagemaker.predictor.RealTimePredictor(endpoint='a')",
26+
"sagemaker.RealTimePredictor(endpoint='b')",
27+
"RealTimePredictor(endpoint='c')",
28+
)
29+
30+
31+
@pytest.fixture
32+
def sparkml_constructors():
33+
return (
34+
"sagemaker.sparkml.model.SparkMLPredictor(endpoint='a')",
35+
"sagemaker.sparkml.SparkMLPredictor(endpoint='b')",
36+
"SparkMLPredictor(endpoint='c')",
37+
)
38+
39+
40+
@pytest.fixture
41+
def other_constructors():
42+
return (
43+
"sagemaker.amazon.knn.KNNPredictor(endpoint='a')",
44+
"sagemaker.KNNPredictor(endpoint='b')",
45+
"KNNPredictor(endpoint='c')",
46+
)
47+
48+
49+
@pytest.fixture
50+
def import_statements():
51+
return (
52+
"from sagemaker.predictor import RealTimePredictor",
53+
"from sagemaker import RealTimePredictor",
54+
)
55+
56+
57+
def test_constructor_node_should_be_modified_base(base_constructors):
58+
modifier = predictors.PredictorConstructorRefactor()
59+
for constructor in base_constructors:
60+
node = ast_call(constructor)
61+
assert modifier.node_should_be_modified(node)
62+
63+
64+
def test_constructor_node_should_be_modified_sparkml(sparkml_constructors):
65+
modifier = predictors.PredictorConstructorRefactor()
66+
for constructor in sparkml_constructors:
67+
node = ast_call(constructor)
68+
assert modifier.node_should_be_modified(node)
69+
70+
71+
def test_constructor_node_should_be_modified_other(other_constructors):
72+
modifier = predictors.PredictorConstructorRefactor()
73+
for constructor in other_constructors:
74+
node = ast_call(constructor)
75+
assert modifier.node_should_be_modified(node)
76+
77+
78+
def test_constructor_node_should_be_modified_random_call():
79+
modifier = predictors.PredictorConstructorRefactor()
80+
node = ast_call("Model()")
81+
assert not modifier.node_should_be_modified(node)
82+
83+
84+
def test_constructor_modify_node():
85+
modifier = predictors.PredictorConstructorRefactor()
86+
87+
node = ast_call("sagemaker.RealTimePredictor(endpoint='a')")
88+
modifier.modify_node(node)
89+
assert "sagemaker.Predictor(endpoint_name='a')" == pasta.dump(node)
90+
91+
node = ast_call("RealTimePredictor(endpoint='a')")
92+
modifier.modify_node(node)
93+
assert "Predictor(endpoint_name='a')" == pasta.dump(node)
94+
95+
node = ast_call("sagemaker.amazon.kmeans.KMeansPredictor(endpoint='a')")
96+
modifier.modify_node(node)
97+
assert "sagemaker.amazon.kmeans.KMeansPredictor(endpoint_name='a')" == pasta.dump(node)
98+
99+
node = ast_call("KMeansPredictor(endpoint='a')")
100+
modifier.modify_node(node)
101+
assert "KMeansPredictor(endpoint_name='a')" == pasta.dump(node)
102+
103+
104+
def test_import_from_node_should_be_modified_predictor_module(import_statements):
105+
modifier = predictors.PredictorImportFromRenamer()
106+
for statement in import_statements:
107+
node = ast_import(statement)
108+
assert modifier.node_should_be_modified(node)
109+
110+
111+
def test_import_from_node_should_be_modified_random_import():
112+
modifier = predictors.PredictorImportFromRenamer()
113+
node = ast_import("from sagemaker import Session")
114+
assert not modifier.node_should_be_modified(node)
115+
116+
117+
def test_import_from_modify_node():
118+
modifier = predictors.PredictorImportFromRenamer()
119+
120+
node = ast_import("from sagemaker.predictor import BytesDeserializer, RealTimePredictor")
121+
modifier.modify_node(node)
122+
expected_result = "from sagemaker.predictor import BytesDeserializer, Predictor"
123+
assert expected_result == pasta.dump(node)
124+
125+
node = ast_import("from sagemaker.predictor import RealTimePredictor as RTP")
126+
modifier.modify_node(node)
127+
expected_result = "from sagemaker.predictor import Predictor as RTP"
128+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)