16
16
from __future__ import absolute_import
17
17
18
18
import ast
19
+ import logging
20
+ import pasta
19
21
20
22
from sagemaker .cli .compatibility .v2 .modifiers import matching
21
23
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
22
24
23
25
GET_IMAGE_URI_NAME = "get_image_uri"
24
- GET_IMAGE_URI_NAMESPACES = ("sagemaker" , "sagemaker.amazon_estimator" )
26
+ GET_IMAGE_URI_NAMESPACES = (
27
+ "sagemaker" ,
28
+ "sagemaker.amazon_estimator" ,
29
+ "sagemaker.amazon.amazon_estimator" ,
30
+ "amazon_estimator" ,
31
+ "amazon.amazon_estimator" ,
32
+ )
33
+
34
+ ALGORITHM_NAME_FROM_REPO = {
35
+ "blazingtext" : "blazingtext" ,
36
+ "sagemaker-rl-mxnet" : "coach-mxnet" ,
37
+ "sagemaker-rl-tensorflow" : ["coach-tensorflow" , "ray-tensorflow" ],
38
+ "factorization-machine" : "factorization-machines" ,
39
+ "forecasting-deepar" : "forecasting-deepar" ,
40
+ "image-classification" : "image-classification" ,
41
+ "image-classification-neo" : "image-classification-neo" ,
42
+ "ipinsights" : "ipinsights" ,
43
+ "kmeans" : "kmeans" ,
44
+ "knn" : "knn" ,
45
+ "lda" : "lda" ,
46
+ "linear-learner" : "linear-learner" ,
47
+ "ntm" : "ntm" ,
48
+ "object2vec" : "object2vec" ,
49
+ "object-detection" : "object-detection" ,
50
+ "pca" : "pca" ,
51
+ "randomcutforest" : "randomcutforest" ,
52
+ "sagemaker-rl-ray-container" : "ray-pytorch" ,
53
+ "semantic-segmentation" : "semantic-segmentation" ,
54
+ "seq2seq" : "seq2seq" ,
55
+ "sagemaker-scikit-learn" : "sklearn" ,
56
+ "sagemaker-sparkml-serving" : "sparkml-serving" ,
57
+ "sagemaker-rl-vw-container" : "vw" ,
58
+ "sagemaker-xgboost" : "xgboost" ,
59
+ "xgboost-neo" : "xgboost-neo" ,
60
+ }
61
+
62
+ logger = logging .getLogger ("sagemaker" )
25
63
26
64
27
65
class ImageURIRetrieveRefactor (Modifier ):
@@ -43,7 +81,9 @@ def node_should_be_modified(self, node):
43
81
Returns:
44
82
bool: If the ``ast.Call`` instantiates a class of interest.
45
83
"""
46
- return matching .matches_name_or_namespaces (node , GET_IMAGE_URI_NAME , GET_IMAGE_URI_NAMESPACES )
84
+ return matching .matches_name_or_namespaces (
85
+ node , GET_IMAGE_URI_NAME , GET_IMAGE_URI_NAMESPACES
86
+ )
47
87
48
88
def modify_node (self , node ):
49
89
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
@@ -52,12 +92,59 @@ def modify_node(self, node):
52
92
Args:
53
93
node (ast.Call): a node that represents a *image_uris.retrieve call.
54
94
"""
55
- if matching .matches_name (node , GET_IMAGE_URI_NAME ):
56
- node .func .id = "image_uris.retrieve"
57
- node .func .params .argOne , node .func .params .argTwo = node .func .params .argTwo , node .func .params .argOne
58
- elif matching .matches_attr (node , GET_IMAGE_URI_NAME ):
59
- node .func .attr = "image_uris.retrieve"
60
- node .func .params .argOne , node .func .params .argTwo = node .func .params .argTwo , node .func .params .argOne
95
+ original_args = [None ] * 3
96
+ for kw in node .keywords :
97
+ if kw .arg == "repo_name" :
98
+ arg = kw .value .s
99
+ modified_arg = ALGORITHM_NAME_FROM_REPO [arg ]
100
+ if isinstance (modified_arg , list ):
101
+ logger .warning (
102
+ "There are more than one value mapping to {}, {} will be used" .format (
103
+ arg , modified_arg [0 ]
104
+ )
105
+ )
106
+ modified_arg = modified_arg [0 ]
107
+ original_args [0 ] = ast .Str (modified_arg )
108
+ elif kw .arg == "repo_region" :
109
+ original_args [1 ] = ast .Str (kw .value .s )
110
+ elif kw .arg == "repo_version" :
111
+ original_args [2 ] = ast .Str (kw .value .s )
112
+
113
+ if len (node .args ) > 0 :
114
+ original_args [1 ] = ast .Str (node .args [0 ].s )
115
+ if len (node .args ) > 1 :
116
+ arg = node .args [1 ].s
117
+ modified_arg = ALGORITHM_NAME_FROM_REPO [arg ]
118
+ if isinstance (modified_arg , list ):
119
+ logger .warning (
120
+ "There are more than one value mapping to {}, {} will be used" .format (
121
+ arg , modified_arg [0 ]
122
+ )
123
+ )
124
+ modified_arg = modified_arg [0 ]
125
+ original_args [0 ] = ast .Str (modified_arg )
126
+ if len (node .args ) > 2 :
127
+ original_args [2 ] = ast .Str (node .args [2 ].s )
128
+
129
+ args = []
130
+ for arg in original_args :
131
+ if arg :
132
+ args .append (arg )
133
+
134
+ if matching .matches_name (node , GET_IMAGE_URI_NAME ) or matching .matches_attr (
135
+ node , GET_IMAGE_URI_NAME
136
+ ):
137
+ node_components = list (pasta .dump (node ).split ("." ))
138
+ node_modules = node_components [: len (node_components ) - 1 ]
139
+ if "sagemaker" in node_modules :
140
+ node .func = ast .Attribute (
141
+ value = ast .Attribute (attr = "image_uris" , value = ast .Name (id = "sagemaker" )),
142
+ attr = "retrieve" ,
143
+ )
144
+ else :
145
+ node .func = ast .Attribute (value = ast .Name (id = "image_uris" ), attr = "retrieve" )
146
+ node .args = args
147
+ node .keywords = []
61
148
return node
62
149
63
150
@@ -75,7 +162,7 @@ def node_should_be_modified(self, node):
75
162
bool: If the import statement imports ``get_image_uri`` from the correct module.
76
163
"""
77
164
return node .module in GET_IMAGE_URI_NAMESPACES and any (
78
- name .name == GET_IMAGE_URI_NAMESPACES for name in node .names
165
+ name .name == GET_IMAGE_URI_NAME for name in node .names
79
166
)
80
167
81
168
def modify_node (self , node ):
@@ -91,6 +178,6 @@ def modify_node(self, node):
91
178
for name in node .names :
92
179
if name .name == GET_IMAGE_URI_NAME :
93
180
name .name = "image_uris"
94
- if node .module == "sagemaker.amazon_estimator" :
181
+ if node .module in GET_IMAGE_URI_NAMESPACES :
95
182
node .module = "sagemaker"
96
183
return node
0 commit comments