Skip to content

Commit e32ca32

Browse files
authored
update with aws zwei
update with aws zwei
2 parents 4999c58 + 57b2a22 commit e32ca32

File tree

10 files changed

+519
-1
lines changed

10 files changed

+519
-1
lines changed

tools/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools to assist with using the SageMake Python SDK."""
14+
from __future__ import absolute_import

tools/compatibility/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools to assist with compatibility between SageMaker Python SDK versions."""
14+
from __future__ import absolute_import

tools/compatibility/v2/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Tools to assist with upgrading to v2 of the SageMaker Python SDK."""
14+
from __future__ import absolute_import
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""An ast.NodeTransformer subclass for updating SageMaker Python SDK code."""
14+
from __future__ import absolute_import
15+
16+
import ast
17+
18+
from modifiers import framework_version
19+
20+
FUNCTION_CALL_MODIFIERS = [framework_version.FrameworkVersionEnforcer()]
21+
22+
23+
class ASTTransformer(ast.NodeTransformer):
24+
"""An ``ast.NodeTransformer`` subclass that walks the abstract syntax tree and
25+
modifies nodes to upgrade the given SageMaker Python SDK code.
26+
"""
27+
28+
def visit_Call(self, node):
29+
"""Visits an ``ast.Call`` node and returns a modified node, if needed.
30+
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer.
31+
32+
Args:
33+
node (ast.Call): a node that represents a function call.
34+
35+
Returns:
36+
ast.Call: a node that represents a function call, which has
37+
potentially been modified from the original input.
38+
"""
39+
for function_checker in FUNCTION_CALL_MODIFIERS:
40+
function_checker.check_and_modify_node(node)
41+
return node

