Skip to content

Commit c56d79d

Browse files
committed
feature: support torchrun for gpu instances
1 parent ae70340 commit c56d79d

File tree

2 files changed

+69
-7
lines changed

2 files changed

+69
-7
lines changed

src/sagemaker/fw_utils.py

+44-7
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,15 @@
148148
]
149149

150150

151-
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
151+
TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS = [
152+
"1.11",
153+
"1.11.0",
154+
"1.12",
155+
"1.12.0",
156+
"1.12.1",
157+
"1.13.1",
158+
]
159+
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]
152160

153161

154162
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
@@ -1069,12 +1077,13 @@ def validate_torch_distributed_distribution(
10691077
if not image_uri:
10701078
# ignore framework_version and py_version if image_uri is set
10711079
# in case image_uri is not set, then both are mandatory
1072-
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
1080+
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS or \
1081+
framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS:
10731082
err_msg += (
10741083
f"Provided framework_version {framework_version} is not supported by"
10751084
" torch_distributed.\n"
10761085
"Please specify one of the supported framework versions:"
1077-
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
1086+
f"{TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS} \n"
10781087
)
10791088
if "py3" not in py_version:
10801089
err_msg += (
@@ -1083,13 +1092,22 @@ def validate_torch_distributed_distribution(
10831092
)
10841093

10851094
# Check instance compatibility
1095+
if not _is_gpu_instance(instance_type):
1096+
err_msg += (
1097+
"torch_distributed is supported only for GPU instances.\n"
1098+
)
1099+
1100+
# Check version compatibility for GPU instance
10861101
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
10871102
if match:
1088-
if not match[1].startswith("trn"):
1103+
# Non-Trainium GPU instance but version earlier than 1.13.1
1104+
if not match[1].startswith("trn") and \
1105+
framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
10891106
err_msg += (
1090-
"torch_distributed is currently supported only for trainium instances.\n"
1091-
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
1092-
"for information regarding distributed training on non-trainium instances"
1107+
f"Provided framework_version {framework_version} is not supported by"
1108+
f" torch_distributed for instance {instance_type}.\n"
1109+
"Please specify one of the supported framework versions:"
1110+
f"{TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
10931111
)
10941112

10951113
# Check entry point type
@@ -1103,6 +1121,25 @@ def validate_torch_distributed_distribution(
11031121
raise ValueError(err_msg)
11041122

11051123

1124+
def _is_gpu_instance(instance_type):
1125+
"""Returns bool indicating whether instance_type supports GPU
1126+
1127+
Args:
1128+
instance_type (str): Name of the instance_type to check against.
1129+
1130+
Returns:
1131+
bool: Whether or not the instance_type supports GPU
1132+
"""
1133+
if isinstance(instance_type, str):
1134+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1135+
if match:
1136+
if (match[1].startswith("trn") or match[1].startswith("p") or match[1].startswith("g")):
1137+
return True
1138+
if instance_type == "local_gpu":
1139+
return True
1140+
return False
1141+
1142+
11061143
def python_deprecation_warning(framework, latest_supported_version):
11071144
"""Placeholder docstring"""
11081145
return PYTHON_2_DEPRECATION_WARNING.format(

tests/unit/test_fw_utils.py

+25
Original file line numberDiff line numberDiff line change
@@ -1102,3 +1102,28 @@ def test_instance_type_supports_profiler():
11021102
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
11031103
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
11041104
assert fw_utils._instance_type_supports_profiler("local") is False
1105+
1106+
def test_is_gpu_instance():
1107+
gpu_instance_types = [
1108+
"ml.p3.2xlarge",
1109+
"ml.p3.8xlarge",
1110+
"ml.p3.16xlarge",
1111+
"ml.p3dn.24xlarge",
1112+
"ml.p4d.24xlarge",
1113+
"ml.p4de.24xlarge",
1114+
"ml.g4dn.xlarge",
1115+
"ml.g5.xlarge",
1116+
"ml.g5.48xlarge",
1117+
"local_gpu"
1118+
]
1119+
non_gpu_instance_types = [
1120+
"ml.t3.xlarge",
1121+
"ml.m5.8xlarge",
1122+
"ml.m5d.16xlarge",
1123+
"ml.c5.9xlarge",
1124+
"ml.r5.8xlarge",
1125+
]
1126+
for gpu_type in gpu_instance_types:
1127+
assert fw_utils._is_gpu_instance(gpu_type) is True
1128+
for non_gpu_type in non_gpu_instance_types:
1129+
assert fw_utils._is_gpu_instance(non_gpu_type) is False

0 commit comments

Comments
 (0)