File tree 2 files changed +7
-2
lines changed
2 files changed +7
-2
lines changed Original file line number Diff line number Diff line change 50
50
51
51
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1" ,)
52
52
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge" , "ml.p3.2xlarge" )
53
- SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = ("ml.p3.16xlarge" , "ml.p3dn.24xlarge" , "local_gpu" )
53
+ SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
54
+ "ml.p3.16xlarge" ,
55
+ "ml.p3dn.24xlarge" ,
56
+ "ml.p4d.24xlarge" ,
57
+ "local_gpu" ,
58
+ )
54
59
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
55
60
"tensorflow" : ["2.3.0" , "2.3.1" ],
56
61
"pytorch" : ["1.6.0" ],
Original file line number Diff line number Diff line change @@ -552,7 +552,7 @@ def test_validate_version_or_image_args_raises():
552
552
def test_validate_smdistributed_not_raises ():
553
553
smdataparallel_enabled = {"smdistributed" : {"dataparallel" : {"enabled" : True }}}
554
554
smdataparallel_disabled = {"smdistributed" : {"dataparallel" : {"enabled" : False }}}
555
- instance_types = [ "ml.p3.16xlarge" , "ml.p3dn.24xlarge" ]
555
+ instance_types = list ( fw_utils . SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES )
556
556
557
557
good_args = [
558
558
(smdataparallel_enabled , "custom-container" ),
You can’t perform that action at this time.
0 commit comments