diff --git a/tests/unit/sagemaker/image_uris/test_graviton.py b/tests/unit/sagemaker/image_uris/test_graviton.py index ea4ef29919..a122be9291 100644 --- a/tests/unit/sagemaker/image_uris/test_graviton.py +++ b/tests/unit/sagemaker/image_uris/test_graviton.py @@ -30,11 +30,18 @@ ] -def _test_graviton_framework_uris(framework, version, py_version, account, region): +def _test_graviton_framework_uris( + framework, version, py_version, account, region, container_version="ubuntu20.04-sagemaker" +): for instance_type in GRAVITON_INSTANCE_TYPES: uri = image_uris.retrieve(framework, region, instance_type=instance_type, version=version) expected = _expected_graviton_framework_uri( - framework, version, py_version, account, region=region + framework, + version, + py_version, + account, + region=region, + container_version=container_version, ) assert expected == uri @@ -50,11 +57,21 @@ def test_graviton_framework_uris(load_config_and_file_name, scope): for version in VERSIONS: ACCOUNTS = config[scope]["versions"][version]["registries"] py_versions = config[scope]["versions"][version]["py_versions"] + container_version = ( + config[scope]["versions"][version].get("container_version", {}).get("cpu", None) + ) + if container_version: + container_version = container_version + "-sagemaker" for py_version in py_versions: for region in ACCOUNTS.keys(): - _test_graviton_framework_uris( - framework, version, py_version, ACCOUNTS[region], region - ) + if container_version: + _test_graviton_framework_uris( + framework, version, py_version, ACCOUNTS[region], region, container_version + ) + else: + _test_graviton_framework_uris( + framework, version, py_version, ACCOUNTS[region], region + ) def _test_graviton_unsupported_framework(framework, region, framework_version): @@ -183,11 +200,14 @@ def test_graviton_sklearn_image_scope_specified_x86_instance(graviton_sklearn_un assert "Unsupported instance type: m5." in str(error) -def _expected_graviton_framework_uri(framework, version, py_version, account, region): +def _expected_graviton_framework_uri( + framework, version, py_version, account, region, container_version +): return expected_uris.graviton_framework_uri( "{}-inference-graviton".format(framework), fw_version=version, py_version=py_version, account=account, region=region, + container_version=container_version, )