Skip to content

Commit d69a8bd

Browse files
yl-toJoseJuan98
authored andcommitted
Feature: support torchrun for gpu instances (aws#3672)
1 parent c17081c commit d69a8bd

File tree

2 files changed

+149
-24
lines changed

2 files changed

+149
-24
lines changed

src/sagemaker/fw_utils.py

+66-20
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,17 @@
148148
]
149149

150150

151-
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
152-
151+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]
153152

154153
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
155-
154+
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
155+
"1.11",
156+
"1.11.0",
157+
"1.12",
158+
"1.12.0",
159+
"1.12.1",
160+
"1.13.1",
161+
]
156162

157163
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]
158164

@@ -1055,9 +1061,8 @@ def validate_torch_distributed_distribution(
10551061
Raises:
10561062
ValueError: if
10571063
`py_version` is not python3 or
1058-
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
1064+
`framework_version` is not compatible with instance types
10591065
"""
1060-
10611066
torch_distributed_enabled = False
10621067
if "torch_distributed" in distribution:
10631068
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
@@ -1066,30 +1071,36 @@ def validate_torch_distributed_distribution(
10661071
return
10671072

10681073
err_msg = ""
1074+
10691075
if not image_uri:
10701076
# ignore framework_version and py_version if image_uri is set
10711077
# in case image_uri is not set, then both are mandatory
1072-
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
1073-
err_msg += (
1074-
f"Provided framework_version {framework_version} is not supported by"
1075-
" torch_distributed.\n"
1076-
"Please specify one of the supported framework versions:"
1077-
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
1078-
)
10791078
if "py3" not in py_version:
10801079
err_msg += (
10811080
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
1082-
"Please specify py_version>=py3"
1081+
"Please specify py_version>=py3\n"
10831082
)
10841083

1085-
# Check instance compatibility
1086-
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1087-
if match:
1088-
if not match[1].startswith("trn"):
1084+
# Check instance and framework_version compatibility
1085+
if _is_gpu_instance(instance_type):
1086+
if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS:
1087+
err_msg += (
1088+
f"Provided framework_version {framework_version} is not supported by"
1089+
f" torch_distributed for instance {instance_type}.\n"
1090+
"Please specify one of the supported framework versions:"
1091+
f"{TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS} \n"
1092+
)
1093+
elif _is_trainium_instance(instance_type):
1094+
if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS:
1095+
err_msg += (
1096+
f"Provided framework_version {framework_version} is not supported by"
1097+
f" torch_distributed for instance {instance_type}.\n"
1098+
"Please specify one of the supported framework versions:"
1099+
f"{TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS} \n"
1100+
)
1101+
else:
10891102
err_msg += (
1090-
"torch_distributed is currently supported only for trainium instances.\n"
1091-
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
1092-
"for information regarding distributed training on non-trainium instances"
1103+
"Currently torch_distributed is supported only for GPU and Trainium instances.\n"
10931104
)
10941105

10951106
# Check entry point type
@@ -1103,6 +1114,41 @@ def validate_torch_distributed_distribution(
11031114
raise ValueError(err_msg)
11041115

11051116

1117+
def _is_gpu_instance(instance_type):
1118+
"""Returns bool indicating whether instance_type supports GPU
1119+
1120+
Args:
1121+
instance_type (str): Name of the instance_type to check against.
1122+
1123+
Returns:
1124+
bool: Whether or not the instance_type supports GPU
1125+
"""
1126+
if isinstance(instance_type, str):
1127+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1128+
if match:
1129+
if match[1].startswith("p") or match[1].startswith("g"):
1130+
return True
1131+
if instance_type == "local_gpu":
1132+
return True
1133+
return False
1134+
1135+
1136+
def _is_trainium_instance(instance_type):
1137+
"""Returns bool indicating whether instance_type is a Trainium instance
1138+
1139+
Args:
1140+
instance_type (str): Name of the instance_type to check against.
1141+
1142+
Returns:
1143+
bool: Whether or not the instance_type is a Trainium instance
1144+
"""
1145+
if isinstance(instance_type, str):
1146+
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
1147+
if match and match[1].startswith("trn"):
1148+
return True
1149+
return False
1150+
1151+
11061152
def python_deprecation_warning(framework, latest_supported_version):
11071153
"""Placeholder docstring"""
11081154
return PYTHON_2_DEPRECATION_WARNING.format(

tests/unit/test_fw_utils.py

+83-4
Original file line numberDiff line numberDiff line change
@@ -1005,15 +1005,14 @@ def test_validate_pytorchddp_raises():
10051005

10061006

10071007
def test_validate_torch_distributed_not_raises():
1008-
1009-
# Case 1: Framework is PyTorch, but distribution is not torch_distributed
1008+
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
10101009
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
10111010
fw_utils.validate_torch_distributed_distribution(
10121011
instance_type="ml.trn1.2xlarge",
10131012
distribution=torch_distributed_disabled,
10141013
framework_version="1.11.0",
10151014
py_version="py3",
1016-
image_uri="custom-container",
1015+
image_uri=None,
10171016
entry_point="train.py",
10181017
)
10191018
# Case 2: Distribution is torch_distributed enabled, supported framework and py versions
@@ -1027,7 +1026,22 @@ def test_validate_torch_distributed_not_raises():
10271026
distribution=torch_distributed_enabled,
10281027
framework_version=framework_version,
10291028
py_version="py3",
1030-
image_uri="custom-container",
1029+
image_uri=None,
1030+
entry_point="train.py",
1031+
)
1032+
1033+
# Case 3: Distribution is torch_distributed enabled, supported framework and instances
1034+
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
1035+
torch_distributed_gpu_supported_fw_versions = [
1036+
"1.13.1",
1037+
]
1038+
for framework_version in torch_distributed_gpu_supported_fw_versions:
1039+
fw_utils.validate_torch_distributed_distribution(
1040+
instance_type="ml.p3.8xlarge",
1041+
distribution=torch_distributed_enabled,
1042+
framework_version=framework_version,
1043+
py_version="py3",
1044+
image_uri=None,
10311045
entry_point="train.py",
10321046
)
10331047

@@ -1067,6 +1081,17 @@ def test_validate_torch_distributed_raises():
10671081
entry_point="train.sh",
10681082
)
10691083

1084+
# Case 4: Unsupported framework version for gpu instances
1085+
with pytest.raises(ValueError):
1086+
fw_utils.validate_torch_distributed_distribution(
1087+
instance_type="ml.p3.8xlarge",
1088+
distribution=torch_distributed_enabled,
1089+
framework_version="1.11.0",
1090+
py_version="py3",
1091+
image_uri=None,
1092+
entry_point="train.py",
1093+
)
1094+
10701095

10711096
def test_validate_unsupported_distributions_trainium_raises():
10721097
with pytest.raises(ValueError):
@@ -1102,3 +1127,57 @@ def test_instance_type_supports_profiler():
11021127
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
11031128
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
11041129
assert fw_utils._instance_type_supports_profiler("local") is False
1130+
1131+
1132+
def test_is_gpu_instance():
1133+
gpu_instance_types = [
1134+
"ml.p3.2xlarge",
1135+
"ml.p3.8xlarge",
1136+
"ml.p3.16xlarge",
1137+
"ml.p3dn.24xlarge",
1138+
"ml.p4d.24xlarge",
1139+
"ml.p4de.24xlarge",
1140+
"ml.g4dn.xlarge",
1141+
"ml.g5.xlarge",
1142+
"ml.g5.48xlarge",
1143+
"local_gpu",
1144+
]
1145+
non_gpu_instance_types = [
1146+
"ml.t3.xlarge",
1147+
"ml.m5.8xlarge",
1148+
"ml.m5d.16xlarge",
1149+
"ml.c5.9xlarge",
1150+
"ml.r5.8xlarge",
1151+
]
1152+
for gpu_type in gpu_instance_types:
1153+
assert fw_utils._is_gpu_instance(gpu_type) is True
1154+
for non_gpu_type in non_gpu_instance_types:
1155+
assert fw_utils._is_gpu_instance(non_gpu_type) is False
1156+
1157+
1158+
def test_is_trainium_instance():
1159+
trainium_instance_types = [
1160+
"ml.trn1.2xlarge",
1161+
"ml.trn1.32xlarge",
1162+
]
1163+
non_trainum_instance_types = [
1164+
"ml.t3.xlarge",
1165+
"ml.m5.8xlarge",
1166+
"ml.m5d.16xlarge",
1167+
"ml.c5.9xlarge",
1168+
"ml.r5.8xlarge",
1169+
"ml.p3.2xlarge",
1170+
"ml.p3.8xlarge",
1171+
"ml.p3.16xlarge",
1172+
"ml.p3dn.24xlarge",
1173+
"ml.p4d.24xlarge",
1174+
"ml.p4de.24xlarge",
1175+
"ml.g4dn.xlarge",
1176+
"ml.g5.xlarge",
1177+
"ml.g5.48xlarge",
1178+
"local_gpu",
1179+
]
1180+
for tr_type in trainium_instance_types:
1181+
assert fw_utils._is_trainium_instance(tr_type) is True
1182+
for non_tr_type in non_trainum_instance_types:
1183+
assert fw_utils._is_trainium_instance(non_tr_type) is False

0 commit comments

Comments
 (0)