Skip to content

Commit 81ad62b

Browse files
authored
change: add class to read Python scripts and update code for v2 (#1497)
1 parent f7f0ac6 commit 81ad62b

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

tools/compatibility/v2/files.py

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Classes for updating code in files."""
14+
from __future__ import absolute_import
15+
16+
import os
17+
import logging
18+
19+
import pasta
20+
21+
from ast_transformer import ASTTransformer
22+
23+
LOGGER = logging.getLogger(__name__)
24+
25+
26+
class PyFileUpdater(object):
27+
"""A class for updating Python (``*.py``) files."""
28+
29+
def __init__(self, input_path, output_path):
30+
"""Creates a ``PyFileUpdater`` for updating a Python file so that
31+
it is compatible with v2 of the SageMaker Python SDK.
32+
33+
Args:
34+
input_path (str): Location of the input file.
35+
output_path (str): Desired location for the output file.
36+
If the directories don't already exist, then they are created.
37+
If a file exists at ``output_path``, then it is overwritten.
38+
"""
39+
self.input_path = input_path
40+
self.output_path = output_path
41+
42+
def update(self):
43+
"""Reads the input Python file, updates the code so that it is
44+
compatible with v2 of the SageMaker Python SDK, and writes the
45+
updated code to an output file.
46+
"""
47+
output = self._update_ast(self._read_input_file())
48+
self._write_output_file(output)
49+
50+
def _update_ast(self, input_ast):
51+
"""Updates an abstract syntax tree (AST) so that it is compatible
52+
with v2 of the SageMaker Python SDK.
53+
54+
Args:
55+
input_ast (ast.Module): AST to be updated for use with Python SDK v2.
56+
57+
Returns:
58+
ast.Module: Updated AST that is compatible with Python SDK v2.
59+
"""
60+
return ASTTransformer().visit(input_ast)
61+
62+
def _read_input_file(self):
63+
"""Reads input file and parse as an abstract syntax tree (AST).
64+
65+
Returns:
66+
ast.Module: AST representation of the input file.
67+
"""
68+
with open(self.input_path) as input_file:
69+
return pasta.parse(input_file.read())
70+
71+
def _write_output_file(self, output):
72+
"""Writes abstract syntax tree (AST) to output file.
73+
Creates the directories for the output path, if needed.
74+
75+
Args:
76+
output (ast.Module): AST to save as the output file.
77+
"""
78+
output_dir = os.path.dirname(self.output_path)
79+
if output_dir and not os.path.exists(output_dir):
80+
os.makedirs(output_dir)
81+
82+
if os.path.exists(self.output_path):
83+
LOGGER.warning("Overwriting file {}".format(self.output_path))
84+
85+
with open(self.output_path, "w") as output_file:
86+
output_file.write(pasta.dump(output))

0 commit comments

Comments
 (0)