Skip to content

Commit 1566386

Browse files
authored
fix: disable Debugger defaults in unsupported regions (aws#1272)
1 parent 400176f commit 1566386

File tree

4 files changed

+32
-2
lines changed

4 files changed

+32
-2
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
parse_s3_url,
3939
UploadedCode,
4040
validate_source_dir,
41+
_region_supports_debugger,
4142
)
4243
from sagemaker.job import _Job
4344
from sagemaker.local import LocalSession
@@ -1674,7 +1675,9 @@ def _validate_and_set_debugger_configs(self):
16741675
"""
16751676
Set defaults for debugging
16761677
"""
1677-
if self.debugger_hook_config is None:
1678+
if self.debugger_hook_config is None and _region_supports_debugger(
1679+
self.sagemaker_session.boto_region_name
1680+
):
16781681
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
16791682
elif not self.debugger_hook_config:
16801683
self.debugger_hook_config = None

src/sagemaker/fw_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"pytorch-serving": [1, 2, 0],
8585
}
8686

87+
DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"]
88+
8789

8890
def is_version_equal_or_higher(lowest_version, framework_version):
8991
"""Determine whether the ``framework_version`` is equal to or higher than
@@ -504,3 +506,16 @@ def python_deprecation_warning(framework, latest_supported_version):
504506
return PYTHON_2_DEPRECATION_WARNING.format(
505507
framework=framework, latest_supported_version=latest_supported_version
506508
)
509+
510+
511+
def _region_supports_debugger(region_name):
512+
"""Returns boolean indicating whether the region supports Amazon SageMaker Debugger.
513+
514+
Args:
515+
region_name (str): Name of the region to check against.
516+
517+
Returns:
518+
bool: Whether or not the region supports Amazon SageMaker Debugger.
519+
520+
"""
521+
return region_name.lower() not in DEBUGGER_UNSUPPORTED_REGIONS

src/sagemaker/tensorflow/estimator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,9 @@ def _validate_and_set_debugger_configs(self):
723723
)
724724
self.debugger_hook_config = None
725725
self.debugger_rule_configs = None
726-
elif self.debugger_hook_config is None:
726+
elif self.debugger_hook_config is None and fw._region_supports_debugger(
727+
self.sagemaker_session.boto_session.region_name
728+
):
727729
# Set defaults for debugging.
728730
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
729731

tests/unit/test_fw_utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,3 +1035,13 @@ def test_model_code_key_prefix_with_all_none_fail():
10351035
with pytest.raises(TypeError) as error:
10361036
fw_utils.model_code_key_prefix(None, None, None)
10371037
assert "expected string" in str(error)
1038+
1039+
1040+
def test_region_supports_debugger_feature_returns_true_for_supported_regions():
1041+
assert fw_utils._region_supports_debugger("us-west-2") is True
1042+
assert fw_utils._region_supports_debugger("us-east-2") is True
1043+
1044+
1045+
def test_region_supports_debugger_feature_returns_false_for_unsupported_regions():
1046+
assert fw_utils._region_supports_debugger("us-gov-west-1") is False
1047+
assert fw_utils._region_supports_debugger("us-iso-east-1") is False

0 commit comments

Comments
 (0)