Skip to content

Commit 1c60bca

Browse files
rahul003ddavydenko
authored andcommitted
fix: disable DebuggerHook and Rules for TF distributions (#290)
* fix: Disable hook for ps or mpi in TF Co-Authored-By: Denis Davydenko <[email protected]>
1 parent 8a992e6 commit 1c60bca

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1656,7 +1656,12 @@ def _prepare_for_training(self, job_name=None):
16561656
self._hyperparameters[JOB_NAME_PARAM_NAME] = self._current_job_name
16571657
self._hyperparameters[SAGEMAKER_REGION_PARAM_NAME] = self.sagemaker_session.boto_region_name
16581658

1659-
# Set defaults for debugging.
1659+
self._validate_and_set_debugger_configs()
1660+
1661+
def _validate_and_set_debugger_configs(self):
1662+
"""
1663+
Set defaults for debugging
1664+
"""
16601665
if self.debugger_hook_config is None:
16611666
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
16621667

src/sagemaker/tensorflow/estimator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import threading
2323
import time
2424

25+
from sagemaker.debugger import DebuggerHookConfig
2526
from sagemaker.estimator import Framework
2627
import sagemaker.fw_utils as fw
2728
from sagemaker.tensorflow.defaults import TF_VERSION
@@ -696,6 +697,31 @@ def _script_mode_enabled(self):
696697
"""Placeholder docstring"""
697698
return self.py_version == "py3" or self.script_mode
698699

700+
def _validate_and_set_debugger_configs(self):
701+
"""
702+
Disable Debugger Hook Config for PS and Horovod as they are not
703+
supported in smdebug 0.4.13, the current latest version of smdebug
704+
705+
Else, set default HookConfig
706+
"""
707+
ps_enabled = "parameter_server" in self.distributions and self.distributions[
708+
"parameter_server"
709+
].get("enabled", False)
710+
mpi_enabled = "mpi" in self.distributions and self.distributions["mpi"].get(
711+
"enabled", False
712+
)
713+
if ps_enabled or mpi_enabled:
714+
if self.debugger_hook_config is not None or self.debugger_rule_configs is not None:
715+
logger.info(
716+
"Amazon SageMaker Debugger does not currently support "
717+
"Parameter Server and MPI distributions"
718+
)
719+
self.debugger_hook_config = None
720+
self.debugger_rule_configs = None
721+
elif self.debugger_hook_config is None:
722+
# Set defaults for debugging.
723+
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
724+
699725
def train_image(self):
700726
"""Placeholder docstring"""
701727
if self.image_name:

tests/unit/test_tf_estimator.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,14 @@ def _hyperparameters(script_mode=False, horovod=False):
123123

124124

125125
def _create_train_job(
126-
tf_version, script_mode=False, horovod=False, repo_name=IMAGE_REPO_NAME, py_version="py2"
126+
tf_version,
127+
script_mode=False,
128+
horovod=False,
129+
ps=False,
130+
repo_name=IMAGE_REPO_NAME,
131+
py_version="py2",
127132
):
128-
return {
133+
conf = {
129134
"image": _get_full_cpu_image_uri(tf_version, repo=repo_name, py_version=py_version),
130135
"input_mode": "File",
131136
"input_config": [
@@ -153,11 +158,15 @@ def _create_train_job(
153158
"vpc_config": None,
154159
"metric_definitions": None,
155160
"experiment_config": None,
156-
"debugger_hook_config": {
161+
}
162+
163+
if not ps and not horovod:
164+
conf["debugger_hook_config"] = {
157165
"CollectionConfigurations": [],
158166
"S3OutputPath": "s3://{}/".format(BUCKET_NAME),
159-
},
160-
}
167+
}
168+
169+
return conf
161170

162171

163172
def _build_tf(
@@ -1116,7 +1125,7 @@ def test_tf_script_mode_ps(time, strftime, sagemaker_session):
11161125
assert call_names == ["train", "logs_for_job"]
11171126

11181127
expected_train_args = _create_train_job(
1119-
"1.11", script_mode=True, repo_name=SM_IMAGE_REPO_NAME, py_version="py3"
1128+
"1.11", script_mode=True, ps=True, repo_name=SM_IMAGE_REPO_NAME, py_version="py3"
11201129
)
11211130
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
11221131
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_PS_ENV_NAME] = json.dumps(True)

0 commit comments

Comments
 (0)