@@ -853,18 +853,12 @@ def test_validate_smdataparallel_args_raises():
853
853
smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
854
854
855
855
# Cases {PT|TF2}
856
- # 1. None instance type
857
- # 2. incorrect instance type
858
- # 3. incorrect python version
859
- # 4. incorrect framework version
856
+ # 1. incorrect python version
857
+ # 2. incorrect framework version
860
858
861
859
bad_args = [
862
- (None , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
863
- ("ml.p3.2xlarge" , "tensorflow" , "2.3.1" , "py3" , smdataparallel_enabled ),
864
860
("ml.p3dn.24xlarge" , "tensorflow" , "2.3.1" , "py2" , smdataparallel_enabled ),
865
861
("ml.p3.16xlarge" , "tensorflow" , "1.3.1" , "py3" , smdataparallel_enabled ),
866
- (None , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
867
- ("ml.p3.2xlarge" , "pytorch" , "1.6.0" , "py3" , smdataparallel_enabled ),
868
862
("ml.p3dn.24xlarge" , "pytorch" , "1.6.0" , "py2" , smdataparallel_enabled ),
869
863
("ml.p3.16xlarge" , "pytorch" , "1.5.0" , "py3" , smdataparallel_enabled ),
870
864
]
@@ -966,74 +960,6 @@ def test_validate_smdataparallel_args_not_raises():
966
960
)
967
961
968
962
969
- def test_validate_pytorchddp_not_raises ():
970
- # Case 1: Framework is not PyTorch
971
- fw_utils .validate_pytorch_distribution (
972
- distribution = None ,
973
- framework_name = "tensorflow" ,
974
- framework_version = "2.9.1" ,
975
- py_version = "py3" ,
976
- image_uri = "custom-container" ,
977
- )
978
- # Case 2: Framework is PyTorch, but distribution is not PyTorchDDP
979
- pytorchddp_disabled = {"pytorchddp" : {"enabled" : False }}
980
- fw_utils .validate_pytorch_distribution (
981
- distribution = pytorchddp_disabled ,
982
- framework_name = "pytorch" ,
983
- framework_version = "1.10" ,
984
- py_version = "py3" ,
985
- image_uri = "custom-container" ,
986
- )
987
- # Case 3: Framework is PyTorch, Distribution is PyTorchDDP enabled, supported framework and py versions
988
- pytorchddp_enabled = {"pytorchddp" : {"enabled" : True }}
989
- pytorchddp_supported_fw_versions = [
990
- "1.10" ,
991
- "1.10.0" ,
992
- "1.10.2" ,
993
- "1.11" ,
994
- "1.11.0" ,
995
- "1.12" ,
996
- "1.12.0" ,
997
- "1.12.1" ,
998
- "1.13.1" ,
999
- "2.0.0" ,
1000
- "2.0.1" ,
1001
- "2.1.0" ,
1002
- "2.2.0" ,
1003
- ]
1004
- for framework_version in pytorchddp_supported_fw_versions :
1005
- fw_utils .validate_pytorch_distribution (
1006
- distribution = pytorchddp_enabled ,
1007
- framework_name = "pytorch" ,
1008
- framework_version = framework_version ,
1009
- py_version = "py3" ,
1010
- image_uri = "custom-container" ,
1011
- )
1012
-
1013
-
1014
- def test_validate_pytorchddp_raises ():
1015
- pytorchddp_enabled = {"pytorchddp" : {"enabled" : True }}
1016
- # Case 1: Unsupported framework version
1017
- with pytest .raises (ValueError ):
1018
- fw_utils .validate_pytorch_distribution (
1019
- distribution = pytorchddp_enabled ,
1020
- framework_name = "pytorch" ,
1021
- framework_version = "1.8" ,
1022
- py_version = "py3" ,
1023
- image_uri = None ,
1024
- )
1025
-
1026
- # Case 2: Unsupported Py version
1027
- with pytest .raises (ValueError ):
1028
- fw_utils .validate_pytorch_distribution (
1029
- distribution = pytorchddp_enabled ,
1030
- framework_name = "pytorch" ,
1031
- framework_version = "1.10" ,
1032
- py_version = "py2" ,
1033
- image_uri = None ,
1034
- )
1035
-
1036
-
1037
963
def test_validate_torch_distributed_not_raises ():
1038
964
# Case 1: Framework is PyTorch, but torch_distributed is not enabled
1039
965
torch_distributed_disabled = {"torch_distributed" : {"enabled" : False }}
0 commit comments