29
29
30
30
logger = logging .getLogger ("sagemaker" )
31
31
32
- # TODO: consider creating a function for generating this command before removing this constant
33
- _SCRIPT_MODE_TENSORBOARD_WARNING = (
34
- "Tensorboard is not supported with script mode. You can run the following "
35
- "command: tensorboard --logdir %s --host localhost --port 6006 This can be "
36
- "run from anywhere with access to the S3 URI used as the logdir."
37
- )
38
-
39
32
40
33
class TensorFlow (Framework ):
41
34
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
42
35
43
36
__framework_name__ = "tensorflow"
44
- _SCRIPT_MODE_REPO_NAME = "tensorflow-scriptmode"
37
+ _ECR_REPO_NAME = "tensorflow-scriptmode"
45
38
46
39
LATEST_VERSION = defaults .LATEST_VERSION
47
40
48
41
_LATEST_1X_VERSION = "1.15.2"
49
42
50
43
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version .Version ("1.10.0" )
51
- _LOWEST_SCRIPT_MODE_ONLY_VERSION = version .Version ("1.13.1" )
52
-
53
44
_HIGHEST_PYTHON_2_VERSION = version .Version ("2.1.0" )
54
45
55
46
def __init__ (
@@ -59,7 +50,6 @@ def __init__(
59
50
model_dir = None ,
60
51
image_name = None ,
61
52
distributions = None ,
62
- script_mode = True ,
63
53
** kwargs
64
54
):
65
55
"""Initialize a ``TensorFlow`` estimator.
@@ -82,6 +72,8 @@ def __init__(
82
72
* *Local Mode with local sources (file:// instead of s3://)* - \
83
73
``/opt/ml/shared/model``
84
74
75
+ To disable having ``model_dir`` passed to your training script,
76
+ set ``model_dir=False``.
85
77
image_name (str): If specified, the estimator will use this image for training and
86
78
hosting, instead of selecting the appropriate SageMaker official image based on
87
79
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -114,8 +106,6 @@ def __init__(
114
106
}
115
107
}
116
108
117
- script_mode (bool): Whether or not to use the Script Mode TensorFlow images
118
- (default: True).
119
109
**kwargs: Additional kwargs passed to the Framework constructor.
120
110
121
111
.. tip::
@@ -154,7 +144,6 @@ def __init__(
154
144
self .model_dir = model_dir
155
145
self .distributions = distributions or {}
156
146
157
- self ._script_mode_enabled = script_mode
158
147
self ._validate_args (py_version = py_version , framework_version = self .framework_version )
159
148
160
149
def _validate_args (self , py_version , framework_version ):
@@ -171,30 +160,29 @@ def _validate_args(self, py_version, framework_version):
171
160
)
172
161
raise AttributeError (msg )
173
162
174
- if (not self ._script_mode_enabled ) and self ._only_script_mode_supported ():
175
- logger .warning (
176
- "Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
163
+ if self ._only_legacy_mode_supported () and self .image_name is None :
164
+ legacy_image_uri = fw .create_image_uri (
165
+ self .sagemaker_session .boto_region_name ,
166
+ "tensorflow" ,
167
+ self .train_instance_type ,
168
+ self .framework_version ,
169
+ self .py_version ,
177
170
)
178
- self ._script_mode_enabled = True
179
171
180
- if self ._only_legacy_mode_supported ():
181
172
# TODO: add link to docs to explain how to use legacy mode with v2
182
- logger .warning (
183
- "TF %s supports only legacy mode. If you were using any legacy mode parameters "
173
+ msg = (
174
+ "TF {} supports only legacy mode. Please supply the image URI directly with "
175
+ "'image_name={}' and set 'model_dir=False'. If you are using any legacy parameters "
184
176
"(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
185
- "make sure to pass them directly as hyperparameters instead." ,
186
- self .framework_version ,
187
- )
188
- self . _script_mode_enabled = False
177
+ "make sure to pass them directly as hyperparameters instead."
178
+ ). format ( self .framework_version , legacy_image_uri )
179
+
180
+ raise ValueError ( msg )
189
181
190
182
def _only_legacy_mode_supported (self ):
191
183
"""Placeholder docstring"""
192
184
return version .Version (self .framework_version ) <= self ._HIGHEST_LEGACY_MODE_ONLY_VERSION
193
185
194
- def _only_script_mode_supported (self ):
195
- """Placeholder docstring"""
196
- return version .Version (self .framework_version ) >= self ._LOWEST_SCRIPT_MODE_ONLY_VERSION
197
-
198
186
def _only_python_3_supported (self ):
199
187
"""Placeholder docstring"""
200
188
return version .Version (self .framework_version ) > self ._HIGHEST_PYTHON_2_VERSION
@@ -214,10 +202,6 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
214
202
job_details , model_channel_name
215
203
)
216
204
217
- model_dir = init_params ["hyperparameters" ].pop ("model_dir" , None )
218
- if model_dir is not None :
219
- init_params ["model_dir" ] = model_dir
220
-
221
205
image_name = init_params .pop ("image" )
222
206
framework , py_version , tag , script_mode = fw .framework_name_from_image (image_name )
223
207
if not framework :
@@ -226,8 +210,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
226
210
init_params ["image_name" ] = image_name
227
211
return init_params
228
212
229
- if script_mode is None :
230
- init_params ["script_mode" ] = False
213
+ model_dir = init_params ["hyperparameters" ].pop ("model_dir" , None )
214
+ if model_dir :
215
+ init_params ["model_dir" ] = model_dir
216
+ elif script_mode is None :
217
+ init_params ["model_dir" ] = False
231
218
232
219
init_params ["py_version" ] = py_version
233
220
@@ -239,6 +226,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
239
226
"1.4" if tag == "1.0" else fw .framework_version_from_tag (tag )
240
227
)
241
228
229
+ # Legacy images are required to be passed in explicitly.
230
+ if not script_mode :
231
+ init_params ["image_name" ] = image_name
232
+
242
233
training_job_name = init_params ["base_job_name" ]
243
234
if framework != cls .__framework_name__ :
244
235
raise ValueError (
@@ -309,27 +300,26 @@ def hyperparameters(self):
309
300
hyperparameters = super (TensorFlow , self ).hyperparameters ()
310
301
additional_hyperparameters = {}
311
302
312
- if self ._script_mode_enabled :
313
- mpi_enabled = False
314
-
315
- if "parameter_server" in self .distributions :
316
- ps_enabled = self .distributions ["parameter_server" ].get ("enabled" , False )
317
- additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = ps_enabled
303
+ if "parameter_server" in self .distributions :
304
+ ps_enabled = self .distributions ["parameter_server" ].get ("enabled" , False )
305
+ additional_hyperparameters [self .LAUNCH_PS_ENV_NAME ] = ps_enabled
318
306
319
- if "mpi" in self .distributions :
320
- mpi_dict = self .distributions ["mpi" ]
321
- mpi_enabled = mpi_dict .get ("enabled" , False )
322
- additional_hyperparameters [self .LAUNCH_MPI_ENV_NAME ] = mpi_enabled
307
+ mpi_enabled = False
308
+ if "mpi" in self .distributions :
309
+ mpi_dict = self .distributions ["mpi" ]
310
+ mpi_enabled = mpi_dict .get ("enabled" , False )
311
+ additional_hyperparameters [self .LAUNCH_MPI_ENV_NAME ] = mpi_enabled
323
312
324
- if mpi_dict .get ("processes_per_host" ):
325
- additional_hyperparameters [self .MPI_NUM_PROCESSES_PER_HOST ] = mpi_dict .get (
326
- "processes_per_host"
327
- )
328
-
329
- additional_hyperparameters [self .MPI_CUSTOM_MPI_OPTIONS ] = mpi_dict .get (
330
- "custom_mpi_options" , ""
313
+ if mpi_dict .get ("processes_per_host" ):
314
+ additional_hyperparameters [self .MPI_NUM_PROCESSES_PER_HOST ] = mpi_dict .get (
315
+ "processes_per_host"
331
316
)
332
317
318
+ additional_hyperparameters [self .MPI_CUSTOM_MPI_OPTIONS ] = mpi_dict .get (
319
+ "custom_mpi_options" , ""
320
+ )
321
+
322
+ if self .model_dir is not False :
333
323
self .model_dir = self .model_dir or self ._default_s3_path ("model" , mpi = mpi_enabled )
334
324
additional_hyperparameters ["model_dir" ] = self .model_dir
335
325
@@ -375,16 +365,13 @@ def train_image(self):
375
365
if self .image_name :
376
366
return self .image_name
377
367
378
- if self ._script_mode_enabled :
379
- return fw .create_image_uri (
380
- self .sagemaker_session .boto_region_name ,
381
- self ._SCRIPT_MODE_REPO_NAME ,
382
- self .train_instance_type ,
383
- self .framework_version ,
384
- self .py_version ,
385
- )
386
-
387
- return super (TensorFlow , self ).train_image ()
368
+ return fw .create_image_uri (
369
+ self .sagemaker_session .boto_region_name ,
370
+ self ._ECR_REPO_NAME ,
371
+ self .train_instance_type ,
372
+ self .framework_version ,
373
+ self .py_version ,
374
+ )
388
375
389
376
def transformer (
390
377
self ,
0 commit comments