@@ -835,6 +835,52 @@ def validate_pytorch_distribution(
835
835
raise ValueError (err_msg )
836
836
837
837
838
+ def validate_distribution_instance (sagemaker_session , distribution , instance_type ):
839
+ """Check to prevent launching a modelparallel job on CPU only instances.
840
+
841
+ Args:
842
+ sagemaker_session (sagemaker.session.Session): Session object which
843
+ manages interactions with Amazon SageMaker APIs and any other
844
+ AWS services needed.
845
+ distribution (dict): A dictionary with information to enable distributed training.
846
+ distribution = {
847
+ "smdistributed": {
848
+ "modelparallel": {
849
+ "enabled": True,
850
+ "parameters": {
851
+ ...
852
+ },
853
+ },
854
+ },
855
+ ...
856
+ }
857
+ instance_type (str): A string representing the type of training instance selected.
858
+
859
+ Raises:
860
+ ValueError: when modelparallel is enabled, if the instance_type does not support GPU.
861
+ """
862
+ if "smdistributed" not in distribution :
863
+ # Distribution strategy other than smdistributed is selected
864
+ return
865
+
866
+ if "modelparallel" not in distribution ["smdistributed" ]:
867
+ # Strategy other than modelparallel is selected
868
+ return
869
+
870
+ if not distribution ["smdistributed" ]["modelparallel" ]["enabled" ]:
871
+ # Strategy modelparallel is not enabled
872
+ return
873
+
874
+ instance_desc = sagemaker_session .boto_session .client ("ec2" ).describe_instance_types (
875
+ InstanceTypes = [f"{ instance_type } " ]
876
+ )
877
+ if "GpuInfo" not in instance_desc ["InstanceTypes" ][0 ]:
878
+ raise ValueError (
879
+ f"modelparallel only runs on GPU-enabled instances. "
880
+ f"{ instance_type } does not support GPU."
881
+ )
882
+
883
+
838
884
def python_deprecation_warning (framework , latest_supported_version ):
839
885
"""Placeholder docstring"""
840
886
return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments