File tree 3 files changed +26
-1
lines changed
3 files changed +26
-1
lines changed Original file line number Diff line number Diff line change 44
44
UploadedCode ,
45
45
_region_supports_debugger ,
46
46
_region_supports_profiler ,
47
+ _instance_type_supports_profiler ,
47
48
get_mp_parameters ,
48
49
tar_and_upload_dir ,
49
50
validate_source_dir ,
@@ -592,7 +593,9 @@ def __init__(
592
593
593
594
self .max_retry_attempts = max_retry_attempts
594
595
595
- if not _region_supports_profiler (self .sagemaker_session .boto_region_name ):
596
+ if not _region_supports_profiler (
597
+ self .sagemaker_session .boto_region_name
598
+ ) or _instance_type_supports_profiler (self .instance_type ):
596
599
self .disable_profiler = True
597
600
598
601
self .profiler_rule_configs = None
Original file line number Diff line number Diff line change @@ -1074,6 +1074,22 @@ def _region_supports_profiler(region_name):
1074
1074
return region_name .lower () not in PROFILER_UNSUPPORTED_REGIONS
1075
1075
1076
1076
1077
+ def _instance_type_supports_profiler (instance_type ):
1078
+ """Returns bool indicating whether instance_type supports SageMaker Debugger profiling feature.
1079
+
1080
+ Args:
1081
+ instance_type (str): Name of the instance_type to check against.
1082
+
1083
+ Returns:
1084
+ bool: Whether or not the region supports Amazon SageMaker Debugger profiling feature.
1085
+ """
1086
+ if isinstance (instance_type , str ):
1087
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
1088
+ if match and match [1 ].startswith ("trn" ):
1089
+ return True
1090
+ return False
1091
+
1092
+
1077
1093
def validate_version_or_image_args (framework_version , py_version , image_uri ):
1078
1094
"""Checks if version or image arguments are specified.
1079
1095
Original file line number Diff line number Diff line change @@ -1040,3 +1040,9 @@ def test_validate_unsupported_distributions_trainium_raises():
1040
1040
distribution = smdataparallel_enabled ,
1041
1041
instance_type = "ml.trn1.32xlarge" ,
1042
1042
)
1043
+
1044
+
1045
+ def test_instance_type_supports_profiler ():
1046
+ assert fw_utils ._instance_type_supports_profiler ("ml.trn1.xlarge" ) is True
1047
+ assert fw_utils ._instance_type_supports_profiler ("ml.m4.xlarge" ) is False
1048
+ assert fw_utils ._instance_type_supports_profiler ("local" ) is False
You can’t perform that action at this time.
0 commit comments