30
30
]
31
31
32
32
33
- def _test_graviton_framework_uris (framework , version , py_version , account , region ):
33
+ def _test_graviton_framework_uris (
34
+ framework , version , py_version , account , region , container_version = "ubuntu20.04-sagemaker"
35
+ ):
34
36
for instance_type in GRAVITON_INSTANCE_TYPES :
35
37
uri = image_uris .retrieve (framework , region , instance_type = instance_type , version = version )
36
38
expected = _expected_graviton_framework_uri (
37
- framework , version , py_version , account , region = region
39
+ framework ,
40
+ version ,
41
+ py_version ,
42
+ account ,
43
+ region = region ,
44
+ container_version = container_version ,
38
45
)
39
46
assert expected == uri
40
47
@@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope):
50
57
for version in VERSIONS :
51
58
ACCOUNTS = config [scope ]["versions" ][version ]["registries" ]
52
59
py_versions = config [scope ]["versions" ][version ]["py_versions" ]
60
+ container_version = (
61
+ config [scope ]["versions" ][version ].get ("container_version" , {}).get ("cpu" , None )
62
+ )
63
+ if container_version :
64
+ container_version = container_version + "-sagemaker"
53
65
for py_version in py_versions :
54
66
for region in ACCOUNTS .keys ():
55
- _test_graviton_framework_uris (
56
- framework , version , py_version , ACCOUNTS [region ], region
57
- )
67
+ if container_version :
68
+ _test_graviton_framework_uris (
69
+ framework , version , py_version , ACCOUNTS [region ], region , container_version
70
+ )
71
+ else :
72
+ _test_graviton_framework_uris (
73
+ framework , version , py_version , ACCOUNTS [region ], region
74
+ )
58
75
59
76
60
77
def _test_graviton_unsupported_framework (framework , region , framework_version ):
@@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un
183
200
assert "Unsupported instance type: m5." in str (error )
184
201
185
202
186
- def _expected_graviton_framework_uri (framework , version , py_version , account , region ):
203
+ def _expected_graviton_framework_uri (
204
+ framework , version , py_version , account , region , container_version
205
+ ):
187
206
return expected_uris .graviton_framework_uri (
188
207
"{}-inference-graviton" .format (framework ),
189
208
fw_version = version ,
190
209
py_version = py_version ,
191
210
account = account ,
192
211
region = region ,
212
+ container_version = container_version ,
193
213
)
0 commit comments