Skip to content

Commit 3f08e2e

Browse files
author
Chuyang Deng
committed
feature: add 1p algorithm image_uris migration tool
1 parent b34d680 commit 3f08e2e

File tree

4 files changed

+220
-10
lines changed

4 files changed

+220
-10
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
modifiers.training_params.TrainPrefixRemover(),
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
3838
modifiers.serde.SerdeConstructorRenamer(),
39+
modifiers.image_uris.ImageURIRetrieveRefactor(),
3940
]
4041

4142
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
@@ -53,6 +54,7 @@
5354
modifiers.training_input.TrainingInputImportFromRenamer(),
5455
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5556
modifiers.serde.SerdeImportFromPredictorRenamer(),
57+
modifiers.image_uris.ImageURIRetrieveImportFromRenamer(),
5658
]
5759

5860

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
tfs,
2525
training_params,
2626
training_input,
27+
image_uris,
2728
)

src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py

+97-10
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,50 @@
1616
from __future__ import absolute_import
1717

1818
import ast
19+
import logging
20+
import pasta
1921

2022
from sagemaker.cli.compatibility.v2.modifiers import matching
2123
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
2224

2325
GET_IMAGE_URI_NAME = "get_image_uri"
24-
GET_IMAGE_URI_NAMESPACES = ("sagemaker", "sagemaker.amazon_estimator")
26+
GET_IMAGE_URI_NAMESPACES = (
27+
"sagemaker",
28+
"sagemaker.amazon_estimator",
29+
"sagemaker.amazon.amazon_estimator",
30+
"amazon_estimator",
31+
"amazon.amazon_estimator",
32+
)
33+
34+
ALGORITHM_NAME_FROM_REPO = {
35+
"blazingtext": "blazingtext",
36+
"sagemaker-rl-mxnet": "coach-mxnet",
37+
"sagemaker-rl-tensorflow": ["coach-tensorflow", "ray-tensorflow"],
38+
"factorization-machine": "factorization-machines",
39+
"forecasting-deepar": "forecasting-deepar",
40+
"image-classification": "image-classification",
41+
"image-classification-neo": "image-classification-neo",
42+
"ipinsights": "ipinsights",
43+
"kmeans": "kmeans",
44+
"knn": "knn",
45+
"lda": "lda",
46+
"linear-learner": "linear-learner",
47+
"ntm": "ntm",
48+
"object2vec": "object2vec",
49+
"object-detection": "object-detection",
50+
"pca": "pca",
51+
"randomcutforest": "randomcutforest",
52+
"sagemaker-rl-ray-container": "ray-pytorch",
53+
"semantic-segmentation": "semantic-segmentation",
54+
"seq2seq": "seq2seq",
55+
"sagemaker-scikit-learn": "sklearn",
56+
"sagemaker-sparkml-serving": "sparkml-serving",
57+
"sagemaker-rl-vw-container": "vw",
58+
"sagemaker-xgboost": "xgboost",
59+
"xgboost-neo": "xgboost-neo",
60+
}
61+
62+
logger = logging.getLogger("sagemaker")
2563

2664

2765
class ImageURIRetrieveRefactor(Modifier):
@@ -43,7 +81,9 @@ def node_should_be_modified(self, node):
4381
Returns:
4482
bool: If the ``ast.Call`` instantiates a class of interest.
4583
"""
46-
return matching.matches_name_or_namespaces(node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES)
84+
return matching.matches_name_or_namespaces(
85+
node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES
86+
)
4787

4888
def modify_node(self, node):
4989
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
@@ -52,12 +92,59 @@ def modify_node(self, node):
5292
Args:
5393
node (ast.Call): a node that represents a *image_uris.retrieve call.
5494
"""
55-
if matching.matches_name(node, GET_IMAGE_URI_NAME):
56-
node.func.id = "image_uris.retrieve"
57-
node.func.params.argOne, node.func.params.argTwo = node.func.params.argTwo, node.func.params.argOne
58-
elif matching.matches_attr(node, GET_IMAGE_URI_NAME):
59-
node.func.attr = "image_uris.retrieve"
60-
node.func.params.argOne, node.func.params.argTwo = node.func.params.argTwo, node.func.params.argOne
95+
original_args = [None] * 3
96+
for kw in node.keywords:
97+
if kw.arg == "repo_name":
98+
arg = kw.value.s
99+
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
100+
if isinstance(modified_arg, list):
101+
logger.warning(
102+
"There are more than one value mapping to {}, {} will be used".format(
103+
arg, modified_arg[0]
104+
)
105+
)
106+
modified_arg = modified_arg[0]
107+
original_args[0] = ast.Str(modified_arg)
108+
elif kw.arg == "repo_region":
109+
original_args[1] = ast.Str(kw.value.s)
110+
elif kw.arg == "repo_version":
111+
original_args[2] = ast.Str(kw.value.s)
112+
113+
if len(node.args) > 0:
114+
original_args[1] = ast.Str(node.args[0].s)
115+
if len(node.args) > 1:
116+
arg = node.args[1].s
117+
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
118+
if isinstance(modified_arg, list):
119+
logger.warning(
120+
"There are more than one value mapping to {}, {} will be used".format(
121+
arg, modified_arg[0]
122+
)
123+
)
124+
modified_arg = modified_arg[0]
125+
original_args[0] = ast.Str(modified_arg)
126+
if len(node.args) > 2:
127+
original_args[2] = ast.Str(node.args[2].s)
128+
129+
args = []
130+
for arg in original_args:
131+
if arg:
132+
args.append(arg)
133+
134+
if matching.matches_name(node, GET_IMAGE_URI_NAME) or matching.matches_attr(
135+
node, GET_IMAGE_URI_NAME
136+
):
137+
node_components = list(pasta.dump(node).split("."))
138+
node_modules = node_components[: len(node_components) - 1]
139+
if "sagemaker" in node_modules:
140+
node.func = ast.Attribute(
141+
value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")),
142+
attr="retrieve",
143+
)
144+
else:
145+
node.func = ast.Attribute(value=ast.Name(id="image_uris"), attr="retrieve")
146+
node.args = args
147+
node.keywords = []
61148
return node
62149

63150

@@ -75,7 +162,7 @@ def node_should_be_modified(self, node):
75162
bool: If the import statement imports ``get_image_uri`` from the correct module.
76163
"""
77164
return node.module in GET_IMAGE_URI_NAMESPACES and any(
78-
name.name == GET_IMAGE_URI_NAMESPACES for name in node.names
165+
name.name == GET_IMAGE_URI_NAME for name in node.names
79166
)
80167

81168
def modify_node(self, node):
@@ -91,6 +178,6 @@ def modify_node(self, node):
91178
for name in node.names:
92179
if name.name == GET_IMAGE_URI_NAME:
93180
name.name = "image_uris"
94-
if node.module == "sagemaker.amazon_estimator":
181+
if node.module in GET_IMAGE_URI_NAMESPACES:
95182
node.module = "sagemaker"
96183
return node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import image_uris
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def methods():
24+
return (
25+
"get_image_uri('us-west-2', 'kmeans')",
26+
"sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-scikil-learn')",
27+
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-scikil-learn')",
28+
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikil-learn', repo_version='1')",
29+
)
30+
31+
32+
@pytest.fixture
33+
def import_statements():
34+
return (
35+
"from sagemaker import get_image_uri",
36+
"from sagemaker.amazon_estimator import get_image_uri",
37+
"from sagemaker.amazon.amazon_estimator import get_image_uri",
38+
)
39+
40+
41+
def test_method_node_should_be_modified(methods):
42+
modifier = image_uris.ImageURIRetrieveRefactor()
43+
for method in methods:
44+
node = ast_call(method)
45+
assert modifier.node_should_be_modified(node)
46+
47+
48+
def test_methodnode_should_be_modified_random_call():
49+
modifier = image_uris.ImageURIRetrieveRefactor()
50+
node = ast_call("create_image_uri()")
51+
assert not modifier.node_should_be_modified(node)
52+
53+
54+
def test_method_modify_node(methods, caplog):
55+
modifier = image_uris.ImageURIRetrieveRefactor()
56+
57+
method = "get_image_uri('us-west-2', 'sagemaker-scikit-learn')"
58+
node = ast_call(method)
59+
modifier.modify_node(node)
60+
assert "image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
61+
62+
method = "amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikit-learn')"
63+
node = ast_call(method)
64+
modifier.modify_node(node)
65+
assert "image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
66+
67+
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-scikit-learn')"
68+
node = ast_call(method)
69+
modifier.modify_node(node)
70+
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
71+
72+
method = (
73+
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-scikit-learn')"
74+
)
75+
node = ast_call(method)
76+
modifier.modify_node(node)
77+
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2')" == pasta.dump(node)
78+
79+
method = "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-scikit-learn', repo_version='1')"
80+
node = ast_call(method)
81+
modifier.modify_node(node)
82+
assert "sagemaker.image_uris.retrieve('sklearn', 'us-west-2', '1')" == pasta.dump(node)
83+
84+
method = "get_image_uri('us-west-2', 'sagemaker-rl-tensorflow')"
85+
node = ast_call(method)
86+
modifier.modify_node(node)
87+
assert "image_uris.retrieve('coach-tensorflow', 'us-west-2')"
88+
assert "There are more than one value mapping to" in caplog.text
89+
90+
91+
def test_import_from_node_should_be_modified_image_uris_input(import_statements):
92+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
93+
94+
statement = "from sagemaker import get_image_uri"
95+
node = ast_import(statement)
96+
assert modifier.node_should_be_modified(node)
97+
98+
statement = "from sagemaker.amazon_estimator import get_image_uri"
99+
node = ast_import(statement)
100+
assert modifier.node_should_be_modified(node)
101+
102+
statement = "from sagemaker.amazon.amazon_estimator import get_image_uri"
103+
node = ast_import(statement)
104+
assert modifier.node_should_be_modified(node)
105+
106+
107+
def test_import_from_node_should_be_modified_random_import():
108+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
109+
node = ast_import("from sagemaker.amazon_estimator import registry")
110+
assert not modifier.node_should_be_modified(node)
111+
112+
113+
def test_import_from_modify_node(import_statements):
114+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
115+
expected_result = "from sagemaker import image_uris"
116+
117+
for import_statement in import_statements:
118+
node = ast_import(import_statement)
119+
modifier.modify_node(node)
120+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)