File tree 2 files changed +17
-0
lines changed
tests/unit/sagemaker/image_uris
2 files changed +17
-0
lines changed Original file line number Diff line number Diff line change @@ -98,6 +98,8 @@ def retrieve(
98
98
"mxnet-1.8.0-gpu-py37" : "cu110-ubuntu16.04" ,
99
99
"pytorch-1.6-gpu-py36" : "cu110-ubuntu18.04-v3" ,
100
100
"pytorch-1.6.0-gpu-py36" : "cu110-ubuntu18.04" ,
101
+ "pytorch-1.6-gpu-py3" : "cu110-ubuntu18.04-v3" ,
102
+ "pytorch-1.6.0-gpu-py3" : "cu110-ubuntu18.04" ,
101
103
}
102
104
key = "-" .join ([framework , tag ])
103
105
if key in container_versions :
Original file line number Diff line number Diff line change @@ -553,6 +553,21 @@ def test_retrieve_auto_selected_container_version():
553
553
)
554
554
555
555
556
+ def test_retrieve_pytorch_container_version ():
557
+ uri = image_uris .retrieve (
558
+ framework = "pytorch" ,
559
+ region = "us-west-2" ,
560
+ version = "1.6" ,
561
+ py_version = "py3" ,
562
+ instance_type = "ml.p4d.24xlarge" ,
563
+ image_scope = "training" ,
564
+ )
565
+ assert (
566
+ "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:1.6-gpu-py3-cu110-ubuntu18.04-v3"
567
+ == uri
568
+ )
569
+
570
+
556
571
@patch ("sagemaker.image_uris.config_for_framework" , return_value = BASE_CONFIG )
557
572
def test_retrieve_unsupported_processor_type (config_for_framework ):
558
573
with pytest .raises (ValueError ) as e :
You can’t perform that action at this time.
0 commit comments