Skip to content

Commit 7d30d8c

Browse files
yongyanraoYongyan Rao
and
Yongyan Rao
authored
change: add a check to prevent launching a modelparallel job on CPU only instances (#3262)
Co-authored-by: Yongyan Rao <[email protected]>
1 parent 2d59111 commit 7d30d8c

File tree

3 files changed

+87
-0
lines changed

3 files changed

+87
-0
lines changed

src/sagemaker/fw_utils.py

+46
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,52 @@ def validate_pytorch_distribution(
835835
raise ValueError(err_msg)
836836

837837

838+
def validate_distribution_instance(sagemaker_session, distribution, instance_type):
839+
"""Check to prevent launching a modelparallel job on CPU only instances.
840+
841+
Args:
842+
sagemaker_session (sagemaker.session.Session): Session object which
843+
manages interactions with Amazon SageMaker APIs and any other
844+
AWS services needed.
845+
distribution (dict): A dictionary with information to enable distributed training.
846+
distribution = {
847+
"smdistributed": {
848+
"modelparallel": {
849+
"enabled": True,
850+
"parameters": {
851+
...
852+
},
853+
},
854+
},
855+
...
856+
}
857+
instance_type (str): A string representing the type of training instance selected.
858+
859+
Raises:
860+
ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
861+
"""
862+
if "smdistributed" not in distribution:
863+
# Distribution strategy other than smdistributed is selected
864+
return
865+
866+
if "modelparallel" not in distribution["smdistributed"]:
867+
# Strategy other than modelparallel is selected
868+
return
869+
870+
if not distribution["smdistributed"]["modelparallel"]["enabled"]:
871+
# Strategy modelparallel is not enabled
872+
return
873+
874+
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
875+
InstanceTypes=[f"{instance_type}"]
876+
)
877+
if "GpuInfo" not in instance_desc["InstanceTypes"][0]:
878+
raise ValueError(
879+
f"modelparallel only runs on GPU-enabled instances. "
880+
f"{instance_type} does not support GPU."
881+
)
882+
883+
838884
def python_deprecation_warning(framework, latest_supported_version):
839885
"""Placeholder docstring"""
840886
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
python_deprecation_warning,
2626
validate_version_or_image_args,
2727
validate_distribution,
28+
validate_distribution_instance,
2829
)
2930
from sagemaker.pytorch import defaults
3031
from sagemaker.pytorch.model import PyTorchModel
@@ -220,6 +221,12 @@ def __init__(
220221
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
221222
)
222223
if distribution is not None:
224+
instance_type = self._get_instance_type()
225+
# remove "ml." prefix
226+
if instance_type[:3] == "ml.":
227+
instance_type = instance_type[3:]
228+
validate_distribution_instance(self.sagemaker_session, distribution, instance_type)
229+
223230
distribution = validate_distribution(
224231
distribution,
225232
self.instance_groups,

tests/unit/test_fw_utils.py

+34
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,15 @@ def sagemaker_session():
4949
session_mock.sagemaker_client.describe_training_job = Mock(
5050
return_value={"ModelArtifacts": {"S3ModelArtifacts": "s3://m/m.tar.gz"}}
5151
)
52+
session_mock.boto_session.client("ec2").describe_instance_types = Mock(
53+
return_value={
54+
"InstanceTypes": [
55+
{
56+
"CpuInfo": {},
57+
},
58+
],
59+
}
60+
)
5261
return session_mock
5362

5463

@@ -733,6 +742,31 @@ def test_validate_smdistributed_not_raises():
733742
)
734743

735744

745+
def test_validate_distribution_instance_no_smdistributed(sagemaker_session):
746+
distribution = {}
747+
instance_type = "mock_type"
748+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
749+
750+
751+
def test_validate_distribution_instance_no_modelparallel(sagemaker_session):
752+
distribution = {"smdistributed": {}}
753+
instance_type = "mock_type"
754+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
755+
756+
757+
def test_validate_distribution_instance_disabled_modelparallel(sagemaker_session):
758+
distribution = {"smdistributed": {"modelparallel": {"enabled": False}}}
759+
instance_type = "mock_type"
760+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
761+
762+
763+
def test_validate_distribution_instance_raise(sagemaker_session):
764+
distribution = {"smdistributed": {"modelparallel": {"enabled": True}}}
765+
instance_type = "mock_type"
766+
with pytest.raises(ValueError):
767+
fw_utils.validate_distribution_instance(sagemaker_session, distribution, instance_type)
768+
769+
736770
def test_validate_smdistributed_raises():
737771
bad_args = [
738772
{"smdistributed": "dummy"},

0 commit comments

Comments
 (0)