Skip to content

Commit 56352f3

Browse files
authored
fix: add the mapping from py3 to cuda11 images (#2154)
1 parent e7ad54c commit 56352f3

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/sagemaker/image_uris.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def retrieve(
9898
"mxnet-1.8.0-gpu-py37": "cu110-ubuntu16.04",
9999
"pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3",
100100
"pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04",
101+
"pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3",
102+
"pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04",
101103
}
102104
key = "-".join([framework, tag])
103105
if key in container_versions:

tests/unit/sagemaker/image_uris/test_retrieve.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,21 @@ def test_retrieve_auto_selected_container_version():
553553
)
554554

555555

556+
def test_retrieve_pytorch_container_version():
557+
uri = image_uris.retrieve(
558+
framework="pytorch",
559+
region="us-west-2",
560+
version="1.6",
561+
py_version="py3",
562+
instance_type="ml.p4d.24xlarge",
563+
image_scope="training",
564+
)
565+
assert (
566+
"763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.6-gpu-py3-cu110-ubuntu18.04-v3"
567+
== uri
568+
)
569+
570+
556571
@patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG)
557572
def test_retrieve_unsupported_processor_type(config_for_framework):
558573
with pytest.raises(ValueError) as e:

0 commit comments

Comments
 (0)