134
134
"1.12.0" ,
135
135
]
136
136
137
+ TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = [
138
+ "1.11.0"
139
+ ]
140
+
141
+ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = [
142
+ "torch_distributed"
143
+ ]
144
+
137
145
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel" , "modelparallel" ]
138
146
139
147
@@ -767,6 +775,10 @@ def validate_distribution(
767
775
f"Invalid training instance group { train_instance_group .instance_group_name } !"
768
776
)
769
777
instance_type = train_instance_group .instance_type
778
+ validate_supported_distributions (
779
+ instance_type = instance_type ,
780
+ distribution = distribution ,
781
+ )
770
782
validate_smdistributed (
771
783
instance_type = instance_type ,
772
784
framework_name = framework_name ,
@@ -782,6 +794,14 @@ def validate_distribution(
782
794
py_version = py_version ,
783
795
image_uri = image_uri ,
784
796
)
797
+ validate_torch_distributed_distribution (
798
+ instance_type = instance_type ,
799
+ distribution = distribution ,
800
+ framework_name = framework_name ,
801
+ framework_version = framework_version ,
802
+ py_version = py_version ,
803
+ image_uri = image_uri ,
804
+ )
785
805
warn_if_parameter_server_with_multi_gpu (
786
806
training_instance_type = instance_type , distribution = distribution
787
807
)
@@ -793,6 +813,10 @@ def validate_distribution(
793
813
instance_type = renamed_kwargs (
794
814
"train_instance_type" , "instance_type" , kwargs .get ("instance_type" ), kwargs
795
815
)
816
+ validate_supported_distributions (
817
+ instance_type = instance_type ,
818
+ distribution = distribution ,
819
+ )
796
820
validate_smdistributed (
797
821
instance_type = instance_type ,
798
822
framework_name = framework_name ,
@@ -808,11 +832,52 @@ def validate_distribution(
808
832
py_version = py_version ,
809
833
image_uri = image_uri ,
810
834
)
835
+ validate_torch_distributed_distribution (
836
+ instance_type = instance_type ,
837
+ distribution = distribution ,
838
+ framework_name = framework_name ,
839
+ framework_version = framework_version ,
840
+ py_version = py_version ,
841
+ image_uri = image_uri ,
842
+ )
811
843
warn_if_parameter_server_with_multi_gpu (
812
844
training_instance_type = instance_type , distribution = distribution
813
845
)
814
846
return distribution
815
847
848
+ def validate_supported_distributions (
849
+ instance_type , distribution
850
+ ):
851
+ """Check if the provided distribution strategy is supported for the instance_type
852
+
853
+ Args:
854
+ instance_type (str): A string representing the type of training instance selected.
855
+ distribution (dict): A dictionary with information to enable distributed training.
856
+ """
857
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
858
+ err_msg = ""
859
+ if match and match [1 ].startswith ("trn" ):
860
+ keys = distribution .keys ()
861
+ if len (keys ) == 0 :
862
+ return
863
+ elif len (keys ) == 1 :
864
+ distribution_strategy = keys [0 ]
865
+ if distribution_strategy != "torch_distributed" :
866
+ err_msg += (
867
+ f"Provided distribution strategy { distribution_strategy } is not supported by"
868
+ " Trainium instances yet.\n "
869
+ "Please specify one of the following supported distribution strategies:"
870
+ f" { TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES } \n "
871
+ )
872
+ elif len (keys ) > 1 :
873
+ err_msg += (
874
+ f"Multiple distribution strategies are not supported for Trainium instances yet."
875
+ "Please specify one of the following supported distribution strategies:"
876
+ f" { TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES } "
877
+ )
878
+
879
+ if err_msg :
880
+ raise ValueError (err_msg )
816
881
817
882
def validate_pytorch_distribution (
818
883
distribution , framework_name , framework_version , py_version , image_uri
@@ -870,6 +935,73 @@ def validate_pytorch_distribution(
870
935
if err_msg :
871
936
raise ValueError (err_msg )
872
937
938
+ def validate_torch_distributed_distribution (
939
+ instance_type , distribution , framework_name , framework_version , py_version , image_uri
940
+ ):
941
+ """Check if torch_distributed distribution strategy is correctly invoked by the user.
942
+
943
+ Args:
944
+ instance_type (str): A string representing the type of training instance selected.
945
+ distribution (dict): A dictionary with information to enable distributed training.
946
+ (Defaults to None if distributed training is not enabled.) For example:
947
+
948
+ .. code:: python
949
+
950
+ {
951
+ "torch_distributed": {
952
+ "enabled": True
953
+ }
954
+ }
955
+ framework_name (str): A string representing the name of framework selected.
956
+ framework_version (str): A string representing the framework version selected.
957
+ py_version (str): A string representing the python version selected.
958
+ image_uri (str): A string representing a Docker image URI.
959
+
960
+ Raises:
961
+ ValueError: if
962
+ `py_version` is not python3 or
963
+ `framework_version` is not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS
964
+ """
965
+ if framework_name and framework_name != "pytorch" :
966
+ # We need to validate only for PyTorch framework
967
+ return
968
+
969
+ torch_distributed_enabled = False
970
+ if "torch_distributed" in distribution :
971
+ torch_distributed_enabled = distribution .get ("torch_distributed" ).get ("enabled" , False )
972
+ if not torch_distributed_enabled :
973
+ # Distribution strategy other than pytorchddp is selected
974
+ return
975
+
976
+ err_msg = ""
977
+ if not image_uri :
978
+ # ignore framework_version and py_version if image_uri is set
979
+ # in case image_uri is not set, then both are mandatory
980
+ if framework_version not in TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS :
981
+ err_msg += (
982
+ f"Provided framework_version { framework_version } is not supported by"
983
+ " torch_distributed.\n "
984
+ "Please specify one of the supported framework versions:"
985
+ f" { TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS } \n "
986
+ )
987
+ if "py3" not in py_version :
988
+ err_msg += (
989
+ f"Provided py_version { py_version } is not supported by torch_distributed.\n "
990
+ "Please specify py_version>=py3"
991
+ )
992
+
993
+ # Check instance compatibility
994
+ match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
995
+ if match and match [1 ].startswith ("trn" ):
996
+ return
997
+ else :
998
+ err_msg += (
999
+ f"torch_distributed is currently supported only for trainium instances."
1000
+ " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1001
+ for information regarding distributed training on non-trainium instances"
1002
+ )
1003
+ if err_msg :
1004
+ raise ValueError (err_msg )
873
1005
874
1006
def python_deprecation_warning (framework , latest_supported_version ):
875
1007
"""Placeholder docstring"""
0 commit comments