Skip to content

Feature: support torchrun for gpu instances #3672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 44 additions & 7 deletions src/sagemaker/fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,15 @@
]


TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"]
TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS = [
"1.11",
"1.11.0",
"1.12",
"1.12.0",
"1.12.1",
"1.13.1",
]
TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]


TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
Expand Down Expand Up @@ -1069,12 +1077,13 @@ def validate_torch_distributed_distribution(
if not image_uri:
# ignore framework_version and py_version if image_uri is set
# in case image_uri is not set, then both are mandatory
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS or \
framework_version not in TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
f"Provided framework_version {framework_version} is not supported by"
" torch_distributed.\n"
"Please specify one of the supported framework versions:"
f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
f"{TORCH_DISTRIBUTED_TRAINIUM_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)
if "py3" not in py_version:
err_msg += (
Expand All @@ -1083,13 +1092,22 @@ def validate_torch_distributed_distribution(
)

# Check instance compatibility
if not _is_gpu_instance(instance_type):
err_msg += (
"torch_distributed is supported only for GPU instances.\n"
)

# Check version compatibility for GPU instance
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if not match[1].startswith("trn"):
# Non-Trainium GPU instance but version earlier than 1.13.1
if not match[1].startswith("trn") and \
framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS:
err_msg += (
"torch_distributed is currently supported only for trainium instances.\n"
" Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301
"for information regarding distributed training on non-trainium instances"
f"Provided framework_version {framework_version} is not supported by"
f" torch_distributed for instance {instance_type}.\n"
"Please specify one of the supported framework versions:"
f"{TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n"
)

# Check entry point type
Expand All @@ -1103,6 +1121,25 @@ def validate_torch_distributed_distribution(
raise ValueError(err_msg)


def _is_gpu_instance(instance_type):
"""Returns bool indicating whether instance_type supports GPU

Args:
instance_type (str): Name of the instance_type to check against.

Returns:
bool: Whether or not the instance_type supports GPU
"""
if isinstance(instance_type, str):
match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type)
if match:
if (match[1].startswith("trn") or match[1].startswith("p") or match[1].startswith("g")):
return True
if instance_type == "local_gpu":
return True
return False


def python_deprecation_warning(framework, latest_supported_version):
"""Placeholder docstring"""
return PYTHON_2_DEPRECATION_WARNING.format(
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/test_fw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,3 +1102,28 @@ def test_instance_type_supports_profiler():
assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True
assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False
assert fw_utils._instance_type_supports_profiler("local") is False

def test_is_gpu_instance():
gpu_instance_types = [
"ml.p3.2xlarge",
"ml.p3.8xlarge",
"ml.p3.16xlarge",
"ml.p3dn.24xlarge",
"ml.p4d.24xlarge",
"ml.p4de.24xlarge",
"ml.g4dn.xlarge",
"ml.g5.xlarge",
"ml.g5.48xlarge",
"local_gpu"
]
non_gpu_instance_types = [
"ml.t3.xlarge",
"ml.m5.8xlarge",
"ml.m5d.16xlarge",
"ml.c5.9xlarge",
"ml.r5.8xlarge",
]
for gpu_type in gpu_instance_types:
assert fw_utils._is_gpu_instance(gpu_type) is True
for non_gpu_type in non_gpu_instance_types:
assert fw_utils._is_gpu_instance(non_gpu_type) is False