Skip to content

Commit 92e6d57

Browse files
committed
change: create ASTTransformer class to handle migrating Python SDK code for v2
As a start, this class ensures that the framework_version parameter is specified when framework classes are instantiated.
1 parent f5814d5 commit 92e6d57

File tree

7 files changed

+252
-0
lines changed

7 files changed

+252
-0
lines changed

tools/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import

tools/compatibility/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import

tools/compatibility/v2/__init__.py

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from modifiers import framework_version
19+
20+
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
21+
22+
23+
class ASTTransformer(ast.NodeTransformer):
24+
"""An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and
25+
modifies nodes to upgrade the given SageMaker Python SDK code.
26+
"""
27+
28+
def visit_Call(self, node):
29+
"""Visits an ``ast.Call`` node and returns a modified node, if needed.
30+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
31+
32+
Args:
33+
node (ast.Call): a node that represents a function call.
34+
35+
Returns:
36+
ast.Call: a node that represents a function call, which has
37+
potentially been modified from the original input.
38+
"""
39+
for function_checker in FUNCTION_CALL_MODIFIERS:
40+
function_checker.check_and_modify_node(node)
41+
return node
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for modifying AST nodes"""
14+
from __future__ import absolute_import
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""A class to ensure that ``framework_version`` is defined when constructing framework classes."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from modifiers.modifier import Modifier
19+
20+
FRAMEWORK_DEFAULTS = {
21+
"Chainer": "4.1.0",
22+
"MXNet": "1.2.0",
23+
"PyTorch": "0.4.0",
24+
"SKLearn": "0.20.0",
25+
"TensorFlow": "1.11.0",
26+
}
27+
28+
FRAMEWORKS = list(FRAMEWORK_DEFAULTS.keys())
29+
# TODO: check for sagemaker.tensorflow.serving.Model
30+
FRAMEWORK_CLASSES = FRAMEWORKS + ["{}Model".format(fw) for fw in FRAMEWORKS]
31+
FRAMEWORK_MODULES = [fw.lower() for fw in FRAMEWORKS]
32+
33+
34+
class FrameworkVersionEnforcer(Modifier):
35+
def node_should_be_modified(self, node):
36+
"""Check if the ast.Call node instantiates a framework estimator or model,
37+
but doesn't specify the framework_version parameter.
38+
39+
This looks for the following formats:
40+
41+
- ``TensorFlow``
42+
- ``sagemaker.tensorflow.TensorFlow``
43+
44+
where "TensorFlow" can be Chainer, MXNet, PyTorch, SKLearn, or TensorFlow.
45+
46+
Args:
47+
node (ast.Call): a node that represents a function call. For more,
48+
see https://docs.python.org/3/library/ast.html#abstract-grammar.
49+
50+
Returns:
51+
bool: If the ``ast.Call`` is instantiating a framework class that
52+
should specify ``framework_version``, but doesn't.
53+
"""
54+
if self._is_framework_constructor(node):
55+
return not self._fw_version_in_keywords(node)
56+
57+
return False
58+
59+
def _is_framework_constructor(self, node):
60+
"""Check if the ``ast.Call`` node represents a call of the form
61+
<Framework> or sagemaker.<framework>.<Framework>.
62+
"""
63+
if isinstance(node.func, ast.Name):
64+
if node.func.id in FRAMEWORK_CLASSES:
65+
return True
66+
67+
if (
68+
isinstance(node.func, ast.Attribute)
69+
and node.func.attr in FRAMEWORK_CLASSES
70+
and isinstance(node.func.value, ast.Attribute)
71+
and node.func.value.attr in FRAMEWORK_MODULES
72+
and isinstance(node.func.value.value, ast.Name)
73+
and node.func.value.value.id == "sagemaker"
74+
):
75+
return True
76+
77+
return False
78+
79+
def _fw_version_in_keywords(self, node):
80+
"""Check if the ``ast.Call`` node's keywords contain ``framework_version``."""
81+
for kw in node.keywords:
82+
if kw.arg == "framework_version" and kw.value:
83+
return True
84+
return False
85+
86+
def modify_node(self, node):
87+
"""Modify the ``ast.Call`` node's keywords to include ``framework_version``.
88+
89+
The ``framework_version`` value is determined by the framework:
90+
91+
- Chainer: "4.1.0"
92+
- MXNet: "1.2.0"
93+
- PyTorch: "0.4.0"
94+
- SKLearn: "0.20.0"
95+
- TensorFlow: "1.11.0"
96+
97+
Args:
98+
node (ast.Call): a node that represents the constructor of a framework class.
99+
"""
100+
framework = self._framework_name_from_node(node)
101+
node.keywords.append(
102+
ast.keyword(arg="framework_version", value=ast.Str(s=FRAMEWORK_DEFAULTS[framework]))
103+
)
104+
105+
def _framework_name_from_node(self, node):
106+
"""Retrieve the framework name based on the function call.
107+
108+
Args:
109+
node (ast.Call): a node that represents the constructor of a framework class.
110+
This can represent either <Framework> or sagemaker.<framework>.<Framework>.
111+
112+
Returns:
113+
str: the (capitalized) framework name.
114+
"""
115+
if isinstance(node.func, ast.Name):
116+
framework = node.func.id
117+
elif isinstance(node.func, ast.Attribute):
118+
framework = node.func.attr
119+
120+
if framework.endswith("Model"):
121+
framework = framework[:framework.find("Model")]
122+
123+
return framework
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Abstract class for modifying AST nodes."""
14+
from __future__ import absolute_import
15+
16+
from abc import abstractmethod
17+
18+
19+
class Modifier(object):
20+
"""Abstract class to take in an AST node, check if it needs modification,
21+
and potentially modify the node.
22+
"""
23+
24+
def check_and_modify_node(self, node):
25+
"""Check an AST node, and modify it if applicable."""
26+
if self.node_should_be_modified(node):
27+
self.modify_node(node)
28+
29+
@abstractmethod
30+
def node_should_be_modified(self, node):
31+
"""Check if an AST node should be modified."""
32+
33+
@abstractmethod
34+
def modify_node(self, node):
35+
"""Modify an AST node."""

0 commit comments

Comments
 (0)