16
16
from sagemaker import image_uris
17
17
from tests .unit .sagemaker .image_uris import expected_uris
18
18
19
- COMMON_INSTANCE_TYPES = {"cpu" : " ml.c4.xlarge" , "gpu" : "ml.p4d .24xlarge" }
19
+ CONTAINER_VERSIONS = {"ml.p4d.24xlarge" : "cu118" , "ml.p5d .24xlarge" : "cu12 " }
20
20
21
21
22
22
@pytest .mark .parametrize ("load_config" , ["pytorch-smp.json" ], indirect = True )
@@ -40,20 +40,21 @@ def test_smp_v2(load_config):
40
40
PY_VERSIONS = load_config ["training" ]["versions" ][version ]["py_versions" ]
41
41
for py_version in PY_VERSIONS :
42
42
for region in ACCOUNTS .keys ():
43
- uri = image_uris .get_training_image_uri (
44
- region ,
45
- framework = "pytorch" ,
46
- framework_version = version ,
47
- py_version = py_version ,
48
- distribution = distribution ,
49
- instance_type = COMMON_INSTANCE_TYPES [processor ]
50
- )
51
- expected = expected_uris .framework_uri (
52
- repo = "smdistributed-modelparallel" ,
53
- fw_version = version ,
54
- py_version = py_version ,
55
- processor = processor ,
56
- region = region ,
57
- account = ACCOUNTS [region ],
58
- )
59
- assert expected == uri
43
+ for instance_type in CONTAINER_VERSIONS .keys ():
44
+ uri = image_uris .get_training_image_uri (
45
+ region ,
46
+ framework = "pytorch" ,
47
+ framework_version = version ,
48
+ py_version = py_version ,
49
+ distribution = distribution ,
50
+ instance_type = instance_type
51
+ )
52
+ expected = expected_uris .framework_uri (
53
+ repo = "smdistributed-modelparallel" ,
54
+ fw_version = version ,
55
+ py_version = f"{ py_version } -{ CONTAINER_VERSIONS [instance_type ]} " ,
56
+ processor = processor ,
57
+ region = region ,
58
+ account = ACCOUNTS [region ],
59
+ )
60
+ assert expected == uri
0 commit comments