Skip to content

Commit f038668

Browse files
knakadlaurenyu
authored andcommitted
feature: allow disabling debugger_hook_config (#1194)
1 parent 33e96af commit f038668

File tree

2 files changed

+36
-2
lines changed

2 files changed

+36
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def _prepare_for_training(self, job_name=None):
327327
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
328328

329329
# Prepare rules and debugger configs for training.
330-
if self.rules and not self.debugger_hook_config:
330+
if self.rules and self.debugger_hook_config is None:
331331
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
332332
# If an object was provided without an S3 URI is not provided, default it for the customer.
333333
if self.debugger_hook_config and not self.debugger_hook_config.s3_output_path:
@@ -378,7 +378,7 @@ def _prepare_collection_configs(self):
378378
for rule in self.rules:
379379
self.collection_configs.update(rule.collection_configs)
380380
# Add the CollectionConfigs from DebuggerHookConfig to the set.
381-
if self.debugger_hook_config is not None:
381+
if self.debugger_hook_config:
382382
self.collection_configs.update(self.debugger_hook_config.collection_configs or [])
383383

384384
def latest_job_debugger_artifacts_path(self):
@@ -1676,6 +1676,8 @@ def _validate_and_set_debugger_configs(self):
16761676
"""
16771677
if self.debugger_hook_config is None:
16781678
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
1679+
elif not self.debugger_hook_config:
1680+
self.debugger_hook_config = None
16791681

16801682
def _stage_user_code_in_s3(self):
16811683
"""Upload the user training script to s3 and return the location.

tests/integ/test_debugger.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,38 @@ def test_mxnet_with_all_rules_and_configs(sagemaker_session, mxnet_full_version,
435435
_wait_and_assert_that_no_rule_jobs_errored(training_job=mx.latest_training_job)
436436

437437

438+
def test_mxnet_with_debugger_hook_config_disabled(
439+
sagemaker_session, mxnet_full_version, cpu_instance_type
440+
):
441+
with timeout(minutes=TRAINING_DEFAULT_TIMEOUT_MINUTES):
442+
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist_gluon.py")
443+
data_path = os.path.join(DATA_DIR, "mxnet_mnist")
444+
445+
mx = MXNet(
446+
entry_point=script_path,
447+
role="SageMakerRole",
448+
framework_version=mxnet_full_version,
449+
py_version=PYTHON_VERSION,
450+
train_instance_count=1,
451+
train_instance_type=cpu_instance_type,
452+
sagemaker_session=sagemaker_session,
453+
debugger_hook_config=False,
454+
)
455+
456+
train_input = mx.sagemaker_session.upload_data(
457+
path=os.path.join(data_path, "train"), key_prefix="integ-test-data/mxnet_mnist/train"
458+
)
459+
test_input = mx.sagemaker_session.upload_data(
460+
path=os.path.join(data_path, "test"), key_prefix="integ-test-data/mxnet_mnist/test"
461+
)
462+
463+
mx.fit({"train": train_input, "test": test_input})
464+
465+
job_description = mx.latest_training_job.describe()
466+
467+
assert job_description.get("DebugHookConfig") is None
468+
469+
438470
def _get_custom_rule(session):
439471
script_path = os.path.join(DATA_DIR, "mxnet_mnist", "my_custom_rule.py")
440472

0 commit comments

Comments
 (0)