Skip to content

Commit b34d680

Browse files
author
Chuyang Deng
committed
add migration tool for image_uris.retrieve
1 parent f573f35 commit b34d680

File tree

1 file changed

+96
-0
lines changed

1 file changed

+96
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify image_uris.retrieve() code to be compatible
14+
with version 2.0 and later of the SageMaker Python SDK.
15+
"""
16+
from __future__ import absolute_import
17+
18+
import ast
19+
20+
from sagemaker.cli.compatibility.v2.modifiers import matching
21+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
22+
23+
GET_IMAGE_URI_NAME = "get_image_uri"
24+
GET_IMAGE_URI_NAMESPACES = ("sagemaker", "sagemaker.amazon_estimator")
25+
26+
27+
class ImageURIRetrieveRefactor(Modifier):
28+
"""A class to refactor *get_image_uri() method."""
29+
30+
def node_should_be_modified(self, node):
31+
"""Checks if the ``ast.Call`` node calls a function of interest.
32+
33+
This looks for the following calls:
34+
35+
- ``sagemaker.get_image_uri``
36+
- ``sagemaker.amazon_estimator.get_image_uri``
37+
- ``get_image_uri``
38+
39+
Args:
40+
node (ast.Call): a node that represents a function call. For more,
41+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
42+
43+
Returns:
44+
bool: If the ``ast.Call`` instantiates a class of interest.
45+
"""
46+
return matching.matches_name_or_namespaces(node, GET_IMAGE_URI_NAME, GET_IMAGE_URI_NAMESPACES)
47+
48+
def modify_node(self, node):
49+
"""Modifies the ``ast.Call`` node to call ``image_uris.retrieve`` instead.
50+
And switch the first two parameters from (region, repo) to (framework, region)
51+
52+
Args:
53+
node (ast.Call): a node that represents a *image_uris.retrieve call.
54+
"""
55+
if matching.matches_name(node, GET_IMAGE_URI_NAME):
56+
node.func.id = "image_uris.retrieve"
57+
node.func.params.argOne, node.func.params.argTwo = node.func.params.argTwo, node.func.params.argOne
58+
elif matching.matches_attr(node, GET_IMAGE_URI_NAME):
59+
node.func.attr = "image_uris.retrieve"
60+
node.func.params.argOne, node.func.params.argTwo = node.func.params.argTwo, node.func.params.argOne
61+
return node
62+
63+
64+
class ImageURIRetrieveImportFromRenamer(Modifier):
65+
"""A class to update import statements of ``get_image_uri``."""
66+
67+
def node_should_be_modified(self, node):
68+
"""Checks if the import statement imports ``get_image_uri`` from the correct module.
69+
70+
Args:
71+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
72+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
73+
74+
Returns:
75+
bool: If the import statement imports ``get_image_uri`` from the correct module.
76+
"""
77+
return node.module in GET_IMAGE_URI_NAMESPACES and any(
78+
name.name == GET_IMAGE_URI_NAMESPACES for name in node.names
79+
)
80+
81+
def modify_node(self, node):
82+
"""Changes the ``ast.ImportFrom`` node's name from ``get_image_uri`` to ``image_uris``.
83+
84+
Args:
85+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
86+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
87+
88+
Returns:
89+
ast.AST: the original node, which has been potentially modified.
90+
"""
91+
for name in node.names:
92+
if name.name == GET_IMAGE_URI_NAME:
93+
name.name = "image_uris"
94+
if node.module == "sagemaker.amazon_estimator":
95+
node.module = "sagemaker"
96+
return node

0 commit comments

Comments
 (0)