diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 1f2cc9214e..fa8f35af0c 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -148,11 +148,17 @@ ] -TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11", "1.11.0"] - +TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"] TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"] - +TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [ + "1.11", + "1.11.0", + "1.12", + "1.12.0", + "1.12.1", + "1.13.1", +] SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"] @@ -1055,9 +1061,8 @@ def validate_torch_distributed_distribution( Raises: ValueError: if `py_version` is not python3 or - `framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS + `framework_version` is not compatible with instance types """ - torch_distributed_enabled = False if "torch_distributed" in distribution: torch_distributed_enabled = distribution.get("torch_distributed").get("enabled", False) @@ -1066,30 +1071,36 @@ def validate_torch_distributed_distribution( return err_msg = "" + if not image_uri: # ignore framework_version and py_version if image_uri is set # in case image_uri is not set, then both are mandatory - if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS: - err_msg += ( - f"Provided framework_version {framework_version} is not supported by" - " torch_distributed.\n" - "Please specify one of the supported framework versions:" - f" {TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS} \n" - ) if "py3" not in py_version: err_msg += ( f"Provided py_version {py_version} is not supported by torch_distributed.\n" - "Please specify py_version>=py3" + "Please specify py_version>=py3\n" ) - # Check instance compatibility - match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) - if match: - if not match[1].startswith("trn"): + # Check instance and framework_version compatibility + if _is_gpu_instance(instance_type): + if framework_version not in TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + f" torch_distributed for instance {instance_type}.\n" + "Please specify one of the supported framework versions:" + f"{TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS} \n" + ) + elif _is_trainium_instance(instance_type): + if framework_version not in TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS: + err_msg += ( + f"Provided framework_version {framework_version} is not supported by" + f" torch_distributed for instance {instance_type}.\n" + "Please specify one of the supported framework versions:" + f"{TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS} \n" + ) + else: err_msg += ( - "torch_distributed is currently supported only for trainium instances.\n" - " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \n" # noqa E501 # pylint: disable=c0301 - "for information regarding distributed training on non-trainium instances" + "Currently torch_distributed is supported only for GPU and Trainium instances.\n" ) # Check entry point type @@ -1103,6 +1114,41 @@ def validate_torch_distributed_distribution( raise ValueError(err_msg) +def _is_gpu_instance(instance_type): + """Returns bool indicating whether instance_type supports GPU + + Args: + instance_type (str): Name of the instance_type to check against. + + Returns: + bool: Whether or not the instance_type supports GPU + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match: + if match[1].startswith("p") or match[1].startswith("g"): + return True + if instance_type == "local_gpu": + return True + return False + + +def _is_trainium_instance(instance_type): + """Returns bool indicating whether instance_type is a Trainium instance + + Args: + instance_type (str): Name of the instance_type to check against. + + Returns: + bool: Whether or not the instance_type is a Trainium instance + """ + if isinstance(instance_type, str): + match = re.match(r"^ml[\._]([a-z\d]+)\.?\w*$", instance_type) + if match and match[1].startswith("trn"): + return True + return False + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 8645f05159..8de237d1d6 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -1005,15 +1005,14 @@ def test_validate_pytorchddp_raises(): def test_validate_torch_distributed_not_raises(): - - # Case 1: Framework is PyTorch, but distribution is not torch_distributed + # Case 1: Framework is PyTorch, but torch_distributed is not enabled torch_distributed_disabled = {"torch_distributed": {"enabled": False}} fw_utils.validate_torch_distributed_distribution( instance_type="ml.trn1.2xlarge", distribution=torch_distributed_disabled, framework_version="1.11.0", py_version="py3", - image_uri="custom-container", + image_uri=None, entry_point="train.py", ) # Case 2: Distribution is torch_distributed enabled, supported framework and py versions @@ -1027,7 +1026,22 @@ def test_validate_torch_distributed_not_raises(): distribution=torch_distributed_enabled, framework_version=framework_version, py_version="py3", - image_uri="custom-container", + image_uri=None, + entry_point="train.py", + ) + + # Case 3: Distribution is torch_distributed enabled, supported framework and instances + torch_distributed_enabled = {"torch_distributed": {"enabled": True}} + torch_distributed_gpu_supported_fw_versions = [ + "1.13.1", + ] + for framework_version in torch_distributed_gpu_supported_fw_versions: + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.p3.8xlarge", + distribution=torch_distributed_enabled, + framework_version=framework_version, + py_version="py3", + image_uri=None, entry_point="train.py", ) @@ -1067,6 +1081,17 @@ def test_validate_torch_distributed_raises(): entry_point="train.sh", ) + # Case 4: Unsupported framework version for gpu instances + with pytest.raises(ValueError): + fw_utils.validate_torch_distributed_distribution( + instance_type="ml.p3.8xlarge", + distribution=torch_distributed_enabled, + framework_version="1.11.0", + py_version="py3", + image_uri=None, + entry_point="train.py", + ) + def test_validate_unsupported_distributions_trainium_raises(): with pytest.raises(ValueError): @@ -1102,3 +1127,57 @@ def test_instance_type_supports_profiler(): assert fw_utils._instance_type_supports_profiler("ml.trn1.xlarge") is True assert fw_utils._instance_type_supports_profiler("ml.m4.xlarge") is False assert fw_utils._instance_type_supports_profiler("local") is False + + +def test_is_gpu_instance(): + gpu_instance_types = [ + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.g4dn.xlarge", + "ml.g5.xlarge", + "ml.g5.48xlarge", + "local_gpu", + ] + non_gpu_instance_types = [ + "ml.t3.xlarge", + "ml.m5.8xlarge", + "ml.m5d.16xlarge", + "ml.c5.9xlarge", + "ml.r5.8xlarge", + ] + for gpu_type in gpu_instance_types: + assert fw_utils._is_gpu_instance(gpu_type) is True + for non_gpu_type in non_gpu_instance_types: + assert fw_utils._is_gpu_instance(non_gpu_type) is False + + +def test_is_trainium_instance(): + trainium_instance_types = [ + "ml.trn1.2xlarge", + "ml.trn1.32xlarge", + ] + non_trainum_instance_types = [ + "ml.t3.xlarge", + "ml.m5.8xlarge", + "ml.m5d.16xlarge", + "ml.c5.9xlarge", + "ml.r5.8xlarge", + "ml.p3.2xlarge", + "ml.p3.8xlarge", + "ml.p3.16xlarge", + "ml.p3dn.24xlarge", + "ml.p4d.24xlarge", + "ml.p4de.24xlarge", + "ml.g4dn.xlarge", + "ml.g5.xlarge", + "ml.g5.48xlarge", + "local_gpu", + ] + for tr_type in trainium_instance_types: + assert fw_utils._is_trainium_instance(tr_type) is True + for non_tr_type in non_trainum_instance_types: + assert fw_utils._is_trainium_instance(non_tr_type) is False