134
134
"1.12.0" ,
135
135
]
136
136
137
- TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = [
138
- "1.11" ,
139
- "1.11.0"
140
- ]
137
+ TORCH_DISTRIBUTED_SUPPORTED_FRAMEWORK_VERSIONS = ["1.11" , "1.11.0" ]
141
138
142
- TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = [
143
- "torch_distributed"
144
- ]
139
+ TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed" ]
145
140
146
141
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel" , "modelparallel" ]
147
142
@@ -710,7 +705,14 @@ def _validate_smdataparallel_args(
710
705
711
706
712
707
def validate_distribution (
713
- distribution , instance_groups , framework_name , framework_version , py_version , image_uri , entry_point , kwargs
708
+ distribution ,
709
+ instance_groups ,
710
+ framework_name ,
711
+ framework_version ,
712
+ py_version ,
713
+ image_uri ,
714
+ entry_point ,
715
+ kwargs ,
714
716
):
715
717
"""Check if distribution strategy is correctly invoked by the user.
716
718
@@ -850,9 +852,8 @@ def validate_distribution(
850
852
)
851
853
return distribution
852
854
853
- def validate_distribution_for_instance_type (
854
- instance_type , distribution
855
- ):
855
+
856
+ def validate_distribution_for_instance_type (instance_type , distribution ):
856
857
"""Check if the provided distribution strategy is supported for the instance_type
857
858
858
859
Args:
@@ -869,11 +870,11 @@ def validate_distribution_for_instance_type(
869
870
distribution_strategy = keys [0 ]
870
871
if distribution_strategy != "torch_distributed" :
871
872
err_msg += (
872
- f"Provided distribution strategy { distribution_strategy } is not supported for"
873
- " Trainium instances.\n "
874
- "Please specify one of the following supported distribution strategies:"
875
- f" { TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES } \n "
876
- )
873
+ f"Provided distribution strategy { distribution_strategy } is not supported for"
874
+ " Trainium instances.\n "
875
+ "Please specify one of the following supported distribution strategies:"
876
+ f" { TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES } \n "
877
+ )
877
878
elif len (keys ) > 1 :
878
879
err_msg += (
879
880
f"Multiple distribution strategies are not supported for Trainium instances.\n "
@@ -884,6 +885,7 @@ def validate_distribution_for_instance_type(
884
885
if err_msg :
885
886
raise ValueError (err_msg )
886
887
888
+
887
889
def validate_pytorch_distribution (
888
890
distribution , framework_name , framework_version , py_version , image_uri
889
891
):
@@ -940,8 +942,15 @@ def validate_pytorch_distribution(
940
942
if err_msg :
941
943
raise ValueError (err_msg )
942
944
945
+
943
946
def validate_torch_distributed_distribution (
944
- instance_type , distribution , framework_name , framework_version , py_version , image_uri , entry_point ,
947
+ instance_type ,
948
+ distribution ,
949
+ framework_name ,
950
+ framework_version ,
951
+ py_version ,
952
+ image_uri ,
953
+ entry_point ,
945
954
):
946
955
"""Check if torch_distributed distribution strategy is correctly invoked by the user.
947
956
@@ -1003,20 +1012,22 @@ def validate_torch_distributed_distribution(
1003
1012
return
1004
1013
else :
1005
1014
err_msg += (
1006
- f"torch_distributed is currently supported only for trainium instances."
1007
- " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1015
+ f"torch_distributed is currently supported only for trainium instances."
1016
+ " Please refer https://sagemaker.readthedocs.io/en/stable/frameworks/pytorch/using_pytorch.html#distributed-pytorch-training \
1008
1017
for information regarding distributed training on non-trainium instances"
1009
1018
)
1010
1019
1011
1020
# Check entry point type
1012
1021
if not entry_point .endswith (".py" ):
1013
- err_msg += ("Unsupported entry point type for torch_distributed.\n "
1014
- "Only python programs (*.py) are supported."
1022
+ err_msg += (
1023
+ "Unsupported entry point type for torch_distributed.\n "
1024
+ "Only python programs (*.py) are supported."
1015
1025
)
1016
-
1026
+
1017
1027
if err_msg :
1018
1028
raise ValueError (err_msg )
1019
1029
1030
+
1020
1031
def python_deprecation_warning (framework , latest_supported_version ):
1021
1032
"""Placeholder docstring"""
1022
1033
return PYTHON_2_DEPRECATION_WARNING .format (
0 commit comments