Skip to content

Commit b9899ec

Browse files
author
Chuyang Deng
committed
remove framework mapping and use func.value to check for 'sagemaker'
1 parent 2d05070 commit b9899ec

File tree

2 files changed

+18
-40
lines changed

2 files changed

+18
-40
lines changed

src/sagemaker/cli/compatibility/v2/modifiers/image_uris.py

+11-35
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414
from __future__ import absolute_import
1515

1616
import ast
17-
import logging
18-
import pasta
1917

2018
from sagemaker.cli.compatibility.v2.modifiers import matching
2119
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
@@ -29,30 +27,6 @@
2927
"amazon.amazon_estimator",
3028
)
3129

32-
ALGORITHM_NAME_FROM_REPO = {
33-
"blazingtext": "blazingtext",
34-
"factorization-machine": "factorization-machines",
35-
"forecasting-deepar": "forecasting-deepar",
36-
"image-classification": "image-classification",
37-
"image-classification-neo": "image-classification-neo",
38-
"ipinsights": "ipinsights",
39-
"kmeans": "kmeans",
40-
"knn": "knn",
41-
"lda": "lda",
42-
"linear-learner": "linear-learner",
43-
"ntm": "ntm",
44-
"object2vec": "object2vec",
45-
"object-detection": "object-detection",
46-
"pca": "pca",
47-
"randomcutforest": "randomcutforest",
48-
"semantic-segmentation": "semantic-segmentation",
49-
"seq2seq": "seq2seq",
50-
"sagemaker-xgboost": "xgboost",
51-
"xgboost-neo": "xgboost-neo",
52-
}
53-
54-
logger = logging.getLogger("sagemaker")
55-
5630

5731
class ImageURIRetrieveRefactor(Modifier):
5832
"""A class to refactor *get_image_uri() method."""
@@ -87,9 +61,7 @@ def modify_node(self, node):
8761
original_args = [None] * 3
8862
for kw in node.keywords:
8963
if kw.arg == "repo_name":
90-
arg = kw.value.s
91-
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
92-
original_args[0] = ast.Str(modified_arg)
64+
original_args[0] = ast.Str(kw.value.s)
9365
elif kw.arg == "repo_region":
9466
original_args[1] = ast.Str(kw.value.s)
9567
elif kw.arg == "repo_version":
@@ -98,9 +70,7 @@ def modify_node(self, node):
9870
if len(node.args) > 0:
9971
original_args[1] = ast.Str(node.args[0].s)
10072
if len(node.args) > 1:
101-
arg = node.args[1].s
102-
modified_arg = ALGORITHM_NAME_FROM_REPO[arg]
103-
original_args[0] = ast.Str(modified_arg)
73+
original_args[0] = ast.Str(node.args[1].s)
10474
if len(node.args) > 2:
10575
original_args[2] = ast.Str(node.args[2].s)
10676

@@ -112,9 +82,15 @@ def modify_node(self, node):
11282
if matching.matches_name(node, GET_IMAGE_URI_NAME) or matching.matches_attr(
11383
node, GET_IMAGE_URI_NAME
11484
):
115-
node_components = list(pasta.dump(node).split("."))
116-
node_modules = node_components[: len(node_components) - 1]
117-
if "sagemaker" in node_modules:
85+
func = node.func
86+
has_sagemaker = False
87+
while hasattr(func, "value"):
88+
if hasattr(func.value, "id") and func.value.id == "sagemaker":
89+
has_sagemaker = True
90+
break
91+
func = func.value
92+
93+
if has_sagemaker:
11894
node.func = ast.Attribute(
11995
value=ast.Attribute(attr="image_uris", value=ast.Name(id="sagemaker")),
12096
attr="retrieve",

tests/unit/sagemaker/cli/compatibility/v2/modifiers/test_algorithm_image_uris.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -54,27 +54,29 @@ def test_methodnode_should_be_modified_random_call():
5454
def test_method_modify_node(methods, caplog):
5555
modifier = image_uris.ImageURIRetrieveRefactor()
5656

57-
method = "get_image_uri('us-west-2', 'sagemaker-xgboost')"
57+
method = "get_image_uri('us-west-2', 'xgboost')"
5858
node = ast_call(method)
5959
modifier.modify_node(node)
6060
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
6161

62-
method = "amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost')"
62+
method = "amazon_estimator.get_image_uri('us-west-2', 'xgboost')"
6363
node = ast_call(method)
6464
modifier.modify_node(node)
6565
assert "image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
6666

67-
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='sagemaker-xgboost')"
67+
method = "sagemaker.get_image_uri(repo_region='us-west-2', repo_name='xgboost')"
6868
node = ast_call(method)
6969
modifier.modify_node(node)
7070
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
7171

72-
method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='sagemaker-xgboost')"
72+
method = "sagemaker.amazon_estimator.get_image_uri('us-west-2', repo_name='xgboost')"
7373
node = ast_call(method)
7474
modifier.modify_node(node)
7575
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2')" == pasta.dump(node)
7676

77-
method = "sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'sagemaker-xgboost', repo_version='1')"
77+
method = (
78+
"sagemaker.amazon.amazon_estimator.get_image_uri('us-west-2', 'xgboost', repo_version='1')"
79+
)
7880
node = ast_call(method)
7981
modifier.modify_node(node)
8082
assert "sagemaker.image_uris.retrieve('xgboost', 'us-west-2', '1')" == pasta.dump(node)

0 commit comments

Comments
 (0)