Skip to content

Commit 5207ddc

Browse files
author
Yongyan Rao
committed
change: add a check to prevent launching a modelparallel job on CPU only instances.
1 parent b4f05b8 commit 5207ddc

File tree

2 files changed

+53
-0
lines changed

2 files changed

+53
-0
lines changed

src/sagemaker/fw_utils.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,52 @@ def validate_distribution(
700700
return distribution
701701

702702

703+
def validate_distribution_instance(sagemaker_session, distribution, instance_type):
704+
"""Check to prevent launching a modelparallel job on CPU only instances.
705+
706+
Args:
707+
sagemaker_session (sagemaker.session.Session): Session object which
708+
manages interactions with Amazon SageMaker APIs and any other
709+
AWS services needed.
710+
distribution (dict): A dictionary with information to enable distributed training.
711+
distribution = {
712+
"smdistributed": {
713+
"modelparallel": {
714+
"enabled": True,
715+
"parameters": {
716+
...
717+
},
718+
},
719+
},
720+
...
721+
}
722+
instance_type (str): A string representing the type of training instance selected.
723+
724+
Raises:
725+
ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
726+
"""
727+
if "smdistributed" not in distribution:
728+
# Distribution strategy other than smdistributed is selected
729+
return
730+
731+
if "modelparallel" not in distribution["smdistributed"]:
732+
# Strategy other than modelparallel is selected
733+
return
734+
735+
if not distribution["smdistributed"]["modelparallel"]["enabled"]:
736+
# Strategy modelparallel is not enabled
737+
return
738+
739+
instance_desc = sagemaker_session.boto_session.client("ec2").describe_instance_types(
740+
InstanceTypes=[f"{instance_type}"]
741+
)
742+
if "GpuInfo" not in instance_desc["InstanceTypes"][0]:
743+
raise ValueError(
744+
f"modelparallel only runs on GPU-enabled instances. "
745+
f"{instance_type} does not support GPU."
746+
)
747+
748+
703749
def python_deprecation_warning(framework, latest_supported_version):
704750
"""Placeholder docstring"""
705751
return PYTHON_2_DEPRECATION_WARNING.format(

src/sagemaker/pytorch/estimator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
python_deprecation_warning,
2525
validate_version_or_image_args,
2626
validate_distribution,
27+
validate_distribution_instance,
2728
)
2829
from sagemaker.pytorch import defaults
2930
from sagemaker.pytorch.model import PyTorchModel
@@ -203,6 +204,12 @@ def __init__(
203204
entry_point, source_dir, hyperparameters, image_uri=image_uri, **kwargs
204205
)
205206
if distribution is not None:
207+
instance_type = self._get_instance_type()
208+
# remove "ml." prefix
209+
if instance_type[:3] == "ml.":
210+
instance_type = instance_type[3:]
211+
validate_distribution_instance(self.sagemaker_session, distribution, instance_type)
212+
206213
distribution = validate_distribution(
207214
distribution,
208215
self.instance_groups,

0 commit comments

Comments
 (0)