Skip to content

Commit 7612098

Browse files
ruhanprasadliujiaorr
authored and
root
committed
fix: add PT 2.2 support for smdistributed, pytorchddp, and torch_distributed distributions (aws#4480)
* Add support for smdistributed, pytorchddp, torch_distributed for PT 2.2 * formatting * formatting --------- Co-authored-by: liujiaor <[email protected]>
1 parent 226be67 commit 7612098

File tree

3 files changed

+3
-5
lines changed

3 files changed

+3
-5
lines changed

src/sagemaker/estimator.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -3285,10 +3285,7 @@ class Framework(EstimatorBase):
32853285
"""
32863286

32873287
_framework_name = None
3288-
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = (
3289-
"2.0.1-gpu-py310-cu121",
3290-
"2.0-gpu-py310-cu121",
3291-
)
3288+
UNSUPPORTED_DLC_IMAGE_FOR_SM_PARALLELISM = ("2.0.1-gpu-py310-cu121",)
32923289

32933290
def __init__(
32943291
self,

src/sagemaker/fw_utils.py

-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,6 @@
161161
"2.2.0",
162162
]
163163

164-
165164
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
166165
"1.13.1",
167166
"2.0.0",

tests/unit/test_fw_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -933,6 +933,7 @@ def test_validate_smdataparallel_args_not_raises():
933933
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
934934
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled),
935935
("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled),
936+
("ml.p3.16xlarge", "pytorch", "2.2.0", "py310", smdataparallel_enabled),
936937
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
937938
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
938939
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -957,6 +958,7 @@ def test_validate_smdataparallel_args_not_raises():
957958
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
958959
("ml.p3.16xlarge", "pytorch", "2.0.1", "py310", smdataparallel_enabled_custom_mpi),
959960
("ml.p3.16xlarge", "pytorch", "2.1.0", "py310", smdataparallel_enabled_custom_mpi),
961+
("ml.p3.16xlarge", "pytorch", "2.2.0", "py310", smdataparallel_enabled_custom_mpi),
960962
]
961963
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
962964
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)