@@ -933,6 +933,7 @@ def test_validate_smdataparallel_args_not_raises():
933
933
("ml.p3.16xlarge" , "pytorch" , "2.0.0" , "py310" , smdataparallel_enabled ),
934
934
("ml.p3.16xlarge" , "pytorch" , "2.0.1" , "py310" , smdataparallel_enabled ),
935
935
("ml.p3.16xlarge" , "pytorch" , "2.1.0" , "py310" , smdataparallel_enabled ),
936
+ ("ml.p3.16xlarge" , "pytorch" , "2.2.0" , "py310" , smdataparallel_enabled ),
936
937
("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py3" , smdataparallel_enabled_custom_mpi ),
937
938
("ml.p3.16xlarge" , "tensorflow" , "2.4.1" , "py37" , smdataparallel_enabled_custom_mpi ),
938
939
("ml.p3.16xlarge" , "tensorflow" , "2.4.3" , "py3" , smdataparallel_enabled_custom_mpi ),
@@ -957,6 +958,7 @@ def test_validate_smdataparallel_args_not_raises():
957
958
("ml.p3.16xlarge" , "pytorch" , "2.0.0" , "py310" , smdataparallel_enabled_custom_mpi ),
958
959
("ml.p3.16xlarge" , "pytorch" , "2.0.1" , "py310" , smdataparallel_enabled_custom_mpi ),
959
960
("ml.p3.16xlarge" , "pytorch" , "2.1.0" , "py310" , smdataparallel_enabled_custom_mpi ),
961
+ ("ml.p3.16xlarge" , "pytorch" , "2.2.0" , "py310" , smdataparallel_enabled_custom_mpi ),
960
962
]
961
963
for instance_type , framework_name , framework_version , py_version , distribution in good_args :
962
964
fw_utils ._validate_smdataparallel_args (
0 commit comments