Skip to content

Commit eea124f

Browse files
ChaiBapchyaChoiByungWook
authored andcommitted
change: add p4d to smdataparallel supported instances (#538)
1 parent 47439f7 commit eea124f

File tree

2 files changed

+7
-2
lines changed

2 files changed

+7
-2
lines changed

src/sagemaker/fw_utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,12 @@
5050

5151
DEBUGGER_UNSUPPORTED_REGIONS = ("us-iso-east-1",)
5252
SINGLE_GPU_INSTANCE_TYPES = ("ml.p2.xlarge", "ml.p3.2xlarge")
53-
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = ("ml.p3.16xlarge", "ml.p3dn.24xlarge", "local_gpu")
53+
SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES = (
54+
"ml.p3.16xlarge",
55+
"ml.p3dn.24xlarge",
56+
"ml.p4d.24xlarge",
57+
"local_gpu",
58+
)
5459
SM_DATAPARALLEL_SUPPORTED_FRAMEWORK_VERSIONS = {
5560
"tensorflow": ["2.3.0", "2.3.1"],
5661
"pytorch": ["1.6.0"],

tests/unit/test_fw_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,7 @@ def test_validate_version_or_image_args_raises():
552552
def test_validate_smdistributed_not_raises():
553553
smdataparallel_enabled = {"smdistributed": {"dataparallel": {"enabled": True}}}
554554
smdataparallel_disabled = {"smdistributed": {"dataparallel": {"enabled": False}}}
555-
instance_types = ["ml.p3.16xlarge", "ml.p3dn.24xlarge"]
555+
instance_types = list(fw_utils.SM_DATAPARALLEL_SUPPORTED_INSTANCE_TYPES)
556556

557557
good_args = [
558558
(smdataparallel_enabled, "custom-container"),

0 commit comments

Comments
 (0)