13
13
"""Classes for updating code in files."""
14
14
from __future__ import absolute_import
15
15
16
- import os
16
+ from abc import abstractmethod
17
+ import json
17
18
import logging
19
+ import os
18
20
19
21
import pasta
20
22
23
25
LOGGER = logging .getLogger (__name__ )
24
26
25
27
26
- class PyFileUpdater (object ):
27
- """A class for updating Python (``*.py``) files."""
28
+ class FileUpdater (object ):
29
+ """An abstract class for updating files."""
28
30
29
31
def __init__ (self , input_path , output_path ):
30
- """Creates a ``PyFileUpdater `` for updating a Python file so that
32
+ """Creates a ``FileUpdater `` for updating a file so that
31
33
it is compatible with v2 of the SageMaker Python SDK.
32
34
33
35
Args:
@@ -39,6 +41,17 @@ def __init__(self, input_path, output_path):
39
41
self .input_path = input_path
40
42
self .output_path = output_path
41
43
44
+ @abstractmethod
45
+ def update (self ):
46
+ """Reads the input file, updates the code so that it is
47
+ compatible with v2 of the SageMaker Python SDK, and writes the
48
+ updated code to an output file.
49
+ """
50
+
51
+
52
+ class PyFileUpdater (FileUpdater ):
53
+ """A class for updating Python (``*.py``) files."""
54
+
42
55
def update (self ):
43
56
"""Reads the input Python file, updates the code so that it is
44
57
compatible with v2 of the SageMaker Python SDK, and writes the
@@ -60,7 +73,7 @@ def _update_ast(self, input_ast):
60
73
return ASTTransformer ().visit (input_ast )
61
74
62
75
def _read_input_file (self ):
63
- """Reads input file and parse as an abstract syntax tree (AST).
76
+ """Reads input file and parses it as an abstract syntax tree (AST).
64
77
65
78
Returns:
66
79
ast.Module: AST representation of the input file.
@@ -84,3 +97,84 @@ def _write_output_file(self, output):
84
97
85
98
with open (self .output_path , "w" ) as output_file :
86
99
output_file .write (pasta .dump (output ))
100
+
101
+
102
+ class JupyterNotebookFileUpdater (FileUpdater ):
103
+ """A class for updating Jupyter notebook (``*.ipynb``) files.
104
+
105
+ For more on this file format, see
106
+ https://ipython.org/ipython-doc/dev/notebook/nbformat.html#nbformat.
107
+ """
108
+
109
+ def update (self ):
110
+ """Reads the input Jupyter notebook file, updates the code so that it is
111
+ compatible with v2 of the SageMaker Python SDK, and writes the
112
+ updated code to an output file.
113
+ """
114
+ nb_json = self ._read_input_file ()
115
+ for cell in nb_json ["cells" ]:
116
+ if cell ["cell_type" ] == "code" :
117
+ updated_source = self ._update_code_from_cell (cell )
118
+ cell ["source" ] = updated_source
119
+
120
+ self ._write_output_file (nb_json )
121
+
122
+ def _update_code_from_cell (self , cell ):
123
+ """Updates the code from a code cell so that it is
124
+ compatible with v2 of the SageMaker Python SDK.
125
+
126
+ Args:
127
+ cell (dict): A dictionary representation of a code cell from
128
+ a Jupyter notebook. For more info, see
129
+ https://ipython.org/ipython-doc/dev/notebook/nbformat.html#code-cells.
130
+
131
+ Returns:
132
+ list[str]: A list of strings containing the lines of updated code that
133
+ can be used for the "source" attribute of a Jupyter notebook code cell.
134
+ """
135
+ code = "" .join (cell ["source" ])
136
+ updated_ast = ASTTransformer ().visit (pasta .parse (code ))
137
+ updated_code = pasta .dump (updated_ast )
138
+ return self ._code_str_to_source_list (updated_code )
139
+
140
+ def _code_str_to_source_list (self , code ):
141
+ """Converts a string of code into a list for a Jupyter notebook code cell.
142
+
143
+ Args:
144
+ code (str): Code to be converted.
145
+
146
+ Returns:
147
+ list[str]: A list of strings containing the lines of code that
148
+ can be used for the "source" attribute of a Jupyter notebook code cell.
149
+ Each element of the list (i.e. line of code) contains a
150
+ trailing newline character ("\n ") except for the last element.
151
+ """
152
+ source_list = ["{}\n " .format (s ) for s in code .split ("\n " )]
153
+ source_list [- 1 ] = source_list [- 1 ].rstrip ("\n " )
154
+ return source_list
155
+
156
+ def _read_input_file (self ):
157
+ """Reads input file and parses it as JSON.
158
+
159
+ Returns:
160
+ dict: JSON representation of the input file.
161
+ """
162
+ with open (self .input_path ) as input_file :
163
+ return json .load (input_file )
164
+
165
+ def _write_output_file (self , output ):
166
+ """Writes JSON to output file. Creates the directories for the output path, if needed.
167
+
168
+ Args:
169
+ output (dict): JSON to save as the output file.
170
+ """
171
+ output_dir = os .path .dirname (self .output_path )
172
+ if output_dir and not os .path .exists (output_dir ):
173
+ os .makedirs (output_dir )
174
+
175
+ if os .path .exists (self .output_path ):
176
+ LOGGER .warning ("Overwriting file %s" , self .output_path )
177
+
178
+ with open (self .output_path , "w" ) as output_file :
179
+ json .dump (output , output_file , indent = 1 )
180
+ output_file .write ("\n " ) # json.dump does not write trailing newline
0 commit comments