diff --git a/src/sagemaker/image_uri_config/djl-deepspeed.json b/src/sagemaker/image_uri_config/djl-deepspeed.json new file mode 100644 index 0000000000..504f7960ef --- /dev/null +++ b/src/sagemaker/image_uri_config/djl-deepspeed.json @@ -0,0 +1,65 @@ +{ + "scope": ["inference"], + "versions": { + "0.20.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "djl-inference", + "tag_prefix": "0.20.0-deepspeed0.7.5-cu116" + }, + "0.19.0": { + "registries": { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884" + }, + "repository": "djl-inference", + "tag_prefix": "0.19.0-deepspeed0.7.3-cu113" + } + } +} diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index a38ca2c8bf..2729d7db51 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -72,3 +72,9 @@ def graviton_framework_uri( tag = "-".join(x for x in (fw_version, processor, py_version, container_version) if x) return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + +def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + tag = f"{djl_version}-{primary_framework}" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) diff --git a/tests/unit/sagemaker/image_uris/test_djl.py b/tests/unit/sagemaker/image_uris/test_djl.py new file mode 100644 index 0000000000..27a680e752 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_djl.py @@ -0,0 +1,64 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +ACCOUNTS = { + "af-south-1": "626614931356", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", +} +VERSIONS = ["0.20.0", "0.19.0"] +DJL_FRAMEWORKS = ["djl-deepspeed"] +DJL_VERSIONS_TO_FRAMEWORK = { + "0.19.0": {"djl-deepspeed": "deepspeed0.7.3-cu113"}, + "0.20.0": {"djl-deepspeed": "deepspeed0.7.5-cu116"}, +} + + +@pytest.mark.parametrize("region", ACCOUNTS.keys()) +@pytest.mark.parametrize("version", VERSIONS) +@pytest.mark.parametrize("djl_framework", DJL_FRAMEWORKS) +def test_djl_uris(region, version, djl_framework): + uri = image_uris.retrieve(framework=djl_framework, region=region, version=version) + expected = expected_uris.djl_framework_uri( + "djl-inference", + ACCOUNTS[region], + version, + DJL_VERSIONS_TO_FRAMEWORK[version][djl_framework], + region, + ) + assert expected == uri