Skip to content

Commit 5a3a150

Browse files
knikuremufiAmazon
andauthored
feature: Disable profiler for Trainium instance type (#3442)
Co-authored-by: Mufaddal Rohawala <[email protected]>
1 parent 9713203 commit 5a3a150

File tree

3 files changed

+26
-1
lines changed

3 files changed

+26
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
UploadedCode,
4545
_region_supports_debugger,
4646
_region_supports_profiler,
47+
_instance_type_supports_profiler,
4748
get_mp_parameters,
4849
tar_and_upload_dir,
4950
validate_source_dir,
@@ -592,7 +593,9 @@ def __init__(
592593

593594
self.max_retry_attempts = max_retry_attempts
594595

595-
if not _region_supports_profiler(self.sagemaker_session.boto_region_name):
596+
if not _region_supports_profiler(
597+
self.sagemaker_session.boto_region_name
598+
) or _instance_type_supports_profiler(self.instance_type):
596599
self.disable_profiler = True
597600

598601
self.profiler_rule_configs = None

src/sagemaker/fw_utils.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,22 @@ def _region_supports_profiler(region_name):
10741074
return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS
10751075

10761076

1077+
def _instance_type_supports_profiler(instance_type):
1078+
"""Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
1079+
1080+
Args:
1081+
instance_type (str): Name of the instance_type to check against.
1082+
1083+
Returns:
1084+
bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1085+
"""
1086+
if isinstance(instance_type, str):
1087+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1088+
if match and match[1].startswith("trn"):
1089+
return True
1090+
return False
1091+
1092+
10771093
def validate_version_or_image_args(framework_version, py_version, image_uri):
10781094
"""Checks if version or image arguments are specified.
10791095

tests/unit/test_fw_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1040,3 +1040,9 @@ def test_validate_unsupported_distributions_trainium_raises():
10401040
distribution=smdataparallel_enabled,
10411041
instance_type="ml.trn1.32xlarge",
10421042
)
1043+
1044+
1045+
def test_instance_type_supports_profiler():
1046+
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
1047+
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
1048+
assert fw_utils._instance_type_supports_profiler("local") is False

0 commit comments

Comments
 (0)