Skip to content

Commit 57b2a22

Browse files
authored
change: add .ipynb file support for v2 migration script (#1508)
1 parent 88518e0 commit 57b2a22

File tree

2 files changed

+142
-15
lines changed

2 files changed

+142
-15
lines changed

tools/compatibility/v2/files.py

+99-5
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
"""Classes for updating code in files."""
1414
from __future__ import absolute_import
1515

16-
import os
16+
from abc import abstractmethod
17+
import json
1718
import logging
19+
import os
1820

1921
import pasta
2022

@@ -23,11 +25,11 @@
2325
LOGGER = logging.getLogger(__name__)
2426

2527

26-
class PyFileUpdater(object):
27-
"""A class for updating Python (``*.py``) files."""
28+
class FileUpdater(object):
29+
"""An abstract class for updating files."""
2830

2931
def __init__(self, input_path, output_path):
30-
"""Creates a ``PyFileUpdater`` for updating a Python file so that
32+
"""Creates a ``FileUpdater`` for updating a file so that
3133
it is compatible with v2 of the SageMaker Python SDK.
3234
3335
Args:
@@ -39,6 +41,17 @@ def __init__(self, input_path, output_path):
3941
self.input_path = input_path
4042
self.output_path = output_path
4143

44+
@abstractmethod
45+
def update(self):
46+
"""Reads the input file, updates the code so that it is
47+
compatible with v2 of the SageMaker Python SDK, and writes the
48+
updated code to an output file.
49+
"""
50+
51+
52+
class PyFileUpdater(FileUpdater):
53+
"""A class for updating Python (``*.py``) files."""
54+
4255
def update(self):
4356
"""Reads the input Python file, updates the code so that it is
4457
compatible with v2 of the SageMaker Python SDK, and writes the
@@ -60,7 +73,7 @@ def _update_ast(self, input_ast):
6073
return ASTTransformer().visit(input_ast)
6174

6275
def _read_input_file(self):
63-
"""Reads input file and parse as an abstract syntax tree (AST).
76+
"""Reads input file and parses it as an abstract syntax tree (AST).
6477
6578
Returns:
6679
ast.Module: AST representation of the input file.
@@ -84,3 +97,84 @@ def _write_output_file(self, output):
8497

8598
with open(self.output_path, "w") as output_file:
8699
output_file.write(pasta.dump(output))
100+
101+
102+
class JupyterNotebookFileUpdater(FileUpdater):
103+
"""A class for updating Jupyter notebook (``*.ipynb``) files.
104+
105+
For more on this file format, see
106+
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat.
107+
"""
108+
109+
def update(self):
110+
"""Reads the input Jupyter notebook file, updates the code so that it is
111+
compatible with v2 of the SageMaker Python SDK, and writes the
112+
updated code to an output file.
113+
"""
114+
nb_json = self._read_input_file()
115+
for cell in nb_json["cells"]:
116+
if cell["cell_type"] == "code":
117+
updated_source = self._update_code_from_cell(cell)
118+
cell["source"] = updated_source
119+
120+
self._write_output_file(nb_json)
121+
122+
def _update_code_from_cell(self, cell):
123+
"""Updates the code from a code cell so that it is
124+
compatible with v2 of the SageMaker Python SDK.
125+
126+
Args:
127+
cell (dict): A dictionary representation of a code cell from
128+
a Jupyter notebook. For more info, see
129+
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells.
130+
131+
Returns:
132+
list[str]: A list of strings containing the lines of updated code that
133+
can be used for the "source" attribute of a Jupyter notebook code cell.
134+
"""
135+
code = "".join(cell["source"])
136+
updated_ast = ASTTransformer().visit(pasta.parse(code))
137+
updated_code = pasta.dump(updated_ast)
138+
return self._code_str_to_source_list(updated_code)
139+
140+
def _code_str_to_source_list(self, code):
141+
"""Converts a string of code into a list for a Jupyter notebook code cell.
142+
143+
Args:
144+
code (str): Code to be converted.
145+
146+
Returns:
147+
list[str]: A list of strings containing the lines of code that
148+
can be used for the "source" attribute of a Jupyter notebook code cell.
149+
Each element of the list (i.e. line of code) contains a
150+
trailing newline character ("\n") except for the last element.
151+
"""
152+
source_list = ["{}\n".format(s) for s in code.split("\n")]
153+
source_list[-1] = source_list[-1].rstrip("\n")
154+
return source_list
155+
156+
def _read_input_file(self):
157+
"""Reads input file and parses it as JSON.
158+
159+
Returns:
160+
dict: JSON representation of the input file.
161+
"""
162+
with open(self.input_path) as input_file:
163+
return json.load(input_file)
164+
165+
def _write_output_file(self, output):
166+
"""Writes JSON to output file. Creates the directories for the output path, if needed.
167+
168+
Args:
169+
output (dict): JSON to save as the output file.
170+
"""
171+
output_dir = os.path.dirname(self.output_path)
172+
if output_dir and not os.path.exists(output_dir):
173+
os.makedirs(output_dir)
174+
175+
if os.path.exists(self.output_path):
176+
LOGGER.warning("Overwriting file %s", self.output_path)
177+
178+
with open(self.output_path, "w") as output_file:
179+
json.dump(output, output_file, indent=1)
180+
output_file.write("\n") # json.dump does not write trailing newline

tools/compatibility/v2/sagemaker_upgrade_v2.py

+43-10
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,63 @@
1414
from __future__ import absolute_import
1515

1616
import argparse
17+
import os
1718

1819
import files
1920

21+
_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater}
2022

21-
def _parse_and_validate_args():
23+
24+
def _update_file(input_file, output_file):
25+
"""Update a file to be compatible with v2 of the SageMaker Python SDK,
26+
and write the updated source to the output file.
27+
28+
Args:
29+
input_file (str): The path to the file to be updated.
30+
output_file (str): The output file destination.
31+
32+
Raises:
33+
ValueError: If the input and output filename extensions don't match,
34+
or if the file extensions are neither ".py" nor ".ipynb".
35+
"""
36+
input_file_ext = os.path.splitext(input_file)[1]
37+
output_file_ext = os.path.splitext(output_file)[1]
38+
39+
if input_file_ext != output_file_ext:
40+
raise ValueError(
41+
"Mismatched file extensions: input: {}, output: {}".format(
42+
input_file_ext, output_file_ext
43+
)
44+
)
45+
46+
if input_file_ext not in _EXT_TO_UPDATER_CLS:
47+
raise ValueError("Unrecognized file extension: {}".format(input_file_ext))
48+
49+
updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext]
50+
updater_cls(input_path=input_file, output_path=output_file).update()
51+
52+
53+
def _parse_args():
2254
"""Parses CLI arguments"""
2355
parser = argparse.ArgumentParser(
24-
description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK."
25-
"\nSimple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py"
56+
description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. "
57+
"Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py"
2658
)
2759
parser.add_argument(
28-
"--in-file", help="If converting a single file, the name of the file to convert"
60+
"--in-file",
61+
help="If converting a single file, the file to convert. The file's extension "
62+
"must be either '.py' or '.ipynb'.",
2963
)
3064
parser.add_argument(
3165
"--out-file",
32-
help="If converting a single file, the output file destination. If needed, "
33-
"directories in the output file path are created. If the output file already exists, "
34-
"it is overwritten.",
66+
help="If converting a single file, the output file destination. The file's extension "
67+
"must be either '.py' or '.ipynb'. If needed, directories in the output path are created. "
68+
"If the output file already exists, it is overwritten.",
3569
)
3670

3771
return parser.parse_args()
3872

3973

4074
if __name__ == "__main__":
41-
args = _parse_and_validate_args()
42-
43-
files.PyFileUpdater(input_path=args.in_file, output_path=args.out_file).update()
75+
args = _parse_args()
76+
_update_file(args.in_file, args.out_file)

0 commit comments

Comments
 (0)