Skip to content

Commit 90aa33b

Browse files
adtian2Andrew Tian
authored andcommitted
change: Upgrade smp to version 2.2 (aws#4479)
* upgrading smp to version 2.2 * fixing linting issue * fixing syntax error with multiline if statement * upgrading smp to version 2.2 * fixing linting issue * fixing syntax error with multiline if statement * fixing formatting --------- Co-authored-by: Andrew Tian <[email protected]>
1 parent c63c268 commit 90aa33b

File tree

4 files changed

+42
-4
lines changed

4 files changed

+42
-4
lines changed

src/sagemaker/fw_utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
"2.0.1",
142142
"2.1.0",
143143
"2.1.2",
144+
"2.2.0",
144145
],
145146
}
146147

@@ -160,7 +161,14 @@
160161
]
161162

162163

163-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0", "2.0.1", "2.1.0", "2.1.2"]
164+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = [
165+
"1.13.1",
166+
"2.0.0",
167+
"2.0.1",
168+
"2.1.0",
169+
"2.1.2",
170+
"2.2.0",
171+
]
164172

165173
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
166174
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [

src/sagemaker/image_uri_config/pytorch-smp.json

+27-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
],
66
"version_aliases": {
77
"2.0": "2.0.1",
8-
"2.1": "2.1.2"
8+
"2.1": "2.1.2",
9+
"2.2": "2.2.0"
910
},
1011
"versions": {
1112
"2.0.1": {
@@ -57,6 +58,31 @@
5758
"us-west-2": "658645717510"
5859
},
5960
"repository": "smdistributed-modelparallel"
61+
},
62+
"2.2.0": {
63+
"py_versions": [
64+
"py310"
65+
],
66+
"registries": {
67+
"ap-northeast-1": "658645717510",
68+
"ap-northeast-2": "658645717510",
69+
"ap-northeast-3": "658645717510",
70+
"ap-south-1": "658645717510",
71+
"ap-southeast-1": "658645717510",
72+
"ap-southeast-2": "658645717510",
73+
"ca-central-1": "658645717510",
74+
"eu-central-1": "658645717510",
75+
"eu-north-1": "658645717510",
76+
"eu-west-1": "658645717510",
77+
"eu-west-2": "658645717510",
78+
"eu-west-3": "658645717510",
79+
"sa-east-1": "658645717510",
80+
"us-east-1": "658645717510",
81+
"us-east-2": "658645717510",
82+
"us-west-1": "658645717510",
83+
"us-west-2": "658645717510"
84+
},
85+
"repository": "smdistributed-modelparallel"
6086
}
6187
}
6288
}

src/sagemaker/image_uris.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,11 @@ def get_training_image_uri(
678678
if "modelparallel" in distribution["smdistributed"]:
679679
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
680680
framework = "pytorch-smp"
681-
if "p5" in instance_type or "2.1" in framework_version:
681+
if (
682+
"p5" in instance_type
683+
or "2.1" in framework_version
684+
or "2.2" in framework_version
685+
):
682686
container_version = "cu121"
683687
else:
684688
container_version = "cu118"

tests/unit/sagemaker/image_uris/test_smp_v2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def test_smp_v2(load_config):
3535
for region in ACCOUNTS.keys():
3636
for instance_type in CONTAINER_VERSIONS.keys():
3737
cuda_vers = CONTAINER_VERSIONS[instance_type]
38-
if "2.1" in version:
38+
if "2.1" in version or "2.2" in version:
3939
cuda_vers = "cu121"
4040

4141
uri = image_uris.get_training_image_uri(

0 commit comments

Comments
 (0)