Skip to content

Commit de54a76

Browse files
authored
change: update v2 migration tool to rename TFS classes/imports (#1552)
1 parent baf1c35 commit de54a76

File tree

5 files changed

+272
-0
lines changed

5 files changed

+272
-0
lines changed

src/sagemaker/cli/compatibility/v2/ast_transformer.py

+39
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,13 @@
2222
modifiers.tf_legacy_mode.TensorFlowLegacyModeConstructorUpgrader(),
2323
modifiers.tf_legacy_mode.TensorBoardParameterRemover(),
2424
modifiers.deprecated_params.TensorFlowScriptModeParameterRemover(),
25+
modifiers.tfs.TensorFlowServingConstructorRenamer(),
2526
]
2627

28+
IMPORT_MODIFIERS = [modifiers.tfs.TensorFlowServingImportRenamer()]
29+
30+
IMPORT_FROM_MODIFIERS = [modifiers.tfs.TensorFlowServingImportFromRenamer()]
31+
2732

2833
class ASTTransformer(ast.NodeTransformer):
2934
"""An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and
@@ -46,3 +51,37 @@ def visit_Call(self, node):
4651

4752
ast.fix_missing_locations(node)
4853
return node
54+
55+
def visit_Import(self, node):
56+
"""Visits an ``ast.Import`` node and returns a modified node, if needed.
57+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
58+
59+
Args:
60+
node (ast.Import): a node that represents an import statement.
61+
62+
Returns:
63+
ast.Import: a node that represents an import statement, which has
64+
potentially been modified from the original input.
65+
"""
66+
for import_checker in IMPORT_MODIFIERS:
67+
import_checker.check_and_modify_node(node)
68+
69+
ast.fix_missing_locations(node)
70+
return node
71+
72+
def visit_ImportFrom(self, node):
73+
"""Visits an ``ast.ImportFrom`` node and returns a modified node, if needed.
74+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
75+
76+
Args:
77+
node (ast.ImportFrom): a node that represents an import statement.
78+
79+
Returns:
80+
ast.ImportFrom: a node that represents an import statement, which has
81+
potentially been modified from the original input.
82+
"""
83+
for import_checker in IMPORT_FROM_MODIFIERS:
84+
import_checker.check_and_modify_node(node)
85+
86+
ast.fix_missing_locations(node)
87+
return node

src/sagemaker/cli/compatibility/v2/modifiers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,5 @@
1717
deprecated_params,
1818
framework_version,
1919
tf_legacy_mode,
20+
tfs,
2021
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes to modify TensorFlow Serving code to be compatible with SageMaker Python SDK v2."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from sagemaker.cli.compatibility.v2.modifiers.modifier import Modifier
19+
20+
21+
class TensorFlowServingConstructorRenamer(Modifier):
22+
"""A class to rename TensorFlow Serving classes."""
23+
24+
def node_should_be_modified(self, node):
25+
"""Checks if the ``ast.Call`` node instantiates a TensorFlow Serving class.
26+
27+
This looks for the following calls:
28+
29+
- ``sagemaker.tensorflow.serving.Model``
30+
- ``sagemaker.tensorflow.serving.Predictor``
31+
- ``Predictor``
32+
33+
Because ``Model`` can refer to either ``sagemaker.tensorflow.serving.Model``
34+
or :class:`~sagemaker.model.Model`, ``Model`` on its own is not sufficient
35+
for indicating a TFS Model object.
36+
37+
Args:
38+
node (ast.Call): a node that represents a function call. For more,
39+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
40+
41+
Returns:
42+
bool: If the ``ast.Call`` instantiates a TensorFlow Serving class.
43+
"""
44+
if isinstance(node.func, ast.Name):
45+
return node.func.id == "Predictor"
46+
47+
if not (isinstance(node.func, ast.Attribute) and node.func.attr in ("Model", "Predictor")):
48+
return False
49+
50+
return (
51+
isinstance(node.func.value, ast.Attribute)
52+
and node.func.value.attr == "serving"
53+
and isinstance(node.func.value.value, ast.Attribute)
54+
and node.func.value.value.attr == "tensorflow"
55+
and isinstance(node.func.value.value.value, ast.Name)
56+
and node.func.value.value.value.id == "sagemaker"
57+
)
58+
59+
def modify_node(self, node):
60+
"""Modifies the ``ast.Call`` node to use the v2 classes for TensorFlow Serving:
61+
62+
- ``sagemaker.tensorflow.TensorFlowModel``
63+
- ``sagemaker.tensorflow.TensorFlowPredictor``
64+
65+
Args:
66+
node (ast.Call): a node that represents a TensorFlow Serving constructor.
67+
"""
68+
if isinstance(node.func, ast.Name):
69+
node.func.id = self._new_cls_name(node.func.id)
70+
else:
71+
node.func.attr = self._new_cls_name(node.func.attr)
72+
node.func.value = node.func.value.value
73+
74+
def _new_cls_name(self, cls_name):
75+
"""Returns the v2 class name."""
76+
return "TensorFlow{}".format(cls_name)
77+
78+
79+
class TensorFlowServingImportFromRenamer(Modifier):
80+
"""A class to update import statements starting with ``from sagemaker.tensorflow.serving``."""
81+
82+
def node_should_be_modified(self, node):
83+
"""Checks if the import statement imports from the ``sagemaker.tensorflow.serving`` module.
84+
85+
Args:
86+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
87+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
88+
89+
Returns:
90+
bool: If the ``ast.ImportFrom`` uses the ``sagemaker.tensorflow.serving`` module.
91+
"""
92+
return node.module == "sagemaker.tensorflow.serving"
93+
94+
def modify_node(self, node):
95+
"""Changes the ``ast.ImportFrom`` node's module to ``sagemaker.tensorflow`` and updates the
96+
imported class names to ``TensorFlowModel`` and ``TensorFlowPredictor``, as applicable.
97+
98+
Args:
99+
node (ast.ImportFrom): a node that represents a ``from ... import ... `` statement.
100+
For more, see https://docs.python.org/3/library/ast.html#abstract-grammar.
101+
"""
102+
node.module = "sagemaker.tensorflow"
103+
104+
for cls in node.names:
105+
cls.name = "TensorFlow{}".format(cls.name)
106+
107+
108+
class TensorFlowServingImportRenamer(Modifier):
109+
"""A class to update ``import sagemaker.tensorflow.serving``."""
110+
111+
def check_and_modify_node(self, node):
112+
"""Checks if the ``ast.Import`` node imports the ``sagemaker.tensorflow.serving`` module
113+
and, if so, changes it to ``sagemaker.tensorflow``.
114+
115+
Args:
116+
node (ast.Import): a node that represents an import statement. For more,
117+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
118+
"""
119+
for module in node.names:
120+
if module.name == "sagemaker.tensorflow.serving":
121+
module.name = "sagemaker.tensorflow"

tests/unit/sagemaker/cli/compatibility/v2/modifiers/ast_converter.py

+4
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,7 @@
1717

1818
def ast_call(code):
1919
return pasta.parse(code).body[0].value
20+
21+
22+
def ast_import(code):
23+
return pasta.parse(code).body[0]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
15+
import pasta
16+
17+
from sagemaker.cli.compatibility.v2.modifiers import tfs
18+
from tests.unit.sagemaker.cli.compatibility.v2.modifiers.ast_converter import ast_call, ast_import
19+
20+
21+
def test_constructor_node_should_be_modified_tfs_constructor():
22+
tfs_constructors = (
23+
"sagemaker.tensorflow.serving.Model()",
24+
"sagemaker.tensorflow.serving.Predictor()",
25+
"Predictor()",
26+
)
27+
28+
modifier = tfs.TensorFlowServingConstructorRenamer()
29+
30+
for constructor in tfs_constructors:
31+
node = ast_call(constructor)
32+
assert modifier.node_should_be_modified(node) is True
33+
34+
35+
def test_constructor_node_should_be_modified_random_function_call():
36+
modifier = tfs.TensorFlowServingConstructorRenamer()
37+
node = ast_call("Model()")
38+
assert modifier.node_should_be_modified(node) is False
39+
40+
41+
def test_constructor_modify_node():
42+
modifier = tfs.TensorFlowServingConstructorRenamer()
43+
44+
node = ast_call("sagemaker.tensorflow.serving.Model()")
45+
modifier.modify_node(node)
46+
assert "sagemaker.tensorflow.TensorFlowModel()" == pasta.dump(node)
47+
48+
node = ast_call("sagemaker.tensorflow.serving.Predictor()")
49+
modifier.modify_node(node)
50+
assert "sagemaker.tensorflow.TensorFlowPredictor()" == pasta.dump(node)
51+
52+
node = ast_call("Predictor()")
53+
modifier.modify_node(node)
54+
assert "TensorFlowPredictor()" == pasta.dump(node)
55+
56+
57+
def test_import_from_node_should_be_modified_tfs_module():
58+
import_statements = (
59+
"from sagemaker.tensorflow.serving import Model, Predictor",
60+
"from sagemaker.tensorflow.serving import Predictor",
61+
"from sagemaker.tensorflow.serving import Model as tfsModel",
62+
)
63+
64+
modifier = tfs.TensorFlowServingImportFromRenamer()
65+
66+
for import_from in import_statements:
67+
node = ast_import(import_from)
68+
assert modifier.node_should_be_modified(node) is True
69+
70+
71+
def test_import_from_node_should_be_modified_random_import():
72+
modifier = tfs.TensorFlowServingImportFromRenamer()
73+
node = ast_import("from sagemaker import Session")
74+
assert modifier.node_should_be_modified(node) is False
75+
76+
77+
def test_import_from_modify_node():
78+
modifier = tfs.TensorFlowServingImportFromRenamer()
79+
80+
node = ast_import("from sagemaker.tensorflow.serving import Model, Predictor")
81+
modifier.modify_node(node)
82+
expected_result = "from sagemaker.tensorflow import TensorFlowModel, TensorFlowPredictor"
83+
assert expected_result == pasta.dump(node)
84+
85+
node = ast_import("from sagemaker.tensorflow.serving import Predictor")
86+
modifier.modify_node(node)
87+
assert "from sagemaker.tensorflow import TensorFlowPredictor" == pasta.dump(node)
88+
89+
node = ast_import("from sagemaker.tensorflow.serving import Model as tfsModel")
90+
modifier.modify_node(node)
91+
assert "from sagemaker.tensorflow import TensorFlowModel as tfsModel" == pasta.dump(node)
92+
93+
94+
def test_import_check_and_modify_node_tfs_import():
95+
modifier = tfs.TensorFlowServingImportRenamer()
96+
node = ast_import("import sagemaker.tensorflow.serving")
97+
modifier.check_and_modify_node(node)
98+
assert "import sagemaker.tensorflow" == pasta.dump(node)
99+
100+
101+
def test_import_check_and_modify_node_random_import():
102+
modifier = tfs.TensorFlowServingImportRenamer()
103+
104+
import_statement = "import random"
105+
node = ast_import(import_statement)
106+
modifier.check_and_modify_node(node)
107+
assert import_statement == pasta.dump(node)

0 commit comments

Comments
 (0)