tools/compatibility/v2/files.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for updating code in files."""
14+
from __future__ import absolute_import
15+
16+
from abc import abstractmethod
17+
import json
18+
import logging
19+
import os
20+
21+
import pasta
22+
23+
from ast_transformer import ASTTransformer
24+
25+
LOGGER = logging.getLogger(__name__)
26+
27+
28+
class FileUpdater(object):
29+
"""An abstract class for updating files."""
30+
31+
def __init__(self, input_path, output_path):
32+
"""Creates a ``FileUpdater`` for updating a file so that
33+
it is compatible with v2 of the SageMaker Python SDK.
34+
35+
Args:
36+
input_path (str): Location of the input file.
37+
output_path (str): Desired location for the output file.
38+
If the directories don't already exist, then they are created.
39+
If a file exists at ``output_path``, then it is overwritten.
40+
"""
41+
self.input_path = input_path
42+
self.output_path = output_path
43+
44+
@abstractmethod
45+
def update(self):
46+
"""Reads the input file, updates the code so that it is
47+
compatible with v2 of the SageMaker Python SDK, and writes the
48+
updated code to an output file.
49+
"""
50+
51+
52+
class PyFileUpdater(FileUpdater):
53+
"""A class for updating Python (``*.py``) files."""
54+
55+
def update(self):
56+
"""Reads the input Python file, updates the code so that it is
57+
compatible with v2 of the SageMaker Python SDK, and writes the
58+
updated code to an output file.
59+
"""
60+
output = self._update_ast(self._read_input_file())
61+
self._write_output_file(output)
62+
63+
def _update_ast(self, input_ast):
64+
"""Updates an abstract syntax tree (AST) so that it is compatible
65+
with v2 of the SageMaker Python SDK.
66+
67+
Args:
68+
input_ast (ast.Module): AST to be updated for use with Python SDK v2.
69+
70+
Returns:
71+
ast.Module: Updated AST that is compatible with Python SDK v2.
72+
"""
73+
return ASTTransformer().visit(input_ast)
74+
75+
def _read_input_file(self):
76+
"""Reads input file and parses it as an abstract syntax tree (AST).
77+
78+
Returns:
79+
ast.Module: AST representation of the input file.
80+
"""
81+
with open(self.input_path) as input_file:
82+
return pasta.parse(input_file.read())
83+
84+
def _write_output_file(self, output):
85+
"""Writes abstract syntax tree (AST) to output file.
86+
Creates the directories for the output path, if needed.
87+
88+
Args:
89+
output (ast.Module): AST to save as the output file.
90+
"""
91+
output_dir = os.path.dirname(self.output_path)
92+
if output_dir and not os.path.exists(output_dir):
93+
os.makedirs(output_dir)
94+
95+
if os.path.exists(self.output_path):
96+
LOGGER.warning("Overwriting file %s", self.output_path)
97+
98+
with open(self.output_path, "w") as output_file:
99+
output_file.write(pasta.dump(output))
100+
101+
102+
class JupyterNotebookFileUpdater(FileUpdater):
103+
"""A class for updating Jupyter notebook (``*.ipynb``) files.
104+
105+
For more on this file format, see
106+
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat.
107+
"""
108+
109+
def update(self):
110+
"""Reads the input Jupyter notebook file, updates the code so that it is
111+
compatible with v2 of the SageMaker Python SDK, and writes the
112+
updated code to an output file.
113+
"""
114+
nb_json = self._read_input_file()
115+
for cell in nb_json["cells"]:
116+
if cell["cell_type"] == "code":
117+
updated_source = self._update_code_from_cell(cell)
118+
cell["source"] = updated_source
119+
120+
self._write_output_file(nb_json)
121+
122+
def _update_code_from_cell(self, cell):
123+
"""Updates the code from a code cell so that it is
124+
compatible with v2 of the SageMaker Python SDK.
125+
126+
Args:
127+
cell (dict): A dictionary representation of a code cell from
128+
a Jupyter notebook. For more info, see
129+
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells.
130+
131+
Returns:
132+
list[str]: A list of strings containing the lines of updated code that
133+
can be used for the "source" attribute of a Jupyter notebook code cell.
134+
"""
135+
code = "".join(cell["source"])
136+
updated_ast = ASTTransformer().visit(pasta.parse(code))
137+
updated_code = pasta.dump(updated_ast)
138+
return self._code_str_to_source_list(updated_code)
139+
140+
def _code_str_to_source_list(self, code):
141+
"""Converts a string of code into a list for a Jupyter notebook code cell.
142+
143+
Args:
144+
code (str): Code to be converted.
145+
146+
Returns:
147+
list[str]: A list of strings containing the lines of code that
148+
can be used for the "source" attribute of a Jupyter notebook code cell.
149+
Each element of the list (i.e. line of code) contains a
150+
trailing newline character ("\n") except for the last element.
151+
"""
152+
source_list = ["{}\n".format(s) for s in code.split("\n")]
153+
source_list[-1] = source_list[-1].rstrip("\n")
154+
return source_list
155+
156+
def _read_input_file(self):
157+
"""Reads input file and parses it as JSON.
158+
159+
Returns:
160+
dict: JSON representation of the input file.
161+
"""
162+
with open(self.input_path) as input_file:
163+
return json.load(input_file)
164+
165+
def _write_output_file(self, output):
166+
"""Writes JSON to output file. Creates the directories for the output path, if needed.
167+
168+
Args:
169+
output (dict): JSON to save as the output file.
170+
"""
171+
output_dir = os.path.dirname(self.output_path)
172+
if output_dir and not os.path.exists(output_dir):
173+
os.makedirs(output_dir)
174+
175+
if os.path.exists(self.output_path):
176+
LOGGER.warning("Overwriting file %s", self.output_path)
177+
178+
with open(self.output_path, "w") as output_file:
179+
json.dump(output, output_file, indent=1)
180+
output_file.write("\n") # json.dump does not write trailing newline
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for modifying AST nodes"""
14+
from __future__ import absolute_import

0 commit comments

Comments
 (0)