Skip to content

Commit 1af5dcf

Browse files
authored
feature: Add PT 1.11 support (#3097)
1 parent 5b8eb10 commit 1af5dcf

File tree

3 files changed

+141
-0
lines changed

3 files changed

+141
-0
lines changed

src/sagemaker/fw_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,9 @@
9393
"1.9.1",
9494
"1.10",
9595
"1.10.0",
96+
"1.10.2",
97+
"1.11",
98+
"1.11.0",
9699
],
97100
}
98101
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]

src/sagemaker/image_uri_config/pytorch.json

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -534,6 +534,72 @@
534534
"us-west-2": "763104351884"
535535
},
536536
"repository": "pytorch-inference"
537+
},
538+
"1.10.2": {
539+
"py_versions": [
540+
"py38"
541+
],
542+
"registries": {
543+
"af-south-1": "626614931356",
544+
"ap-east-1": "871362719292",
545+
"ap-northeast-1": "763104351884",
546+
"ap-northeast-2": "763104351884",
547+
"ap-northeast-3": "364406365360",
548+
"ap-south-1": "763104351884",
549+
"ap-southeast-1": "763104351884",
550+
"ap-southeast-2": "763104351884",
551+
"ca-central-1": "763104351884",
552+
"cn-north-1": "727897471807",
553+
"cn-northwest-1": "727897471807",
554+
"eu-central-1": "763104351884",
555+
"eu-north-1": "763104351884",
556+
"eu-west-1": "763104351884",
557+
"eu-west-2": "763104351884",
558+
"eu-west-3": "763104351884",
559+
"eu-south-1": "692866216735",
560+
"me-south-1": "217643126080",
561+
"sa-east-1": "763104351884",
562+
"us-east-1": "763104351884",
563+
"us-east-2": "763104351884",
564+
"us-gov-west-1": "442386744353",
565+
"us-iso-east-1": "886529160074",
566+
"us-west-1": "763104351884",
567+
"us-west-2": "763104351884"
568+
},
569+
"repository": "pytorch-inference"
570+
},
571+
"1.11.0": {
572+
"py_versions": [
573+
"py38"
574+
],
575+
"registries": {
576+
"af-south-1": "626614931356",
577+
"ap-east-1": "871362719292",
578+
"ap-northeast-1": "763104351884",
579+
"ap-northeast-2": "763104351884",
580+
"ap-northeast-3": "364406365360",
581+
"ap-south-1": "763104351884",
582+
"ap-southeast-1": "763104351884",
583+
"ap-southeast-2": "763104351884",
584+
"ca-central-1": "763104351884",
585+
"cn-north-1": "727897471807",
586+
"cn-northwest-1": "727897471807",
587+
"eu-central-1": "763104351884",
588+
"eu-north-1": "763104351884",
589+
"eu-west-1": "763104351884",
590+
"eu-west-2": "763104351884",
591+
"eu-west-3": "763104351884",
592+
"eu-south-1": "692866216735",
593+
"me-south-1": "217643126080",
594+
"sa-east-1": "763104351884",
595+
"us-east-1": "763104351884",
596+
"us-east-2": "763104351884",
597+
"us-gov-west-1": "442386744353",
598+
"us-iso-east-1": "886529160074",
599+
"us-west-1": "763104351884",
600+
"us-west-2": "763104351884"
601+
},
602+
"repository": "pytorch-inference"
537603
}
538604
}
539605
},
@@ -1025,6 +1091,72 @@
10251091
"us-west-2": "763104351884"
10261092
},
10271093
"repository": "pytorch-training"
1094+
},
1095+
"1.10.2": {
1096+
"py_versions": [
1097+
"py38"
1098+
],
1099+
"registries": {
1100+
"af-south-1": "626614931356",
1101+
"ap-east-1": "871362719292",
1102+
"ap-northeast-1": "763104351884",
1103+
"ap-northeast-2": "763104351884",
1104+
"ap-northeast-3": "364406365360",
1105+
"ap-south-1": "763104351884",
1106+
"ap-southeast-1": "763104351884",
1107+
"ap-southeast-2": "763104351884",
1108+
"ca-central-1": "763104351884",
1109+
"cn-north-1": "727897471807",
1110+
"cn-northwest-1": "727897471807",
1111+
"eu-central-1": "763104351884",
1112+
"eu-north-1": "763104351884",
1113+
"eu-west-1": "763104351884",
1114+
"eu-west-2": "763104351884",
1115+
"eu-west-3": "763104351884",
1116+
"eu-south-1": "692866216735",
1117+
"me-south-1": "217643126080",
1118+
"sa-east-1": "763104351884",
1119+
"us-east-1": "763104351884",
1120+
"us-east-2": "763104351884",
1121+
"us-gov-west-1": "442386744353",
1122+
"us-iso-east-1": "886529160074",
1123+
"us-west-1": "763104351884",
1124+
"us-west-2": "763104351884"
1125+
},
1126+
"repository": "pytorch-training"
1127+
},
1128+
"1.11.0": {
1129+
"py_versions": [
1130+
"py38"
1131+
],
1132+
"registries": {
1133+
"af-south-1": "626614931356",
1134+
"ap-east-1": "871362719292",
1135+
"ap-northeast-1": "763104351884",
1136+
"ap-northeast-2": "763104351884",
1137+
"ap-northeast-3": "364406365360",
1138+
"ap-south-1": "763104351884",
1139+
"ap-southeast-1": "763104351884",
1140+
"ap-southeast-2": "763104351884",
1141+
"ca-central-1": "763104351884",
1142+
"cn-north-1": "727897471807",
1143+
"cn-northwest-1": "727897471807",
1144+
"eu-central-1": "763104351884",
1145+
"eu-north-1": "763104351884",
1146+
"eu-west-1": "763104351884",
1147+
"eu-west-2": "763104351884",
1148+
"eu-west-3": "763104351884",
1149+
"eu-south-1": "692866216735",
1150+
"me-south-1": "217643126080",
1151+
"sa-east-1": "763104351884",
1152+
"us-east-1": "763104351884",
1153+
"us-east-2": "763104351884",
1154+
"us-gov-west-1": "442386744353",
1155+
"us-iso-east-1": "886529160074",
1156+
"us-west-1": "763104351884",
1157+
"us-west-2": "763104351884"
1158+
},
1159+
"repository": "pytorch-training"
10281160
}
10291161
}
10301162
}

