Skip to content

Commit 42c58eb

Browse files
Teng-xuakrishna1995
authored andcommitted
Add cuda version in uri
1 parent e8edeaa commit 42c58eb

File tree

3 files changed

+24
-20
lines changed

3 files changed

+24
-20
lines changed

src/sagemaker/image_uri_config/pytorch-smp.json

-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
{
22
"training": {
33
"processors": [
4-
"cpu",
54
"gpu"
65
],
76
"version_aliases": {

src/sagemaker/image_uris.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -666,7 +666,11 @@ def get_training_image_uri(
666666
if "modelparallel" in distribution["smdistributed"]:
667667
if distribution["smdistributed"]["modelparallel"].get("enabled", True):
668668
framework = "pytorch-smp"
669-
669+
if "p5" in instance_type:
670+
container_version = "cu12"
671+
else:
672+
container_version = "cu118"
673+
670674
return retrieve(
671675
framework,
672676
region,

tests/unit/sagemaker/image_uris/test_smp_v2.py

+19-18
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from sagemaker import image_uris
1717
from tests.unit.sagemaker.image_uris import expected_uris
1818

19-
COMMON_INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p4d.24xlarge"}
19+
CONTAINER_VERSIONS = {"ml.p4d.24xlarge": "cu118", "ml.p5d.24xlarge": "cu12"}
2020

2121

2222
@pytest.mark.parametrize("load_config", ["pytorch-smp.json"], indirect=True)
@@ -40,20 +40,21 @@ def test_smp_v2(load_config):
4040
PY_VERSIONS = load_config["training"]["versions"][version]["py_versions"]
4141
for py_version in PY_VERSIONS:
4242
for region in ACCOUNTS.keys():
43-
uri = image_uris.get_training_image_uri(
44-
region,
45-
framework="pytorch",
46-
framework_version=version,
47-
py_version=py_version,
48-
distribution=distribution,
49-
instance_type=COMMON_INSTANCE_TYPES[processor]
50-
)
51-
expected = expected_uris.framework_uri(
52-
repo="smdistributed-modelparallel",
53-
fw_version=version,
54-
py_version=py_version,
55-
processor=processor,
56-
region=region,
57-
account=ACCOUNTS[region],
58-
)
59-
assert expected == uri
43+
for instance_type in CONTAINER_VERSIONS.keys():
44+
uri = image_uris.get_training_image_uri(
45+
region,
46+
framework="pytorch",
47+
framework_version=version,
48+
py_version=py_version,
49+
distribution=distribution,
50+
instance_type=instance_type
51+
)
52+
expected = expected_uris.framework_uri(
53+
repo="smdistributed-modelparallel",
54+
fw_version=version,
55+
py_version=f"{py_version}-{CONTAINER_VERSIONS[instance_type]}",
56+
processor=processor,
57+
region=region,
58+
account=ACCOUNTS[region],
59+
)
60+
assert expected == uri

0 commit comments

Comments
 (0)