diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index ef99454a45..613bbd3742 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -835,6 +835,52 @@ def validate_pytorch_distribution( raise ValueError(err_msg) +def validate_distribution_instance(sagemaker_session, distribution, instance_type): + """Check to prevent launching a modelparallel job on CPU only instances. + + Args: + sagemaker_session (sagemaker.session.Session): Session object which + manages interactions with Amazon SageMaker APIs and any other + AWS services needed. + distribution (dict): A dictionary with information to enable distributed training. + distribution = { + "smdistributed": { + "modelparallel": { + "enabled": True, + "parameters": { + ... + }, + }, + }, + ... + } + instance_type (str): A string representing the type of training instance selected. + + Raises: + ValueError: when modelparallel is enabled, if the instance_type does not support GPU. + """ + if "smdistributed" not in distribution: + # Distribution strategy other than smdistributed is selected + return + + if "modelparallel" not in distribution["smdistributed"]: + # Strategy other than modelparallel is selected + return + + if not distribution["smdistributed"]["modelparallel"]["enabled"]: + # Strategy modelparallel is not enabled + return + + instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( + InstanceTypes=[f"{instance_type}"] + ) + if "GpuInfo" not in instance_desc["InstanceTypes"][0]: + raise ValueError( + f"modelparallel only runs on GPU-enabled instances. " + f"{instance_type} does not support GPU." + ) + + def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 153d4656d4..622e79084c 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,6 +25,7 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, + validate_distribution_instance, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -220,6 +221,12 @@ def __init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) if distribution is not None: + instance_type = self._get_instance_type() + # remove "ml." prefix + if instance_type[:3] == "ml.": + instance_type = instance_type[3:] + validate_distribution_instance(self.sagemaker_session, distribution, instance_type) + distribution = validate_distribution( distribution, self.instance_groups, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 018255cf47..5ecf196731 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -49,6 +49,15 @@ def sagemaker_session(): session_mock.sagemaker_client.describe_training_job = Mock( return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} ) + session_mock.boto_session.client("ec2").describe_instance_types = Mock( + return_value={ + "InstanceTypes": [ + { + "CpuInfo": {}, + }, + ], + } + ) return session_mock @@ -733,6 +742,31 @@ def test_validate_smdistributed_not_raises(): ) +def test_validate_distribution_instance_no_smdistributed(sagemaker_session): + distribution = {} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_no_modelparallel(sagemaker_session): + distribution = {"smdistributed": {}} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session): + distribution = {"smdistributed": {"modelparallel": {"enabled": False}}} + instance_type = "mock_type" + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + +def test_validate_distribution_instance_raise(sagemaker_session): + distribution = {"smdistributed": {"modelparallel": {"enabled": True}}} + instance_type = "mock_type" + with pytest.raises(ValueError): + fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) + + def test_validate_smdistributed_raises(): bad_args = [ {"smdistributed": "dummy"},