diff --git a/src/sagemaker/estimator.py b/src/sagemaker/estimator.py index 15763d844b..c4581f5a9f 100644 --- a/src/sagemaker/estimator.py +++ b/src/sagemaker/estimator.py @@ -44,6 +44,7 @@ UploadedCode, _region_supports_debugger, _region_supports_profiler, + _instance_type_supports_profiler, get_mp_parameters, tar_and_upload_dir, validate_source_dir, @@ -592,7 +593,9 @@ def __init__( self.max_retry_attempts = max_retry_attempts - if not _region_supports_profiler(self.sagemaker_session.boto_region_name): + if not _region_supports_profiler( + self.sagemaker_session.boto_region_name + ) or _instance_type_supports_profiler(self.instance_type): self.disable_profiler = True self.profiler_rule_configs = None diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index f86304e720..6b9653520e 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -1065,6 +1065,22 @@ def _region_supports_profiler(region_name): return region_name.lower() not in PROFILER_UNSUPPORTED_REGIONS +def _instance_type_supports_profiler(instance_type): + """Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature. + + Args: + instance_type (str): Name of the instance_type to check against. + + Returns: + bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature. + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match and match[1].startswith("trn"): + return True + return False + + def validate_version_or_image_args(framework_version, py_version, image_uri): """Checks if version or image arguments are specified. diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 1badd1be0c..51b240f156 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -1040,3 +1040,9 @@ def test_validate_unsupported_distributions_trainium_raises(): distribution=smdataparallel_enabled, instance_type="ml.trn1.32xlarge", ) + + +def test_instance_type_supports_profiler(): + assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True + assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False + assert fw_utils._instance_type_supports_profiler("local") is False