diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index a1e91ac154..28e699bc95 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -98,6 +98,8 @@ def retrieve( "mxnet-1.8.0-gpu-py37": "cu110-ubuntu16.04", "pytorch-1.6-gpu-py36": "cu110-ubuntu18.04-v3", "pytorch-1.6.0-gpu-py36": "cu110-ubuntu18.04", + "pytorch-1.6-gpu-py3": "cu110-ubuntu18.04-v3", + "pytorch-1.6.0-gpu-py3": "cu110-ubuntu18.04", } key = "-".join([framework, tag]) if key in container_versions: diff --git a/tests/unit/sagemaker/image_uris/test_retrieve.py b/tests/unit/sagemaker/image_uris/test_retrieve.py index d85f6651df..5d9ecb4fcb 100644 --- a/tests/unit/sagemaker/image_uris/test_retrieve.py +++ b/tests/unit/sagemaker/image_uris/test_retrieve.py @@ -553,6 +553,21 @@ def test_retrieve_auto_selected_container_version(): ) +def test_retrieve_pytorch_container_version(): + uri = image_uris.retrieve( + framework="pytorch", + region="us-west-2", + version="1.6", + py_version="py3", + instance_type="ml.p4d.24xlarge", + image_scope="training", + ) + assert ( + "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.6-gpu-py3-cu110-ubuntu18.04-v3" + == uri + ) + + @patch("sagemaker.image_uris.config_for_framework", return_value=BASE_CONFIG) def test_retrieve_unsupported_processor_type(config_for_framework): with pytest.raises(ValueError) as e: