Skip to content

Commit 4fb16f2

Browse files
authored
feature: add 1p algorithm image_uris migration tool (#1792)
1 parent c4bb695 commit 4fb16f2

File tree

4 files changed

+251
-0
lines changed

4 files changed

+251
-0
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
modifiers.training_input.TrainingInputConstructorRefactor(),
3838
modifiers.training_input.ShuffleConfigModuleRenamer(),
3939
modifiers.serde.SerdeConstructorRenamer(),
40+
modifiers.image_uris.ImageURIRetrieveRefactor(),
4041
]
4142

4243
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
@@ -55,6 +56,7 @@
5556
modifiers.training_input.ShuffleConfigImportFromRenamer(),
5657
modifiers.serde.SerdeImportFromAmazonCommonRenamer(),
5758
modifiers.serde.SerdeImportFromPredictorRenamer(),
59+
modifiers.image_uris.ImageURIRetrieveImportFromRenamer(),
5860
]
5961

6062

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@
2424
tfs,
2525
training_params,
2626
training_input,
27+
image_uris,
2728
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify image uri retrieve methods for Python SDK v2.0 and later."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import matching
19+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
20+
21+
GET_IMAGE_URI_NAME = "get_image_uri"
22+
GET_IMAGE_URI_NAMESPACES = (
23+
"sagemaker",
24+
"sagemaker.amazon_estimator",
25+
"sagemaker.amazon.amazon_estimator",
26+
"amazon_estimator",
27+
"amazon.amazon_estimator",
28+
)
29+
30+
31+
class ImageURIRetrieveRefactor(Modifier):
32+
"""A class to refactor *get_image_uri() method."""
33+
34+
def node_should_be_modified(self, node):
35+
"""Checks if the ``ast.Call`` node calls a function of interest.
36+
37+
This looks for the following calls:
38+
39+
- ``sagemaker.get_image_uri``
40+
- ``sagemaker.amazon_estimator.get_image_uri``
41+
- ``get_image_uri``
42+
43+
Args:
44+
node (ast.Call): a node that represents a function call. For more,
45+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
46+
47+
Returns:
48+
bool: If the ``ast.Call`` instantiates a class of interest.
49+
"""
50+
return matching.matches_name_or_namespaces(
51+
node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES
52+
)
53+
54+
def modify_node(self, node):
55+
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
56+
And switch the first two parameters from (region, repo) to (framework, region)
57+
58+
Args:
59+
node (ast.Call): a node that represents a *image_uris.retrieve call.
60+
"""
61+
original_args = [None] * 3
62+
for kw in node.keywords:
63+
if kw.arg == "repo_name":
64+
original_args[0] = ast.Str(kw.value.s)
65+
elif kw.arg == "repo_region":
66+
original_args[1] = ast.Str(kw.value.s)
67+
elif kw.arg == "repo_version":
68+
original_args[2] = ast.Str(kw.value.s)
69+
70+
if len(node.args) > 0:
71+
original_args[1] = ast.Str(node.args[0].s)
72+
if len(node.args) > 1:
73+
original_args[0] = ast.Str(node.args[1].s)
74+
if len(node.args) > 2:
75+
original_args[2] = ast.Str(node.args[2].s)
76+
77+
args = []
78+
for arg in original_args:
79+
if arg:
80+
args.append(arg)
81+
82+
func = node.func
83+
has_sagemaker = False
84+
while hasattr(func, "value"):
85+
if hasattr(func.value, "id") and func.value.id == "sagemaker":
86+
has_sagemaker = True
87+
break
88+
func = func.value
89+
90+
if has_sagemaker:
91+
node.func = ast.Attribute(
92+
value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")),
93+
attr="retrieve",
94+
)
95+
else:
96+
node.func = ast.Attribute(value=ast.Name(id="image_uris"), attr="retrieve")
97+
node.args = args
98+
node.keywords = []
99+
return node
100+
101+
102+
class ImageURIRetrieveImportFromRenamer(Modifier):
103+
"""A class to update import statements of ``get_image_uri``."""
104+
105+
def node_should_be_modified(self, node):
106+
"""Checks if the import statement imports ``get_image_uri`` from the correct module.
107+
108+
Args:
109+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
110+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
111+
112+
Returns:
113+
bool: If the import statement imports ``get_image_uri`` from the correct module.
114+
"""
115+
return node.module in GET_IMAGE_URI_NAMESPACES and any(
116+
name.name == GET_IMAGE_URI_NAME for name in node.names
117+
)
118+
119+
def modify_node(self, node):
120+
"""Changes the ``ast.ImportFrom`` node's name from ``get_image_uri`` to ``image_uris``.
121+
122+
Args:
123+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
124+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
125+
126+
Returns:
127+
ast.AST: the original node, which has been potentially modified.
128+
"""
129+
for name in node.names:
130+
if name.name == GET_IMAGE_URI_NAME:
131+
name.name = "image_uris"
132+
if node.module in GET_IMAGE_URI_NAMESPACES:
133+
node.module = "sagemaker"
134+
return node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
import pytest
17+
18+
from sagemaker.cli.compatibility.v2.modifiers import image_uris
19+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
20+
21+
22+
@pytest.fixture
23+
def methods():
24+
return (
25+
"get_image_uri('us-west-2', 'sagemaker-xgboost')",
26+
"sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')",
27+
"sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')",
28+
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')",
29+
)
30+
31+
32+
@pytest.fixture
33+
def import_statements():
34+
return (
35+
"from sagemaker import get_image_uri",
36+
"from sagemaker.amazon_estimator import get_image_uri",
37+
"from sagemaker.amazon.amazon_estimator import get_image_uri",
38+
)
39+
40+
41+
def test_method_node_should_be_modified(methods):
42+
modifier = image_uris.ImageURIRetrieveRefactor()
43+
for method in methods:
44+
node = ast_call(method)
45+
assert modifier.node_should_be_modified(node)
46+
47+
48+
def test_methodnode_should_be_modified_random_call():
49+
modifier = image_uris.ImageURIRetrieveRefactor()
50+
node = ast_call("create_image_uri()")
51+
assert not modifier.node_should_be_modified(node)
52+
53+
54+
def test_method_modify_node(methods, caplog):
55+
modifier = image_uris.ImageURIRetrieveRefactor()
56+
57+
method = "get_image_uri('us-west-2', 'xgboost')"
58+
node = ast_call(method)
59+
modifier.modify_node(node)
60+
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
61+
62+
method = "amazon_estimator.get_image_uri('us-west-2', 'xgboost')"
63+
node = ast_call(method)
64+
modifier.modify_node(node)
65+
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
66+
67+
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='xgboost')"
68+
node = ast_call(method)
69+
modifier.modify_node(node)
70+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
71+
72+
method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='xgboost')"
73+
node = ast_call(method)
74+
modifier.modify_node(node)
75+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
76+
77+
method = (
78+
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'xgboost', repo_version='1')"
79+
)
80+
node = ast_call(method)
81+
modifier.modify_node(node)
82+
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node)
83+
84+
85+
def test_import_from_node_should_be_modified_image_uris_input(import_statements):
86+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
87+
88+
statement = "from sagemaker import get_image_uri"
89+
node = ast_import(statement)
90+
assert modifier.node_should_be_modified(node)
91+
92+
statement = "from sagemaker.amazon_estimator import get_image_uri"
93+
node = ast_import(statement)
94+
assert modifier.node_should_be_modified(node)
95+
96+
statement = "from sagemaker.amazon.amazon_estimator import get_image_uri"
97+
node = ast_import(statement)
98+
assert modifier.node_should_be_modified(node)
99+
100+
101+
def test_import_from_node_should_be_modified_random_import():
102+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
103+
node = ast_import("from sagemaker.amazon_estimator import registry")
104+
assert not modifier.node_should_be_modified(node)
105+
106+
107+
def test_import_from_modify_node(import_statements):
108+
modifier = image_uris.ImageURIRetrieveImportFromRenamer()
109+
expected_result = "from sagemaker import image_uris"
110+
111+
for import_statement in import_statements:
112+
node = ast_import(import_statement)
113+
modifier.modify_node(node)
114+
assert expected_result == pasta.dump(node)

0 commit comments

Comments
 (0)