Skip to content

Commit d0eb4a2

Browse files
authored
breaking: force image_uri to be passed for legacy TF images (#1539)
This change also reorganizes the TF unit tests a bit, and updates the tf_version fixture to include recent versions.
1 parent 0c5392f commit d0eb4a2

15 files changed

+1034
-1023
lines changed

src/sagemaker/tensorflow/estimator.py

+49-62
Original file line numberDiff line numberDiff line change
@@ -29,27 +29,18 @@
2929

3030
logger = logging.getLogger("sagemaker")
3131

32-
# TODO: consider creating a function for generating this command before removing this constant
33-
_SCRIPT_MODE_TENSORBOARD_WARNING = (
34-
"Tensorboard is not supported with script mode. You can run the following "
35-
"command: tensorboard --logdir %s --host localhost --port 6006 This can be "
36-
"run from anywhere with access to the S3 URI used as the logdir."
37-
)
38-
3932

4033
class TensorFlow(Framework):
4134
"""Handle end-to-end training and deployment of user-provided TensorFlow code."""
4235

4336
__framework_name__ = "tensorflow"
44-
_SCRIPT_MODE_REPO_NAME = "tensorflow-scriptmode"
37+
_ECR_REPO_NAME = "tensorflow-scriptmode"
4538

4639
LATEST_VERSION = defaults.LATEST_VERSION
4740

4841
_LATEST_1X_VERSION = "1.15.2"
4942

5043
_HIGHEST_LEGACY_MODE_ONLY_VERSION = version.Version("1.10.0")
51-
_LOWEST_SCRIPT_MODE_ONLY_VERSION = version.Version("1.13.1")
52-
5344
_HIGHEST_PYTHON_2_VERSION = version.Version("2.1.0")
5445

5546
def __init__(
@@ -59,7 +50,6 @@ def __init__(
5950
model_dir=None,
6051
image_name=None,
6152
distributions=None,
62-
script_mode=True,
6353
**kwargs
6454
):
6555
"""Initialize a ``TensorFlow`` estimator.
@@ -82,6 +72,8 @@ def __init__(
8272
* *Local Mode with local sources (file:// instead of s3://)* - \
8373
``/opt/ml/shared/model``
8474
75+
To disable having ``model_dir`` passed to your training script,
76+
set ``model_dir=False``.
8577
image_name (str): If specified, the estimator will use this image for training and
8678
hosting, instead of selecting the appropriate SageMaker official image based on
8779
framework_version and py_version. It can be an ECR url or dockerhub image and tag.
@@ -114,8 +106,6 @@ def __init__(
114106
}
115107
}
116108
117-
script_mode (bool): Whether or not to use the Script Mode TensorFlow images
118-
(default: True).
119109
**kwargs: Additional kwargs passed to the Framework constructor.
120110
121111
.. tip::
@@ -154,7 +144,6 @@ def __init__(
154144
self.model_dir = model_dir
155145
self.distributions = distributions or {}
156146

157-
self._script_mode_enabled = script_mode
158147
self._validate_args(py_version=py_version, framework_version=self.framework_version)
159148

160149
def _validate_args(self, py_version, framework_version):
@@ -171,30 +160,29 @@ def _validate_args(self, py_version, framework_version):
171160
)
172161
raise AttributeError(msg)
173162

174-
if (not self._script_mode_enabled) and self._only_script_mode_supported():
175-
logger.warning(
176-
"Legacy mode is deprecated in versions 1.13 and higher. Using script mode instead."
163+
if self._only_legacy_mode_supported() and self.image_name is None:
164+
legacy_image_uri = fw.create_image_uri(
165+
self.sagemaker_session.boto_region_name,
166+
"tensorflow",
167+
self.train_instance_type,
168+
self.framework_version,
169+
self.py_version,
177170
)
178-
self._script_mode_enabled = True
179171

180-
if self._only_legacy_mode_supported():
181172
# TODO: add link to docs to explain how to use legacy mode with v2
182-
logger.warning(
183-
"TF %s supports only legacy mode. If you were using any legacy mode parameters "
173+
msg = (
174+
"TF {} supports only legacy mode. Please supply the image URI directly with "
175+
"'image_name={}' and set 'model_dir=False'. If you are using any legacy parameters "
184176
"(training_steps, evaluation_steps, checkpoint_path, requirements_file), "
185-
"make sure to pass them directly as hyperparameters instead.",
186-
self.framework_version,
187-
)
188-
self._script_mode_enabled = False
177+
"make sure to pass them directly as hyperparameters instead."
178+
).format(self.framework_version, legacy_image_uri)
179+
180+
raise ValueError(msg)
189181

190182
def _only_legacy_mode_supported(self):
191183
"""Placeholder docstring"""
192184
return version.Version(self.framework_version) <= self._HIGHEST_LEGACY_MODE_ONLY_VERSION
193185

194-
def _only_script_mode_supported(self):
195-
"""Placeholder docstring"""
196-
return version.Version(self.framework_version) >= self._LOWEST_SCRIPT_MODE_ONLY_VERSION
197-
198186
def _only_python_3_supported(self):
199187
"""Placeholder docstring"""
200188
return version.Version(self.framework_version) > self._HIGHEST_PYTHON_2_VERSION
@@ -214,10 +202,6 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
214202
job_details, model_channel_name
215203
)
216204

217-
model_dir = init_params["hyperparameters"].pop("model_dir", None)
218-
if model_dir is not None:
219-
init_params["model_dir"] = model_dir
220-
221205
image_name = init_params.pop("image")
222206
framework, py_version, tag, script_mode = fw.framework_name_from_image(image_name)
223207
if not framework:
@@ -226,8 +210,11 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
226210
init_params["image_name"] = image_name
227211
return init_params
228212

229-
if script_mode is None:
230-
init_params["script_mode"] = False
213+
model_dir = init_params["hyperparameters"].pop("model_dir", None)
214+
if model_dir:
215+
init_params["model_dir"] = model_dir
216+
elif script_mode is None:
217+
init_params["model_dir"] = False
231218

232219
init_params["py_version"] = py_version
233220

@@ -239,6 +226,10 @@ def _prepare_init_params_from_job_description(cls, job_details, model_channel_na
239226
"1.4" if tag == "1.0" else fw.framework_version_from_tag(tag)
240227
)
241228

229+
# Legacy images are required to be passed in explicitly.
230+
if not script_mode:
231+
init_params["image_name"] = image_name
232+
242233
training_job_name = init_params["base_job_name"]
243234
if framework != cls.__framework_name__:
244235
raise ValueError(
@@ -309,27 +300,26 @@ def hyperparameters(self):
309300
hyperparameters = super(TensorFlow, self).hyperparameters()
310301
additional_hyperparameters = {}
311302

312-
if self._script_mode_enabled:
313-
mpi_enabled = False
314-
315-
if "parameter_server" in self.distributions:
316-
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
317-
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
303+
if "parameter_server" in self.distributions:
304+
ps_enabled = self.distributions["parameter_server"].get("enabled", False)
305+
additional_hyperparameters[self.LAUNCH_PS_ENV_NAME] = ps_enabled
318306

319-
if "mpi" in self.distributions:
320-
mpi_dict = self.distributions["mpi"]
321-
mpi_enabled = mpi_dict.get("enabled", False)
322-
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
307+
mpi_enabled = False
308+
if "mpi" in self.distributions:
309+
mpi_dict = self.distributions["mpi"]
310+
mpi_enabled = mpi_dict.get("enabled", False)
311+
additional_hyperparameters[self.LAUNCH_MPI_ENV_NAME] = mpi_enabled
323312

324-
if mpi_dict.get("processes_per_host"):
325-
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
326-
"processes_per_host"
327-
)
328-
329-
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
330-
"custom_mpi_options", ""
313+
if mpi_dict.get("processes_per_host"):
314+
additional_hyperparameters[self.MPI_NUM_PROCESSES_PER_HOST] = mpi_dict.get(
315+
"processes_per_host"
331316
)
332317

318+
additional_hyperparameters[self.MPI_CUSTOM_MPI_OPTIONS] = mpi_dict.get(
319+
"custom_mpi_options", ""
320+
)
321+
322+
if self.model_dir is not False:
333323
self.model_dir = self.model_dir or self._default_s3_path("model", mpi=mpi_enabled)
334324
additional_hyperparameters["model_dir"] = self.model_dir
335325

@@ -375,16 +365,13 @@ def train_image(self):
375365
if self.image_name:
376366
return self.image_name
377367

378-
if self._script_mode_enabled:
379-
return fw.create_image_uri(
380-
self.sagemaker_session.boto_region_name,
381-
self._SCRIPT_MODE_REPO_NAME,
382-
self.train_instance_type,
383-
self.framework_version,
384-
self.py_version,
385-
)
386-
387-
return super(TensorFlow, self).train_image()
368+
return fw.create_image_uri(
369+
self.sagemaker_session.boto_region_name,
370+
self._ECR_REPO_NAME,
371+
self.train_instance_type,
372+
self.framework_version,
373+
self.py_version,
374+
)
388375

389376
def transformer(
390377
self,

tests/conftest.py

+15
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,21 @@ def xgboost_version(request):
194194
"1.9.0",
195195
"1.10",
196196
"1.10.0",
197+
"1.11",
198+
"1.11.0",
199+
"1.12",
200+
"1.12.0",
201+
"1.13",
202+
"1.14",
203+
"1.14.0",
204+
"1.15",
205+
"1.15.0",
206+
"1.15.2",
207+
"2.0",
208+
"2.0.0",
209+
"2.0.1",
210+
"2.1",
211+
"2.1.0",
197212
],
198213
)
199214
def tf_version(request):

tests/integ/test_airflow_config.py

-1
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,6 @@ def test_tf_airflow_config_uploads_data_source_to_s3(sagemaker_session, cpu_inst
561561
train_instance_count=SINGLE_INSTANCE_COUNT,
562562
train_instance_type=cpu_instance_type,
563563
sagemaker_session=sagemaker_session,
564-
script_mode=True,
565564
framework_version=TensorFlow.LATEST_VERSION,
566565
py_version=PYTHON_VERSION,
567566
metric_definitions=[

tests/integ/test_horovod.py

-2
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def test_horovod_local_mode(sagemaker_local_session, instances, processes, tmpdi
5858
train_instance_type="local",
5959
sagemaker_session=sagemaker_local_session,
6060
py_version=integ.PYTHON_VERSION,
61-
script_mode=True,
6261
output_path=output_path,
6362
framework_version="1.12",
6463
distributions={"mpi": {"enabled": True, "processes_per_host": processes}},
@@ -106,7 +105,6 @@ def _create_and_fit_estimator(sagemaker_session, instance_type, tmpdir):
106105
train_instance_type=instance_type,
107106
sagemaker_session=sagemaker_session,
108107
py_version=integ.PYTHON_VERSION,
109-
script_mode=True,
110108
framework_version="1.12",
111109
distributions={"mpi": {"enabled": True}},
112110
)

tests/integ/test_tf_script_mode.py renamed to tests/integ/test_tf.py

-5
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def test_mnist_with_checkpoint_config(
5959
train_instance_count=1,
6060
train_instance_type=instance_type,
6161
sagemaker_session=sagemaker_session,
62-
script_mode=True,
6362
framework_version=tf_full_version,
6463
py_version=py_version,
6564
metric_definitions=[{"Name": "train:global_steps", "Regex": r"global_step\/sec:\s(.*)"}],
@@ -104,7 +103,6 @@ def test_server_side_encryption(sagemaker_session, tf_full_version, py_version):
104103
train_instance_count=1,
105104
train_instance_type="ml.c5.xlarge",
106105
sagemaker_session=sagemaker_session,
107-
script_mode=True,
108106
framework_version=tf_full_version,
109107
py_version=py_version,
110108
code_location=output_path,
@@ -141,7 +139,6 @@ def test_mnist_distributed(sagemaker_session, instance_type, tf_full_version, py
141139
train_instance_type=instance_type,
142140
sagemaker_session=sagemaker_session,
143141
py_version=py_version,
144-
script_mode=True,
145142
framework_version=tf_full_version,
146143
distributions=PARAMETER_SERVER_DISTRIBUTION,
147144
)
@@ -166,7 +163,6 @@ def test_mnist_async(sagemaker_session, cpu_instance_type, tf_full_version, py_v
166163
train_instance_type="ml.c5.4xlarge",
167164
py_version=tests.integ.PYTHON_VERSION,
168165
sagemaker_session=sagemaker_session,
169-
script_mode=True,
170166
# testing py-sdk functionality, no need to run against all TF versions
171167
framework_version=TensorFlow.LATEST_VERSION,
172168
tags=TAGS,
@@ -209,7 +205,6 @@ def test_deploy_with_input_handlers(sagemaker_session, instance_type, tf_full_ve
209205
train_instance_type=instance_type,
210206
py_version=py_version,
211207
sagemaker_session=sagemaker_session,
212-
script_mode=True,
213208
framework_version=tf_full_version,
214209
tags=TAGS,
215210
)

tests/integ/test_tf_efs_fsx.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ def test_mnist_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
6565
train_instance_count=1,
6666
train_instance_type=cpu_instance_type,
6767
sagemaker_session=sagemaker_session,
68-
script_mode=True,
6968
framework_version=TensorFlow.LATEST_VERSION,
7069
py_version=PY_VERSION,
7170
subnets=subnets,
@@ -105,7 +104,6 @@ def test_mnist_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
105104
train_instance_count=1,
106105
train_instance_type=cpu_instance_type,
107106
sagemaker_session=sagemaker_session,
108-
script_mode=True,
109107
framework_version=TensorFlow.LATEST_VERSION,
110108
py_version=PY_VERSION,
111109
subnets=subnets,
@@ -130,7 +128,7 @@ def test_mnist_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
130128
tests.integ.test_region() not in tests.integ.EFS_TEST_ENABLED_REGION,
131129
reason="EFS integration tests need to be fixed before running in all regions.",
132130
)
133-
def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
131+
def test_tuning_tf_efs(efs_fsx_setup, sagemaker_session, cpu_instance_type):
134132
role = efs_fsx_setup["role_name"]
135133
subnets = [efs_fsx_setup["subnet_id"]]
136134
security_group_ids = efs_fsx_setup["security_group_ids"]
@@ -140,7 +138,6 @@ def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session, cpu_instanc
140138
role=role,
141139
train_instance_count=1,
142140
train_instance_type=cpu_instance_type,
143-
script_mode=True,
144141
sagemaker_session=sagemaker_session,
145142
py_version=PY_VERSION,
146143
framework_version=TensorFlow.LATEST_VERSION,
@@ -178,7 +175,7 @@ def test_tuning_tf_script_mode_efs(efs_fsx_setup, sagemaker_session, cpu_instanc
178175
tests.integ.test_region() not in tests.integ.EFS_TEST_ENABLED_REGION,
179176
reason="EFS integration tests need to be fixed before running in all regions.",
180177
)
181-
def test_tuning_tf_script_mode_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
178+
def test_tuning_tf_lustre(efs_fsx_setup, sagemaker_session, cpu_instance_type):
182179
role = efs_fsx_setup["role_name"]
183180
subnets = [efs_fsx_setup["subnet_id"]]
184181
security_group_ids = efs_fsx_setup["security_group_ids"]
@@ -188,7 +185,6 @@ def test_tuning_tf_script_mode_lustre(efs_fsx_setup, sagemaker_session, cpu_inst
188185
role=role,
189186
train_instance_count=1,
190187
train_instance_type=cpu_instance_type,
191-
script_mode=True,
192188
sagemaker_session=sagemaker_session,
193189
py_version=PY_VERSION,
194190
framework_version=TensorFlow.LATEST_VERSION,

tests/integ/test_transformer.py

-1
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ def test_transform_tf_kms_network_isolation(sagemaker_session, cpu_instance_type
352352
train_instance_count=1,
353353
train_instance_type=cpu_instance_type,
354354
framework_version=TensorFlow.LATEST_VERSION,
355-
script_mode=True,
356355
py_version=PYTHON_VERSION,
357356
sagemaker_session=sagemaker_session,
358357
)

tests/integ/test_tuner.py

-1
Original file line numberDiff line numberDiff line change
@@ -599,7 +599,6 @@ def test_tuning_tf_script_mode(sagemaker_session, cpu_instance_type, tf_full_ver
599599
role="SageMakerRole",
600600
train_instance_count=1,
601601
train_instance_type=cpu_instance_type,
602-
script_mode=True,
603602
sagemaker_session=sagemaker_session,
604603
py_version=PYTHON_VERSION,
605604
framework_version=tf_full_version,

0 commit comments

Comments
 (0)