Skip to content

Commit d95bd1d

Browse files
committed
change: Update for PT 2.5.1, SMP 2.8.0
1 parent 30fe0ee commit d95bd1d

File tree

4 files changed

+47
-13
lines changed

4 files changed

+47
-13
lines changed

src/sagemaker/fw_utils.py

+1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
"2.3.0",
156156
"2.3.1",
157157
"2.4.1",
158+
"2.5.1",
158159
]
159160

160161
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]

src/sagemaker/image_uri_config/pytorch-smp.json

+27-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
"2.2": "2.3.1",
1010
"2.2.0": "2.3.1",
1111
"2.3.1": "2.5.0",
12-
"2.4.1": "2.7.0"
12+
"2.4.1": "2.7.0",
13+
"2.5.1": "2.8.0"
1314
},
1415
"versions": {
1516
"2.0.1": {
@@ -186,6 +187,31 @@
186187
"us-west-2": "658645717510"
187188
},
188189
"repository": "smdistributed-modelparallel"
190+
},
191+
"2.8.0": {
192+
"py_versions": [
193+
"py311"
194+
],
195+
"registries": {
196+
"ap-northeast-1": "658645717510",
197+
"ap-northeast-2": "658645717510",
198+
"ap-northeast-3": "658645717510",
199+
"ap-south-1": "658645717510",
200+
"ap-southeast-1": "658645717510",
201+
"ap-southeast-2": "658645717510",
202+
"ca-central-1": "658645717510",
203+
"eu-central-1": "658645717510",
204+
"eu-north-1": "658645717510",
205+
"eu-west-1": "658645717510",
206+
"eu-west-2": "658645717510",
207+
"eu-west-3": "658645717510",
208+
"sa-east-1": "658645717510",
209+
"us-east-1": "658645717510",
210+
"us-east-2": "658645717510",
211+
"us-west-1": "658645717510",
212+
"us-west-2": "658645717510"
213+
},
214+
"repository": "smdistributed-modelparallel"
189215
}
190216
}
191217
}

src/sagemaker/image_uris.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -701,12 +701,16 @@ def get_training_image_uri(
701701
if "modelparallel" in distribution["smdistributed"]:
702702
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
703703
framework = "pytorch-smp"
704-
if (
705-
"p5" in instance_type
706-
or "2.1" in framework_version
707-
or "2.2" in framework_version
708-
or "2.3" in framework_version
709-
or "2.4" in framework_version
704+
supported_smp_pt_versions_cu124 = ("2.5",)
705+
supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4")
706+
if any(
707+
pt_version in framework_version
708+
for pt_version in supported_smp_pt_versions_cu124
709+
):
710+
container_version = "cu124"
711+
elif "p5" in instance_type or any(
712+
pt_version in framework_version
713+
for pt_version in supported_smp_pt_versions_cu121
710714
):
711715
container_version = "cu121"
712716
else:

tests/unit/sagemaker/image_uris/test_smp_v2.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,18 @@ def test_smp_v2(load_config):
3636
for region in ACCOUNTS.keys():
3737
for instance_type in CONTAINER_VERSIONS.keys():
3838
cuda_vers = CONTAINER_VERSIONS[instance_type]
39-
if (
40-
"2.1" in version
41-
or "2.2" in version
42-
or "2.3" in version
43-
or "2.4" in version
39+
supported_smp_pt_versions_cu124 = ("2.5",)
40+
supported_smp_pt_versions_cu121 = ("2.1", "2.2", "2.3", "2.4")
41+
if any(
42+
pt_version in version for pt_version in supported_smp_pt_versions_cu124
43+
):
44+
cuda_vers = "cu124"
45+
elif any(
46+
pt_version in version for pt_version in supported_smp_pt_versions_cu121
4447
):
4548
cuda_vers = "cu121"
4649

47-
if "2.3.1" == version or "2.4.1" == version:
50+
if version in ("2.3.1", "2.4.1", "2.5.1"):
4851
py_version = "py311"
4952

5053
uri = image_uris.get_training_image_uri(

0 commit comments

Comments
 (0)