14
14
from __future__ import absolute_import
15
15
16
16
import ast
17
- import logging
18
- import pasta
19
17
20
18
from sagemaker .cli .compatibility .v2 .modifiers import matching
21
19
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
29
27
"amazon.amazon_estimator" ,
30
28
)
31
29
32
- ALGORITHM_NAME_FROM_REPO = {
33
- "blazingtext" : "blazingtext" ,
34
- "factorization-machine" : "factorization-machines" ,
35
- "forecasting-deepar" : "forecasting-deepar" ,
36
- "image-classification" : "image-classification" ,
37
- "image-classification-neo" : "image-classification-neo" ,
38
- "ipinsights" : "ipinsights" ,
39
- "kmeans" : "kmeans" ,
40
- "knn" : "knn" ,
41
- "lda" : "lda" ,
42
- "linear-learner" : "linear-learner" ,
43
- "ntm" : "ntm" ,
44
- "object2vec" : "object2vec" ,
45
- "object-detection" : "object-detection" ,
46
- "pca" : "pca" ,
47
- "randomcutforest" : "randomcutforest" ,
48
- "semantic-segmentation" : "semantic-segmentation" ,
49
- "seq2seq" : "seq2seq" ,
50
- "sagemaker-xgboost" : "xgboost" ,
51
- "xgboost-neo" : "xgboost-neo" ,
52
- }
53
-
54
- logger = logging .getLogger ("sagemaker" )
55
-
56
30
57
31
class ImageURIRetrieveRefactor (Modifier ):
58
32
"""A class to refactor *get_image_uri() method."""
@@ -87,9 +61,7 @@ def modify_node(self, node):
87
61
original_args = [None ] * 3
88
62
for kw in node .keywords :
89
63
if kw .arg == "repo_name" :
90
- arg = kw .value .s
91
- modified_arg = ALGORITHM_NAME_FROM_REPO [arg ]
92
- original_args [0 ] = ast .Str (modified_arg )
64
+ original_args [0 ] = ast .Str (kw .value .s )
93
65
elif kw .arg == "repo_region" :
94
66
original_args [1 ] = ast .Str (kw .value .s )
95
67
elif kw .arg == "repo_version" :
@@ -98,9 +70,7 @@ def modify_node(self, node):
98
70
if len (node .args ) > 0 :
99
71
original_args [1 ] = ast .Str (node .args [0 ].s )
100
72
if len (node .args ) > 1 :
101
- arg = node .args [1 ].s
102
- modified_arg = ALGORITHM_NAME_FROM_REPO [arg ]
103
- original_args [0 ] = ast .Str (modified_arg )
73
+ original_args [0 ] = ast .Str (node .args [1 ].s )
104
74
if len (node .args ) > 2 :
105
75
original_args [2 ] = ast .Str (node .args [2 ].s )
106
76
@@ -112,9 +82,15 @@ def modify_node(self, node):
112
82
if matching .matches_name (node , GET_IMAGE_URI_NAME ) or matching .matches_attr (
113
83
node , GET_IMAGE_URI_NAME
114
84
):
115
- node_components = list (pasta .dump (node ).split ("." ))
116
- node_modules = node_components [: len (node_components ) - 1 ]
117
- if "sagemaker" in node_modules :
85
+ func = node .func
86
+ has_sagemaker = False
87
+ while hasattr (func , "value" ):
88
+ if hasattr (func .value , "id" ) and func .value .id == "sagemaker" :
89
+ has_sagemaker = True
90
+ break
91
+ func = func .value
92
+
93
+ if has_sagemaker :
118
94
node .func = ast .Attribute (
119
95
value = ast .Attribute (attr = "image_uris" , value = ast .Name (id = "sagemaker" )),
120
96
attr = "retrieve" ,
0 commit comments