15
15
16
16
import ast
17
17
18
+ import boto3
18
19
import six
19
20
21
+ from sagemaker .cli .compatibility .v2 .modifiers import framework_version
20
22
from sagemaker .cli .compatibility .v2 .modifiers .modifier import Modifier
23
+ from sagemaker import fw_utils
21
24
22
25
23
26
class TensorFlowLegacyModeConstructorUpgrader (Modifier ):
24
- """A class to turn legacy mode parameters into hyperparameters when
25
- instantiating a TensorFlow estimator.
27
+ """A class to turn legacy mode parameters into hyperparameters, disable the ``model_dir``
28
+ hyperparameter, and set the image URI when instantiating a TensorFlow estimator.
26
29
"""
27
30
28
31
LEGACY_MODE_PARAMETERS = (
@@ -32,6 +35,18 @@ class TensorFlowLegacyModeConstructorUpgrader(Modifier):
32
35
"training_steps" ,
33
36
)
34
37
38
+ def __init__ (self ):
39
+ """Initializes a ``TensorFlowLegacyModeConstructorUpgrader``."""
40
+ self ._region = None
41
+
42
+ @property
43
+ def region (self ):
44
+ """Returns the AWS region for constructing an ECR image URI."""
45
+ if self ._region is None :
46
+ self ._region = boto3 .Session ().region_name
47
+
48
+ return self ._region
49
+
35
50
def node_should_be_modified (self , node ):
36
51
"""Checks if the ``ast.Call`` node instantiates a TensorFlow estimator with legacy mode.
37
52
@@ -89,7 +104,7 @@ def _is_legacy_mode(self, node):
89
104
90
105
def modify_node (self , node ):
91
106
"""Modifies the ``ast.Call`` node's keywords to turn TensorFlow legacy mode parameters
92
- into hyperparameters and set ``script_mode =False``.
107
+ into hyperparameters and sets ``model_dir =False``.
93
108
94
109
The parameters that are converted into hyperparameters:
95
110
@@ -105,9 +120,11 @@ def modify_node(self, node):
105
120
additional_hps = {}
106
121
kw_to_remove = [] # remove keyword args after so that none are skipped during iteration
107
122
123
+ add_image_uri = True
124
+
108
125
for kw in node .keywords :
109
- if kw .arg == "script_mode" :
110
- # remove here because is set to False later regardless of current value
126
+ if kw .arg in ( "script_mode" , "model_dir" ) :
127
+ # model_dir is removed so that it can be set to False later
111
128
kw_to_remove .append (kw )
112
129
if kw .arg == "hyperparameters" and kw .value :
113
130
base_hps = dict (zip (kw .value .keys , kw .value .values ))
@@ -116,11 +133,17 @@ def modify_node(self, node):
116
133
hp_key = self ._hyperparameter_key_for_param (kw .arg )
117
134
additional_hps [hp_key ] = kw .value
118
135
kw_to_remove .append (kw )
136
+ if kw .arg == "image_name" :
137
+ add_image_uri = False
119
138
120
139
self ._remove_keywords (node , kw_to_remove )
121
140
self ._add_updated_hyperparameters (node , base_hps , additional_hps )
122
141
123
- node .keywords .append (ast .keyword (arg = "script_mode" , value = ast .NameConstant (value = False )))
142
+ if add_image_uri :
143
+ image_uri = self ._image_uri_from_args (node .keywords )
144
+ node .keywords .append (ast .keyword (arg = "image_name" , value = ast .Str (s = image_uri )))
145
+
146
+ node .keywords .append (ast .keyword (arg = "model_dir" , value = ast .NameConstant (value = False )))
124
147
125
148
def _hyperparameter_key_for_param (self , arg ):
126
149
"""Returns an ``ast.Str`` for a hyperparameter key replacing a legacy mode parameter."""
@@ -148,6 +171,21 @@ def _to_ast_keyword(self, hps):
148
171
149
172
return None
150
173
174
+ def _image_uri_from_args (self , keywords ):
175
+ """Returns a legacy TensorFlow image URI based on the estimator arguments."""
176
+ tf_version = framework_version .FRAMEWORK_DEFAULTS ["TensorFlow" ]
177
+ instance_type = "ml.m4.xlarge" # CPU default (exact type doesn't matter)
178
+
179
+ for kw in keywords :
180
+ if kw .arg == "framework_version" :
181
+ tf_version = kw .value .s
182
+ if kw .arg == "train_instance_type" :
183
+ instance_type = kw .value .s
184
+
185
+ return fw_utils .create_image_uri (
186
+ self .region , "tensorflow" , instance_type , tf_version , "py2"
187
+ )
188
+
151
189
152
190
class TensorBoardParameterRemover (Modifier ):
153
191
"""A class for removing the ``run_tensorboard_locally`` parameter from ``fit()``."""
0 commit comments