Skip to content

Commit 2af5d43

Browse files
navinsoniJoseJuan98
authored andcommitted
fix: Revert "change: add a check to prevent launching a modelparallel job on CPU only instances" (aws#3280)
1 parent cba4c20 commit 2af5d43

File tree

3 files changed

+0
-87
lines changed

3 files changed

+0
-87
lines changed

src/sagemaker/fw_utils.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -835,52 +835,6 @@ def validate_pytorch_distribution(
835835
raise ValueError(err_msg)
836836

837837

838-
def validate_distribution_instance(sagemaker_session, distribution, instance_type):
839-
"""Check to prevent launching a modelparallel job on CPU only instances.
840-
841-
Args:
842-
sagemaker_session (sagemaker.session.Session): Session object which
843-
manages interactions with Amazon SageMaker APIs and any other
844-
AWS services needed.
845-
distribution (dict): A dictionary with information to enable distributed training.
846-
distribution = {
847-
"smdistributed": {
848-
"modelparallel": {
849-
"enabled": True,
850-
"parameters": {
851-
...
852-
},
853-
},
854-
},
855-
...
856-
}
857-
instance_type (str): A string representing the type of training instance selected.
858-
859-
Raises:
860-
ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
861-
"""
862-
if "smdistributed" not in distribution:
863-
# Distribution strategy other than smdistributed is selected
864-
return
865-
866-
if "modelparallel" not in distribution["smdistributed"]:
867-
# Strategy other than modelparallel is selected
868-
return
869-
870-
if not distribution["smdistributed"]["modelparallel"]["enabled"]:
871-
# Strategy modelparallel is not enabled
872-
return
873-
874-
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
875-
InstanceTypes=[f"{instance_type}"]
876-
)
877-
if "GpuInfo" not in instance_desc["InstanceTypes"][0]:
878-
raise ValueError(
879-
f"modelparallel only runs on GPU-enabled instances. "
880-
f"{instance_type} does not support GPU."
881-
)
882-
883-
884838
def python_deprecation_warning(framework, latest_supported_version):
885839
"""Placeholder docstring"""
886840
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
2727
validate_distribution,
28-
validate_distribution_instance,
2928
)
3029
from sagemaker.pytorch import defaults
3130
from sagemaker.pytorch.model import PyTorchModel
@@ -221,12 +220,6 @@ def __init__(
221220
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
222221
)
223222
if distribution is not None:
224-
instance_type = self._get_instance_type()
225-
# remove "ml." prefix
226-
if instance_type[:3] == "ml.":
227-
instance_type = instance_type[3:]
228-
validate_distribution_instance(self.sagemaker_session, distribution, instance_type)
229-
230223
distribution = validate_distribution(
231224
distribution,
232225
self.instance_groups,

tests/unit/test_fw_utils.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,6 @@ def sagemaker_session():
4949
session_mock.sagemaker_client.describe_training_job = Mock(
5050
return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
5151
)
52-
session_mock.boto_session.client("ec2").describe_instance_types = Mock(
53-
return_value={
54-
"InstanceTypes": [
55-
{
56-
"CpuInfo": {},
57-
},
58-
],
59-
}
60-
)
6152
return session_mock
6253

6354

@@ -742,31 +733,6 @@ def test_validate_smdistributed_not_raises():
742733
)
743734

744735

745-
def test_validate_distribution_instance_no_smdistributed(sagemaker_session):
746-
distribution = {}
747-
instance_type = "mock_type"
748-
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
749-
750-
751-
def test_validate_distribution_instance_no_modelparallel(sagemaker_session):
752-
distribution = {"smdistributed": {}}
753-
instance_type = "mock_type"
754-
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
755-
756-
757-
def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session):
758-
distribution = {"smdistributed": {"modelparallel": {"enabled": False}}}
759-
instance_type = "mock_type"
760-
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
761-
762-
763-
def test_validate_distribution_instance_raise(sagemaker_session):
764-
distribution = {"smdistributed": {"modelparallel": {"enabled": True}}}
765-
instance_type = "mock_type"
766-
with pytest.raises(ValueError):
767-
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
768-
769-
770736
def test_validate_smdistributed_raises():
771737
bad_args = [
772738
{"smdistributed": "dummy"},

0 commit comments

Comments
 (0)