Skip to content

Commit c217e1c

Browse files
Mike SchneiderShibo Xing
Mike Schneider
and
Shibo Xing
authored
feature: Add PyTorch 2.0 to SDK (#3795)
* Add PT 2.0 inference to image_uri_config * Added PT 2.0 SM updates * feat: add 2.0 for torch_distributed in training * fix: remove china and govclouds * update py version for PT 2.0 tests * add missing regions for inference. * was still missing the needed regions.. --------- Co-authored-by: Shibo Xing <[email protected]>
1 parent d986e3e commit c217e1c

File tree

5 files changed

+130
-8
lines changed

5 files changed

+130
-8
lines changed

src/sagemaker/fw_utils.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@
135135
"1.12.0",
136136
"1.12.1",
137137
"1.13.1",
138+
"2.0.0",
138139
],
139140
}
140141

@@ -148,10 +149,11 @@
148149
"1.12.0",
149150
"1.12.1",
150151
"1.13.1",
152+
"2.0.0",
151153
]
152154

153155

154-
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1"]
156+
TORCH_DISTRIBUTED_GPU_SUPPORTED_FRAMEWORK_VERSIONS = ["1.13.1", "2.0.0"]
155157

156158
TRAINIUM_SUPPORTED_DISTRIBUTION_STRATEGIES = ["torch_distributed"]
157159
TRAINIUM_SUPPORTED_TORCH_DISTRIBUTED_FRAMEWORK_VERSIONS = [
@@ -161,6 +163,7 @@
161163
"1.12.0",
162164
"1.12.1",
163165
"1.13.1",
166+
"2.0.0",
164167
]
165168

166169
SMDISTRIBUTED_SUPPORTED_STRATEGIES = ["dataparallel", "modelparallel"]

src/sagemaker/image_uri_config/pytorch.json

