@@ -68,6 +68,34 @@ def modify_node(self, node):
68
68
keyword .arg = self .new_param_name
69
69
70
70
71
+ class MethodParamRenamer (ParamRenamer ):
72
+ """Abstract class to handle parameter renames for methods that belong to objects.
73
+
74
+ This differs from ``ParamRenamer`` in that a node for a standalone function call
75
+ (i.e. where ``node.func`` is an ``ast.Name`` rather than an ``ast.Attribute``) is not modified.
76
+ """
77
+
78
+ def node_should_be_modified (self , node ):
79
+ """Checks if the node matches any of the relevant functions and
80
+ contains the parameter to be renamed.
81
+
82
+ This looks for a call of the form ``<object>.<method>``, and
83
+ assumes the method cannot be called on its own.
84
+
85
+ Args:
86
+ node (ast.Call): a node that represents a function call. For more,
87
+ see https://docs.python.org/3/library/ast.html#abstract-grammar.
88
+
89
+ Returns:
90
+ bool: If the ``ast.Call`` matches the relevant function calls and
91
+ contains the parameter to be renamed.
92
+ """
93
+ if isinstance (node .func , ast .Name ):
94
+ return False
95
+
96
+ return super (MethodParamRenamer , self ).node_should_be_modified (node )
97
+
98
+
71
99
class DistributionParameterRenamer (ParamRenamer ):
72
100
"""A class to rename the ``distributions`` attribute to ``distrbution`` in
73
101
MXNet and TensorFlow estimators.
@@ -100,7 +128,7 @@ def new_param_name(self):
100
128
return "distribution"
101
129
102
130
103
- class S3SessionRenamer (ParamRenamer ):
131
+ class S3SessionRenamer (MethodParamRenamer ):
104
132
"""A class to rename the ``session`` attribute to ``sagemaker_session`` in
105
133
``S3Uploader`` and ``S3Downloader``.
106
134
@@ -139,15 +167,6 @@ def new_param_name(self):
139
167
"""The new name for the SageMaker session argument."""
140
168
return "sagemaker_session"
141
169
142
- def node_should_be_modified (self , node ):
143
- """Checks if the node is one of the S3 utility functions and
144
- contains the ``session`` parameter.
145
- """
146
- if isinstance (node .func , ast .Name ):
147
- return False
148
-
149
- return super (S3SessionRenamer , self ).node_should_be_modified (node )
150
-
151
170
152
171
class EstimatorImageURIRenamer (ParamRenamer ):
153
172
"""A class to rename the ``image_name`` attribute to ``image_uri`` in estimators."""
@@ -209,3 +228,93 @@ def old_param_name(self):
209
228
def new_param_name (self ):
210
229
"""The new name for the image URI argument."""
211
230
return "image_uri"
231
+
232
+
233
+ class EstimatorCreateModelImageURIRenamer (MethodParamRenamer ):
234
+ """A class to rename ``image`` to ``image_uri`` in estimator ``create_model()`` methods."""
235
+
236
+ @property
237
+ def calls_to_modify (self ):
238
+ """A mapping of ``create_model`` to common variable names for estimators."""
239
+ return {
240
+ "create_model" : (
241
+ "estimator" ,
242
+ "chainer" ,
243
+ "mxnet" ,
244
+ "mx" ,
245
+ "pytorch" ,
246
+ "rl" ,
247
+ "sklearn" ,
248
+ "tensorflow" ,
249
+ "tf" ,
250
+ "xgboost" ,
251
+ "xgb" ,
252
+ )
253
+ }
254
+
255
+ @property
256
+ def old_param_name (self ):
257
+ """The previous name for the image URI argument."""
258
+ return "image"
259
+
260
+ @property
261
+ def new_param_name (self ):
262
+ """The new name for the the image URI argument."""
263
+ return "image_uri"
264
+
265
+
266
+ class SessionCreateModelImageURIRenamer (MethodParamRenamer ):
267
+ """A class to rename ``primary_container_image`` to ``image_uri``.
268
+
269
+ This looks for the following calls:
270
+
271
+ - ``sagemaker_session.create_model_from_job()``
272
+ - ``sess.create_model_from_job()``
273
+ """
274
+
275
+ @property
276
+ def calls_to_modify (self ):
277
+ """A mapping of ``create_model_from_job`` to common variable names for Session."""
278
+ return {
279
+ "create_model_from_job" : ("sagemaker_session" , "sess" ),
280
+ }
281
+
282
+ @property
283
+ def old_param_name (self ):
284
+ """The previous name for the image URI argument."""
285
+ return "primary_container_image"
286
+
287
+ @property
288
+ def new_param_name (self ):
289
+ """The new name for the the image URI argument."""
290
+ return "image_uri"
291
+
292
+
293
+ class SessionCreateEndpointImageURIRenamer (MethodParamRenamer ):
294
+ """A class to rename ``deployment_image`` to ``image_uri``.
295
+
296
+ This looks for the following calls:
297
+
298
+ - ``sagemaker_session.endpoint_from_job()``
299
+ - ``sess.endpoint_from_job()``
300
+ - ``sagemaker_session.endpoint_from_model_data()``
301
+ - ``sess.endpoint_from_model_data()``
302
+ """
303
+
304
+ @property
305
+ def calls_to_modify (self ):
306
+ """A mapping of the ``endpoint_from_*`` functions to common variable names for Session."""
307
+ return {
308
+ "endpoint_from_job" : ("sagemaker_session" , "sess" ),
309
+ "endpoint_from_model_data" : ("sagemaker_session" , "sess" ),
310
+ }
311
+
312
+ @property
313
+ def old_param_name (self ):
314
+ """The previous name for the image URI argument."""
315
+ return "deployment_image"
316
+
317
+ @property
318
+ def new_param_name (self ):
319
+ """The new name for the the image URI argument."""
320
+ return "image_uri"
0 commit comments