diff --git a/tools/compatibility/v2/files.py b/tools/compatibility/v2/files.py index 055e30a1c5..b385274093 100644 --- a/tools/compatibility/v2/files.py +++ b/tools/compatibility/v2/files.py @@ -13,8 +13,10 @@ """Classes for updating code in files.""" from __future__ import absolute_import -import os +from abc import abstractmethod +import json import logging +import os import pasta @@ -23,11 +25,11 @@ LOGGER = logging.getLogger(__name__) -class PyFileUpdater(object): - """A class for updating Python (``*.py``) files.""" +class FileUpdater(object): + """An abstract class for updating files.""" def __init__(self, input_path, output_path): - """Creates a ``PyFileUpdater`` for updating a Python file so that + """Creates a ``FileUpdater`` for updating a file so that it is compatible with v2 of the SageMaker Python SDK. Args: @@ -39,6 +41,17 @@ def __init__(self, input_path, output_path): self.input_path = input_path self.output_path = output_path + @abstractmethod + def update(self): + """Reads the input file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + + +class PyFileUpdater(FileUpdater): + """A class for updating Python (``*.py``) files.""" + def update(self): """Reads the input Python file, updates the code so that it is compatible with v2 of the SageMaker Python SDK, and writes the @@ -60,7 +73,7 @@ def _update_ast(self, input_ast): return ASTTransformer().visit(input_ast) def _read_input_file(self): - """Reads input file and parse as an abstract syntax tree (AST). + """Reads input file and parses it as an abstract syntax tree (AST). Returns: ast.Module: AST representation of the input file. @@ -84,3 +97,84 @@ def _write_output_file(self, output): with open(self.output_path, "w") as output_file: output_file.write(pasta.dump(output)) + + +class JupyterNotebookFileUpdater(FileUpdater): + """A class for updating Jupyter notebook (``*.ipynb``) files. + + For more on this file format, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat. + """ + + def update(self): + """Reads the input Jupyter notebook file, updates the code so that it is + compatible with v2 of the SageMaker Python SDK, and writes the + updated code to an output file. + """ + nb_json = self._read_input_file() + for cell in nb_json["cells"]: + if cell["cell_type"] == "code": + updated_source = self._update_code_from_cell(cell) + cell["source"] = updated_source + + self._write_output_file(nb_json) + + def _update_code_from_cell(self, cell): + """Updates the code from a code cell so that it is + compatible with v2 of the SageMaker Python SDK. + + Args: + cell (dict): A dictionary representation of a code cell from + a Jupyter notebook. For more info, see + https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells. + + Returns: + list[str]: A list of strings containing the lines of updated code that + can be used for the "source" attribute of a Jupyter notebook code cell. + """ + code = "".join(cell["source"]) + updated_ast = ASTTransformer().visit(pasta.parse(code)) + updated_code = pasta.dump(updated_ast) + return self._code_str_to_source_list(updated_code) + + def _code_str_to_source_list(self, code): + """Converts a string of code into a list for a Jupyter notebook code cell. + + Args: + code (str): Code to be converted. + + Returns: + list[str]: A list of strings containing the lines of code that + can be used for the "source" attribute of a Jupyter notebook code cell. + Each element of the list (i.e. line of code) contains a + trailing newline character ("\n") except for the last element. + """ + source_list = ["{}\n".format(s) for s in code.split("\n")] + source_list[-1] = source_list[-1].rstrip("\n") + return source_list + + def _read_input_file(self): + """Reads input file and parses it as JSON. + + Returns: + dict: JSON representation of the input file. + """ + with open(self.input_path) as input_file: + return json.load(input_file) + + def _write_output_file(self, output): + """Writes JSON to output file. Creates the directories for the output path, if needed. + + Args: + output (dict): JSON to save as the output file. + """ + output_dir = os.path.dirname(self.output_path) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + + if os.path.exists(self.output_path): + LOGGER.warning("Overwriting file %s", self.output_path) + + with open(self.output_path, "w") as output_file: + json.dump(output, output_file, indent=1) + output_file.write("\n") # json.dump does not write trailing newline diff --git a/tools/compatibility/v2/sagemaker_upgrade_v2.py b/tools/compatibility/v2/sagemaker_upgrade_v2.py index 04b5fad876..2238775e1a 100644 --- a/tools/compatibility/v2/sagemaker_upgrade_v2.py +++ b/tools/compatibility/v2/sagemaker_upgrade_v2.py @@ -14,30 +14,63 @@ from __future__ import absolute_import import argparse +import os import files +_EXT_TO_UPDATER_CLS = {".py": files.PyFileUpdater, ".ipynb": files.JupyterNotebookFileUpdater} -def _parse_and_validate_args(): + +def _update_file(input_file, output_file): + """Update a file to be compatible with v2 of the SageMaker Python SDK, + and write the updated source to the output file. + + Args: + input_file (str): The path to the file to be updated. + output_file (str): The output file destination. + + Raises: + ValueError: If the input and output filename extensions don't match, + or if the file extensions are neither ".py" nor ".ipynb". + """ + input_file_ext = os.path.splitext(input_file)[1] + output_file_ext = os.path.splitext(output_file)[1] + + if input_file_ext != output_file_ext: + raise ValueError( + "Mismatched file extensions: input: {}, output: {}".format( + input_file_ext, output_file_ext + ) + ) + + if input_file_ext not in _EXT_TO_UPDATER_CLS: + raise ValueError("Unrecognized file extension: {}".format(input_file_ext)) + + updater_cls = _EXT_TO_UPDATER_CLS[input_file_ext] + updater_cls(input_path=input_file, output_path=output_file).update() + + +def _parse_args(): """Parses CLI arguments""" parser = argparse.ArgumentParser( - description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK." - "\nSimple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" + description="A tool to convert files to be compatible with v2 of the SageMaker Python SDK. " + "Simple usage: sagemaker_upgrade_v2.py --in-file foo.py --out-file bar.py" ) parser.add_argument( - "--in-file", help="If converting a single file, the name of the file to convert" + "--in-file", + help="If converting a single file, the file to convert. The file's extension " + "must be either '.py' or '.ipynb'.", ) parser.add_argument( "--out-file", - help="If converting a single file, the output file destination. If needed, " - "directories in the output file path are created. If the output file already exists, " - "it is overwritten.", + help="If converting a single file, the output file destination. The file's extension " + "must be either '.py' or '.ipynb'. If needed, directories in the output path are created. " + "If the output file already exists, it is overwritten.", ) return parser.parse_args() if __name__ == "__main__": - args = _parse_and_validate_args() - - files.PyFileUpdater(input_path=args.in_file, output_path=args.out_file).update() + args = _parse_args() + _update_file(args.in_file, args.out_file)