diff --git a/src/sagemaker/fw_utils.py b/src/sagemaker/fw_utils.py index 613bbd3742..ef99454a45 100644 --- a/src/sagemaker/fw_utils.py +++ b/src/sagemaker/fw_utils.py @@ -835,52 +835,6 @@ def validate_pytorch_distribution( raise ValueError(err_msg) -def validate_distribution_instance(sagemaker_session, distribution, instance_type): - """Check to prevent launching a modelparallel job on CPU only instances. - - Args: - sagemaker_session (sagemaker.session.Session): Session object which - manages interactions with Amazon SageMaker APIs and any other - AWS services needed. - distribution (dict): A dictionary with information to enable distributed training. - distribution = { - "smdistributed": { - "modelparallel": { - "enabled": True, - "parameters": { - ... - }, - }, - }, - ... - } - instance_type (str): A string representing the type of training instance selected. - - Raises: - ValueError: when modelparallel is enabled, if the instance_type does not support GPU. - """ - if "smdistributed" not in distribution: - # Distribution strategy other than smdistributed is selected - return - - if "modelparallel" not in distribution["smdistributed"]: - # Strategy other than modelparallel is selected - return - - if not distribution["smdistributed"]["modelparallel"]["enabled"]: - # Strategy modelparallel is not enabled - return - - instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types( - InstanceTypes=[f"{instance_type}"] - ) - if "GpuInfo" not in instance_desc["InstanceTypes"][0]: - raise ValueError( - f"modelparallel only runs on GPU-enabled instances. " - f"{instance_type} does not support GPU." - ) - - def python_deprecation_warning(framework, latest_supported_version): """Placeholder docstring""" return PYTHON_2_DEPRECATION_WARNING.format( diff --git a/src/sagemaker/pytorch/estimator.py b/src/sagemaker/pytorch/estimator.py index 622e79084c..153d4656d4 100644 --- a/src/sagemaker/pytorch/estimator.py +++ b/src/sagemaker/pytorch/estimator.py @@ -25,7 +25,6 @@ python_deprecation_warning, validate_version_or_image_args, validate_distribution, - validate_distribution_instance, ) from sagemaker.pytorch import defaults from sagemaker.pytorch.model import PyTorchModel @@ -221,12 +220,6 @@ def __init__( entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs ) if distribution is not None: - instance_type = self._get_instance_type() - # remove "ml." prefix - if instance_type[:3] == "ml.": - instance_type = instance_type[3:] - validate_distribution_instance(self.sagemaker_session, distribution, instance_type) - distribution = validate_distribution( distribution, self.instance_groups, diff --git a/tests/unit/test_fw_utils.py b/tests/unit/test_fw_utils.py index 5ecf196731..018255cf47 100644 --- a/tests/unit/test_fw_utils.py +++ b/tests/unit/test_fw_utils.py @@ -49,15 +49,6 @@ def sagemaker_session(): session_mock.sagemaker_client.describe_training_job = Mock( return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}} ) - session_mock.boto_session.client("ec2").describe_instance_types = Mock( - return_value={ - "InstanceTypes": [ - { - "CpuInfo": {}, - }, - ], - } - ) return session_mock @@ -742,31 +733,6 @@ def test_validate_smdistributed_not_raises(): ) -def test_validate_distribution_instance_no_smdistributed(sagemaker_session): - distribution = {} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_no_modelparallel(sagemaker_session): - distribution = {"smdistributed": {}} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session): - distribution = {"smdistributed": {"modelparallel": {"enabled": False}}} - instance_type = "mock_type" - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - -def test_validate_distribution_instance_raise(sagemaker_session): - distribution = {"smdistributed": {"modelparallel": {"enabled": True}}} - instance_type = "mock_type" - with pytest.raises(ValueError): - fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type) - - def test_validate_smdistributed_raises(): bad_args = [ {"smdistributed": "dummy"},