Skip to content

Feature: support torchrun for gpu instances #3672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 66 additions & 20 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,17 @@
]


TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]

TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]

TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]

TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.12.1",
"1.13.1",
]

SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]

Expand Down Expand Up @@ -1055,9 +1061,8 @@ def validate_torch_distributed_distribution(
Raises:
ValueError: if
`py_version` is not python3 or
`framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
`framework_version` is not compatible with instance types
"""

torch_distributed_enabled = False
if "torch_distributed" in distribution:
torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False)
Expand All @@ -1066,30 +1071,36 @@ def validate_torch_distributed_distribution(
return

err_msg = ""

if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" torch_distributed.\n"
"Please specify one of the supported framework versions:"
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
if "py3" not in py_version:
err_msg += (
f"Provided py_version {py_version} is not supported by torch_distributed.\n"
"Please specify py_version>=py3"
"Please specify py_version>=py3\n"
)

# Check instance compatibility
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if not match[1].startswith("trn"):
# Check instance and framework_version compatibility
if _is_gpu_instance(instance_type):
if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
f" torch_distributed for instance {instance_type}.\n"
"Please specify one of the supported framework versions:"
f"{TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
elif _is_trainium_instance(instance_type):
if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
f" torch_distributed for instance {instance_type}.\n"
"Please specify one of the supported framework versions:"
f"{TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS} \n"
)
else:
err_msg += (
"torch_distributed is currently supported only for trainium instances.\n"
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
"for information regarding distributed training on non-trainium instances"
"Currently torch_distributed is supported only for GPU and Trainium instances.\n"
)

# Check entry point type
Expand All @@ -1103,6 +1114,41 @@ def validate_torch_distributed_distribution(
raise ValueError(err_msg)


def _is_gpu_instance(instance_type):
"""Returns bool indicating whether instance_type supports GPU

Args:
instance_type (str): Name of the instance_type to check against.

Returns:
bool: Whether or not the instance_type supports GPU
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if match[1].startswith("p") or match[1].startswith("g"):
return True
if instance_type == "local_gpu":
return True
return False


def _is_trainium_instance(instance_type):
"""Returns bool indicating whether instance_type is a Trainium instance

Args:
instance_type (str): Name of the instance_type to check against.

Returns:
bool: Whether or not the instance_type is a Trainium instance
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match and match[1].startswith("trn"):
return True
return False


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
87 changes: 83 additions & 4 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,15 +1005,14 @@ def test_validate_pytorchddp_raises():


def test_validate_torch_distributed_not_raises():

# Case 1: Framework is PyTorch, but distribution is not torch_distributed
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
torch_distributed_disabled = {"torch_distributed": {"enabled": False}}
fw_utils.validate_torch_distributed_distribution(
instance_type="ml.trn1.2xlarge",
distribution=torch_distributed_disabled,
framework_version="1.11.0",
py_version="py3",
image_uri="custom-container",
image_uri=None,
entry_point="train.py",
)
# Case 2: Distribution is torch_distributed enabled, supported framework and py versions
Expand All @@ -1027,7 +1026,22 @@ def test_validate_torch_distributed_not_raises():
distribution=torch_distributed_enabled,
framework_version=framework_version,
py_version="py3",
image_uri="custom-container",
image_uri=None,
entry_point="train.py",
)

# Case 3: Distribution is torch_distributed enabled, supported framework and instances
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
torch_distributed_gpu_supported_fw_versions = [
"1.13.1",
]
for framework_version in torch_distributed_gpu_supported_fw_versions:
fw_utils.validate_torch_distributed_distribution(
instance_type="ml.p3.8xlarge",
distribution=torch_distributed_enabled,
framework_version=framework_version,
py_version="py3",
image_uri=None,
entry_point="train.py",
)

Expand Down Expand Up @@ -1067,6 +1081,17 @@ def test_validate_torch_distributed_raises():
entry_point="train.sh",
)

# Case 4: Unsupported framework version for gpu instances
with pytest.raises(ValueError):
fw_utils.validate_torch_distributed_distribution(
instance_type="ml.p3.8xlarge",
distribution=torch_distributed_enabled,
framework_version="1.11.0",
py_version="py3",
image_uri=None,
entry_point="train.py",
)


def test_validate_unsupported_distributions_trainium_raises():
with pytest.raises(ValueError):
Expand Down Expand Up @@ -1102,3 +1127,57 @@ def test_instance_type_supports_profiler():
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
assert fw_utils._instance_type_supports_profiler("local") is False


def test_is_gpu_instance():
gpu_instance_types = [
"ml.p3.2xlarge",
"ml.p3.8xlarge",
"ml.p3.16xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.g4dn.xlarge",
"ml.g5.xlarge",
"ml.g5.48xlarge",
"local_gpu",
]
non_gpu_instance_types = [
"ml.t3.xlarge",
"ml.m5.8xlarge",
"ml.m5d.16xlarge",
"ml.c5.9xlarge",
"ml.r5.8xlarge",
]
for gpu_type in gpu_instance_types:
assert fw_utils._is_gpu_instance(gpu_type) is True
for non_gpu_type in non_gpu_instance_types:
assert fw_utils._is_gpu_instance(non_gpu_type) is False


def test_is_trainium_instance():
trainium_instance_types = [
"ml.trn1.2xlarge",
"ml.trn1.32xlarge",
]
non_trainum_instance_types = [
"ml.t3.xlarge",
"ml.m5.8xlarge",
"ml.m5d.16xlarge",
"ml.c5.9xlarge",
"ml.r5.8xlarge",
"ml.p3.2xlarge",
"ml.p3.8xlarge",
"ml.p3.16xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.g4dn.xlarge",
"ml.g5.xlarge",
"ml.g5.48xlarge",
"local_gpu",
]
for tr_type in trainium_instance_types:
assert fw_utils._is_trainium_instance(tr_type) is True
for non_tr_type in non_trainum_instance_types:
assert fw_utils._is_trainium_instance(non_tr_type) is False