Skip to content

Commit 963a8cd

Browse files
beniericryansteakley
authored andcommitted
Fix: image_uris graviton image uri (aws#4909)
1 parent 208b164 commit 963a8cd

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

tests/unit/sagemaker/image_uris/test_graviton.py

+26-6
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,18 @@
3030
]
3131

3232

33-
def _test_graviton_framework_uris(framework, version, py_version, account, region):
33+
def _test_graviton_framework_uris(
34+
framework, version, py_version, account, region, container_version="ubuntu20.04-sagemaker"
35+
):
3436
for instance_type in GRAVITON_INSTANCE_TYPES:
3537
uri = image_uris.retrieve(framework, region, instance_type=instance_type, version=version)
3638
expected = _expected_graviton_framework_uri(
37-
framework, version, py_version, account, region=region
39+
framework,
40+
version,
41+
py_version,
42+
account,
43+
region=region,
44+
container_version=container_version,
3845
)
3946
assert expected == uri
4047

@@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope):
5057
for version in VERSIONS:
5158
ACCOUNTS = config[scope]["versions"][version]["registries"]
5259
py_versions = config[scope]["versions"][version]["py_versions"]
60+
container_version = (
61+
config[scope]["versions"][version].get("container_version", {}).get("cpu", None)
62+
)
63+
if container_version:
64+
container_version = container_version + "-sagemaker"
5365
for py_version in py_versions:
5466
for region in ACCOUNTS.keys():
55-
_test_graviton_framework_uris(
56-
framework, version, py_version, ACCOUNTS[region], region
57-
)
67+
if container_version:
68+
_test_graviton_framework_uris(
69+
framework, version, py_version, ACCOUNTS[region], region, container_version
70+
)
71+
else:
72+
_test_graviton_framework_uris(
73+
framework, version, py_version, ACCOUNTS[region], region
74+
)
5875

5976

6077
def _test_graviton_unsupported_framework(framework, region, framework_version):
@@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un
183200
assert "Unsupported instance type: m5." in str(error)
184201

185202

186-
def _expected_graviton_framework_uri(framework, version, py_version, account, region):
203+
def _expected_graviton_framework_uri(
204+
framework, version, py_version, account, region, container_version
205+
):
187206
return expected_uris.graviton_framework_uri(
188207
"{}-inference-graviton".format(framework),
189208
fw_version=version,
190209
py_version=py_version,
191210
account=account,
192211
region=region,
212+
container_version=container_version,
193213
)

0 commit comments

Comments
 (0)