+115-3
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,8 @@
7777
"1.10": "1.10.2",
7878
"1.11": "1.11.0",
7979
"1.12": "1.12.1",
80-
"1.13": "1.13.1"
80+
"1.13": "1.13.1",
81+
"2.0": "2.0.0"
8182
},
8283
"versions": {
8384
"0.4.0": {
@@ -838,6 +839,43 @@
838839
"us-west-2": "763104351884"
839840
},
840841
"repository": "pytorch-inference"
842+
},
843+
"2.0.0": {
844+
"py_versions": [
845+
"py310"
846+
],
847+
"registries": {
848+
"af-south-1": "626614931356",
849+
"ap-east-1": "871362719292",
850+
"ap-northeast-1": "763104351884",
851+
"ap-northeast-2": "763104351884",
852+
"ap-northeast-3": "364406365360",
853+
"ap-south-1": "763104351884",
854+
"ap-southeast-1": "763104351884",
855+
"ap-southeast-2": "763104351884",
856+
"ap-southeast-3": "907027046896",
857+
"ap-southeast-4": "457447274322",
858+
"ca-central-1": "763104351884",
859+
"cn-north-1": "727897471807",
860+
"cn-northwest-1": "727897471807",
861+
"eu-central-1": "763104351884",
862+
"eu-north-1": "763104351884",
863+
"eu-west-1": "763104351884",
864+
"eu-west-2": "763104351884",
865+
"eu-west-3": "763104351884",
866+
"eu-south-1": "692866216735",
867+
"me-south-1": "217643126080",
868+
"sa-east-1": "763104351884",
869+
"us-east-1": "763104351884",
870+
"us-east-2": "763104351884",
871+
"us-gov-east-1": "446045086412",
872+
"us-gov-west-1": "442386744353",
873+
"us-iso-east-1": "886529160074",
874+
"us-isob-east-1": "094389454867",
875+
"us-west-1": "763104351884",
876+
"us-west-2": "763104351884"
877+
},
878+
"repository": "pytorch-inference"
841879
}
842880
}
843881
},
@@ -846,7 +884,8 @@
846884
"cpu"
847885
],
848886
"version_aliases": {
849-
"1.12": "1.12.1"
887+
"1.12": "1.12.1",
888+
"2.0": "2.0.0"
850889
},
851890
"versions": {
852891
"1.12.1": {
@@ -889,6 +928,41 @@
889928
},
890929
"repository": "pytorch-inference-graviton",
891930
"container_version": {"cpu": "ubuntu20.04"}
931+
},
932+
"2.0.0": {
933+
"py_versions": [
934+
"py310"
935+
],
936+
"registries": {
937+
"af-south-1": "626614931356",
938+
"ap-east-1": "871362719292",
939+
"ap-northeast-1": "763104351884",
940+
"ap-northeast-2": "763104351884",
941+
"ap-northeast-3": "364406365360",
942+
"ap-south-1": "763104351884",
943+
"ap-south-2": "772153158452",
944+
"ap-southeast-1": "763104351884",
945+
"ap-southeast-2": "763104351884",
946+
"ap-southeast-3": "907027046896",
947+
"ap-southeast-4": "457447274322",
948+
"ca-central-1": "763104351884",
949+
"eu-central-1": "763104351884",
950+
"eu-central-2": "380420809688",
951+
"eu-north-1": "763104351884",
952+
"eu-west-1": "763104351884",
953+
"eu-west-2": "763104351884",
954+
"eu-west-3": "763104351884",
955+
"eu-south-1": "692866216735",
956+
"eu-south-2": "503227376785",
957+
"me-south-1": "217643126080",
958+
"sa-east-1": "763104351884",
959+
"us-east-1": "763104351884",
960+
"us-east-2": "763104351884",
961+
"us-west-1": "763104351884",
962+
"us-west-2": "763104351884"
963+
},
964+
"repository": "pytorch-inference-graviton",
965+
"container_version": {"cpu": "ubuntu20.04"}
892966
}
893967
}
894968
},
@@ -912,7 +986,8 @@
912986
"1.10": "1.10.2",
913987
"1.11": "1.11.0",
914988
"1.12": "1.12.1",
915-
"1.13": "1.13.1"
989+
"1.13": "1.13.1",
990+
"2.0": "2.0.0"
916991
},
917992
"versions": {
918993
"0.4.0": {
@@ -1674,6 +1749,43 @@
16741749
"us-west-2": "763104351884"
16751750
},
16761751
"repository": "pytorch-training"
1752+
},
1753+
"2.0.0": {
1754+
"py_versions": [
1755+
"py310"
1756+
],
1757+
"registries": {
1758+
"af-south-1": "626614931356",
1759+
"ap-east-1": "871362719292",
1760+
"ap-northeast-1": "763104351884",
1761+
"ap-northeast-2": "763104351884",
1762+
"ap-northeast-3": "364406365360",
1763+
"ap-south-1": "763104351884",
1764+
"ap-southeast-1": "763104351884",
1765+
"ap-southeast-2": "763104351884",
1766+
"ap-southeast-3": "907027046896",
1767+
"ap-southeast-4": "457447274322",
1768+
"ca-central-1": "763104351884",
1769+
"cn-north-1": "727897471807",
1770+
"cn-northwest-1": "727897471807",
1771+
"eu-central-1": "763104351884",
1772+
"eu-north-1": "763104351884",
1773+
"eu-west-1": "763104351884",
1774+
"eu-west-2": "763104351884",
1775+
"eu-west-3": "763104351884",
1776+
"eu-south-1": "692866216735",
1777+
"me-south-1": "217643126080",
1778+
"sa-east-1": "763104351884",
1779+
"us-east-1": "763104351884",
1780+
"us-east-2": "763104351884",
1781+
"us-gov-east-1": "446045086412",
1782+
"us-gov-west-1": "442386744353",
1783+
"us-iso-east-1": "886529160074",
1784+
"us-isob-east-1": "094389454867",
1785+
"us-west-1": "763104351884",
1786+
"us-west-2": "763104351884"
1787+
},
1788+
"repository": "pytorch-training"
16771789
}
16781790
}
16791791
}