tests/unit/test_fw_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,11 @@ def test_validate_smdataparallel_args_not_raises():
700700
("ml.p3.16xlarge", "pytorch", "1.8", "py3", smdataparallel_enabled),
701701
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled),
702702
("ml.p3.16xlarge", "pytorch", "1.9", "py38", smdataparallel_enabled),
703+
("ml.p3.16xlarge", "pytorch", "1.10.0", "py38", smdataparallel_enabled),
704+
("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled),
703705
("ml.p3.16xlarge", "pytorch", "1.10", "py38", smdataparallel_enabled),
706+
("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled),
707+
("ml.p3.16xlarge", "pytorch", "1.11", "py38", smdataparallel_enabled),
704708
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
705709
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
706710
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -713,6 +717,8 @@ def test_validate_smdataparallel_args_not_raises():
713717
("ml.p3.16xlarge", "tensorflow", "2.8.0", "py39", smdataparallel_enabled_custom_mpi),
714718
("ml.p3.16xlarge", "pytorch", "1.8.0", "py3", smdataparallel_enabled_custom_mpi),
715719
("ml.p3.16xlarge", "pytorch", "1.9.1", "py38", smdataparallel_enabled_custom_mpi),
720+
("ml.p3.16xlarge", "pytorch", "1.10.2", "py38", smdataparallel_enabled_custom_mpi),
721+
("ml.p3.16xlarge", "pytorch", "1.11.0", "py38", smdataparallel_enabled_custom_mpi),
716722
]
717723
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
718724
fw_utils._validate_smdataparallel_args(

0 commit comments

Comments
 (0)