Skip to content

change: add .ipynb file support for v2 migration script #1508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 19, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 99 additions & 5 deletions tools/compatibility/v2/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
"""Classes for updating code in files."""
from __future__ import absolute_import

import os
from abc import abstractmethod
import json
import logging
import os

import pasta

Expand All @@ -23,11 +25,11 @@
LOGGER = logging.getLogger(__name__)


class PyFileUpdater(object):
"""A class for updating Python (``*.py``) files."""
class FileUpdater(object):
"""An abstract class for updating files."""

def __init__(self, input_path, output_path):
"""Creates a ``PyFileUpdater`` for updating a Python file so that
"""Creates a ``FileUpdater`` for updating a file so that
it is compatible with v2 of the SageMaker Python SDK.

Args:
Expand All @@ -39,6 +41,17 @@ def __init__(self, input_path, output_path):
self.input_path = input_path
self.output_path = output_path

@abstractmethod
def update(self):
"""Reads the input file, updates the code so that it is
compatible with v2 of the SageMaker Python SDK, and writes the
updated code to an output file.
"""


class PyFileUpdater(FileUpdater):
"""A class for updating Python (``*.py``) files."""

def update(self):
"""Reads the input Python file, updates the code so that it is
compatible with v2 of the SageMaker Python SDK, and writes the
Expand All @@ -60,7 +73,7 @@ def _update_ast(self, input_ast):
return ASTTransformer().visit(input_ast)

def _read_input_file(self):
"""Reads input file and parse as an abstract syntax tree (AST).
"""Reads input file and parses it as an abstract syntax tree (AST).

Returns:
ast.Module: AST representation of the input file.
Expand All @@ -84,3 +97,84 @@ def _write_output_file(self, output):

with open(self.output_path, "w") as output_file:
output_file.write(pasta.dump(output))


class JupyterNotebookFileUpdater(FileUpdater):
"""A class for updating Jupyter notebook (``*.ipynb``) files.

For more on this file format, see
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat.
"""

def update(self):
"""Reads the input Jupyter notebook file, updates the code so that it is
compatible with v2 of the SageMaker Python SDK, and writes the
updated code to an output file.
"""
nb_json = self._read_input_file()
for cell in nb_json["cells"]:
if cell["cell_type"] == "code":
updated_source = self._update_code_from_cell(cell)
cell["source"] = updated_source

self._write_output_file(nb_json)

def _update_code_from_cell(self, cell):
"""Updates the code from a code cell so that it is
compatible with v2 of the SageMaker Python SDK.

Args:
cell (dict): A dictionary representation of a code cell from
a Jupyter notebook. For more info, see
https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells.

Returns:
list[str]: A list of strings containing the lines of updated code that
can be used for the "source" attribute of a Jupyter notebook code cell.
"""
code = "".join(cell["source"])
updated_ast = ASTTransformer().visit(pasta.parse(code))
updated_code = pasta.dump(updated_ast)
return self._code_str_to_source_list(updated_code)

def _code_str_to_source_list(self, code):
"""Converts a string of code into a list for a Jupyter notebook code cell.

Args:
code (str): Code to be converted.

Returns:
list[str]: A list of strings containing the lines of code that
can be used for the "source" attribute of a Jupyter notebook code cell.
Each element of the list (i.e. line of code) contains a
trailing newline character ("\n") except for the last element.
"""
source_list = ["{}\n".format(s) for s in code.split("\n")]
source_list[-1] = source_list[-1].rstrip("\n")
return source_list

def _read_input_file(self):
"""Reads input file and parses it as JSON.

Returns:
dict: JSON representation of the input file.
"""
with open(self.input_path) as input_file:
return json.load(input_file)

def _write_output_file(self, output):
"""Writes JSON to output file. Creates the directories for the output path, if needed.

Args:
output (dict): JSON to save as the output file.
"""
output_dir = os.path.dirname(self.output_path)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

if os.path.exists(self.output_path):
LOGGER.warning("Overwriting file %s", self.output_path)

with open(self.output_path, "w") as output_file:
json.dump(output, output_file, indent=1)
output_file.write("\n") # json.dump does not write trailing newline
53 changes: 43 additions & 10 deletions tools/compatibility/v2/sagemaker_upgrade_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,63 @@
from __future__ import absolute_import

import argparse
import os

import files

_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater}

def _parse_and_validate_args():

def _update_file(input_file, output_file):
"""Update a file to be compatible with v2 of the SageMaker Python SDK,
and write the updated source to the output file.

Args:
input_file (str): The path to the file to be updated.
output_file (str): The output file destination.

Raises:
ValueError: If the input and output filename extensions don't match,
or if the file extensions are neither ".py" nor ".ipynb".
"""
input_file_ext = os.path.splitext(input_file)[1]
output_file_ext = os.path.splitext(output_file)[1]

if input_file_ext != output_file_ext:
raise ValueError(
"Mismatched file extensions: input: {}, output: {}".format(
input_file_ext, output_file_ext
)
)

if input_file_ext not in _EXT_TO_UPDATER_CLS:
raise ValueError("Unrecognized file extension: {}".format(input_file_ext))

updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext]
updater_cls(input_path=input_file, output_path=output_file).update()


def _parse_args():
"""Parses CLI arguments"""
parser = argparse.ArgumentParser(
description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK."
"\nSimple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py"
description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. "
"Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py"
)
parser.add_argument(
"--in-file", help="If converting a single file, the name of the file to convert"
"--in-file",
help="If converting a single file, the file to convert. The file's extension "
"must be either '.py' or '.ipynb'.",
)
parser.add_argument(
"--out-file",
help="If converting a single file, the output file destination. If needed, "
"directories in the output file path are created. If the output file already exists, "
"it is overwritten.",
help="If converting a single file, the output file destination. The file's extension "
"must be either '.py' or '.ipynb'. If needed, directories in the output path are created. "
"If the output file already exists, it is overwritten.",
)

return parser.parse_args()


if __name__ == "__main__":
args = _parse_and_validate_args()

files.PyFileUpdater(input_path=args.in_file, output_path=args.out_file).update()
args = _parse_args()
_update_file(args.in_file, args.out_file)