tests/conftest.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def mxnet_eia_latest_py_version():
243243

244244
@pytest.fixture(scope="module", params=["py2", "py3"])
245245
def pytorch_training_py_version(pytorch_training_version, request):
246-
if Version(pytorch_training_version) >= Version("1.13"):
246+
if Version(pytorch_training_version) >= Version("2.0"):
247+
return "py310"
248+
elif Version(pytorch_training_version) >= Version("1.13"):
247249
return "py39"
248250
elif Version(pytorch_training_version) >= Version("1.9"):
249251
return "py38"
@@ -255,7 +257,9 @@ def pytorch_training_py_version(pytorch_training_version, request):
255257

256258
@pytest.fixture(scope="module", params=["py2", "py3"])
257259
def pytorch_inference_py_version(pytorch_inference_version, request):
258-
if Version(pytorch_inference_version) >= Version("1.13"):
260+
if Version(pytorch_inference_version) >= Version("2.0"):
261+
return "py310"
262+
elif Version(pytorch_inference_version) >= Version("1.13"):
259263
return "py39"
260264
elif Version(pytorch_inference_version) >= Version("1.9"):
261265
return "py38"

tests/unit/test_fw_utils.py

+3
Original file line numberDiff line numberDiff line change
@@ -913,6 +913,7 @@ def test_validate_smdataparallel_args_not_raises():
913913
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled),
914914
("ml.p3.16xlarge", "pytorch", "1.12", "py38", smdataparallel_enabled),
915915
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled),
916+
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled),
916917
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py3", smdataparallel_enabled_custom_mpi),
917918
("ml.p3.16xlarge", "tensorflow", "2.4.1", "py37", smdataparallel_enabled_custom_mpi),
918919
("ml.p3.16xlarge", "tensorflow", "2.4.3", "py3", smdataparallel_enabled_custom_mpi),
@@ -934,6 +935,7 @@ def test_validate_smdataparallel_args_not_raises():
934935
("ml.p3.16xlarge", "pytorch", "1.12.0", "py38", smdataparallel_enabled_custom_mpi),
935936
("ml.p3.16xlarge", "pytorch", "1.12.1", "py38", smdataparallel_enabled_custom_mpi),
936937
("ml.p3.16xlarge", "pytorch", "1.13.1", "py39", smdataparallel_enabled_custom_mpi),
938+
("ml.p3.16xlarge", "pytorch", "2.0.0", "py310", smdataparallel_enabled_custom_mpi),
937939
]
938940
for instance_type, framework_name, framework_version, py_version, distribution in good_args:
939941
fw_utils._validate_smdataparallel_args(
@@ -1034,6 +1036,7 @@ def test_validate_torch_distributed_not_raises():
10341036
torch_distributed_enabled = {"torch_distributed": {"enabled": True}}
10351037
torch_distributed_gpu_supported_fw_versions = [
10361038
"1.13.1",
1039+
"2.0.0",
10371040
]
10381041
for framework_version in torch_distributed_gpu_supported_fw_versions:
10391042
fw_utils.validate_torch_distributed_distribution(

tests/unit/test_utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -380,8 +380,8 @@ def test_set_nested_value():
380380

381381

382382
def test_get_short_version():
383-
assert sagemaker.utils.get_short_version("1.13.1") == "1.13"
384-
assert sagemaker.utils.get_short_version("1.13") == "1.13"
383+
assert sagemaker.utils.get_short_version("2.0.0") == "2.0"
384+
assert sagemaker.utils.get_short_version("2.0") == "2.0"
385385

386386

387387
def test_deferred_error():

0 commit comments

Comments
 (0)