Skip to content

Commit ab23cda

Browse files
committed
feature: Adding image_uri config for DJL containers
1 parent 885423c commit ab23cda

File tree

3 files changed

+102
-0
lines changed

3 files changed

+102
-0
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"scope": ["inference"],
3+
"versions": {
4+
"0.19.0": {
5+
"registries": {
6+
"af-south-1": "626614931356",
7+
"ap-east-1": "871362719292",
8+
"ap-northeast-1": "763104351884",
9+
"ap-northeast-2": "763104351884",
10+
"ap-northeast-3": "364406365360",
11+
"ap-south-1": "763104351884",
12+
"ap-southeast-1": "763104351884",
13+
"ap-southeast-2": "763104351884",
14+
"ap-southeast-3": "907027046896",
15+
"ca-central-1": "763104351884",
16+
"cn-north-1": "727897471807",
17+
"cn-northwest-1": "727897471807",
18+
"eu-central-1": "763104351884",
19+
"eu-north-1": "763104351884",
20+
"eu-west-1": "763104351884",
21+
"eu-west-2": "763104351884",
22+
"eu-west-3": "763104351884",
23+
"eu-south-1": "692866216735",
24+
"me-south-1": "217643126080",
25+
"sa-east-1": "763104351884",
26+
"us-east-1": "763104351884",
27+
"us-east-2": "763104351884",
28+
"us-west-1": "763104351884",
29+
"us-west-2": "763104351884"
30+
},
31+
"repository": "djl-inference",
32+
"tag_prefix": "0.19.0-deepspeed0.7.3-cu113"
33+
}
34+
}
35+
}

tests/unit/sagemaker/image_uris/expected_uris.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,9 @@ def graviton_framework_uri(
7171
tag = "-".join(x for x in (fw_version, processor, py_version, container_version) if x)
7272

7373
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
74+
75+
76+
def djl_framework_uri(repo, account, djl_version, primary_framework, region=REGION):
77+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
78+
tag = f"{djl_version}-{primary_framework}"
79+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
ACCOUNTS = {
19+
"af-south-1": "626614931356",
20+
"ap-east-1": "871362719292",
21+
"ap-northeast-1": "763104351884",
22+
"ap-northeast-2": "763104351884",
23+
"ap-northeast-3": "364406365360",
24+
"ap-south-1": "763104351884",
25+
"ap-southeast-1": "763104351884",
26+
"ap-southeast-2": "763104351884",
27+
"ap-southeast-3": "907027046896",
28+
"ca-central-1": "763104351884",
29+
"cn-north-1": "727897471807",
30+
"cn-northwest-1": "727897471807",
31+
"eu-central-1": "763104351884",
32+
"eu-north-1": "763104351884",
33+
"eu-west-1": "763104351884",
34+
"eu-west-2": "763104351884",
35+
"eu-west-3": "763104351884",
36+
"eu-south-1": "692866216735",
37+
"me-south-1": "217643126080",
38+
"sa-east-1": "763104351884",
39+
"us-east-1": "763104351884",
40+
"us-east-2": "763104351884",
41+
"us-west-1": "763104351884",
42+
"us-west-2": "763104351884",
43+
}
44+
VERSIONS = ["0.19.0"]
45+
DJL_FRAMEWORKS = ["djl-deepspeed"]
46+
DJL_VERSIONS_TO_FRAMEWORK = {"0.19.0": {"djl-deepspeed": "deepspeed0.7.3-cu113"}}
47+
48+
49+
@pytest.mark.parametrize("region", ACCOUNTS.keys())
50+
@pytest.mark.parametrize("version", VERSIONS)
51+
@pytest.mark.parametrize("djl_framework", DJL_FRAMEWORKS)
52+
def test_djl_uris(region, version, djl_framework):
53+
uri = image_uris.retrieve(framework=djl_framework, region=region, version=version)
54+
expected = expected_uris.djl_framework_uri(
55+
"djl-inference",
56+
ACCOUNTS[region],
57+
version,
58+
DJL_VERSIONS_TO_FRAMEWORK[version][djl_framework],
59+
region,
60+
)
61+
assert expected == uri

0 commit comments

Comments
 (0)