Skip to content

Commit 5158749

Browse files
authored
feat: Add new Triton DLC URIs (#4432)
* Add new Triton DLC URIs * Update according to black and pylint
1 parent 5782eb5 commit 5158749

File tree

4 files changed

+143
-0
lines changed

4 files changed

+143
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
{
2+
"processors": [
3+
"cpu",
4+
"gpu"
5+
],
6+
"scope": [
7+
"inference"
8+
],
9+
"versions": {
10+
"23.12": {
11+
"registries": {
12+
"af-south-1": "626614931356",
13+
"il-central-1": "780543022126",
14+
"ap-east-1": "871362719292",
15+
"ap-northeast-1": "763104351884",
16+
"ap-northeast-2": "763104351884",
17+
"ap-northeast-3": "364406365360",
18+
"ap-south-1": "763104351884",
19+
"ap-southeast-1": "763104351884",
20+
"ap-southeast-2": "763104351884",
21+
"ap-southeast-3": "907027046896",
22+
"ca-central-1": "763104351884",
23+
"cn-north-1": "727897471807",
24+
"cn-northwest-1": "727897471807",
25+
"eu-central-1": "763104351884",
26+
"eu-north-1": "763104351884",
27+
"eu-west-1": "763104351884",
28+
"eu-west-2": "763104351884",
29+
"eu-west-3": "763104351884",
30+
"eu-south-1": "692866216735",
31+
"me-south-1": "217643126080",
32+
"sa-east-1": "763104351884",
33+
"us-east-1": "763104351884",
34+
"us-east-2": "763104351884",
35+
"us-west-1": "763104351884",
36+
"us-west-2": "763104351884",
37+
"ca-west-1": "204538143572"
38+
},
39+
"repository": "sagemaker-tritonserver",
40+
"tag_prefix": "23.12-py3"
41+
},
42+
"24.01": {
43+
"registries": {
44+
"af-south-1": "626614931356",
45+
"il-central-1": "780543022126",
46+
"ap-east-1": "871362719292",
47+
"ap-northeast-1": "763104351884",
48+
"ap-northeast-2": "763104351884",
49+
"ap-northeast-3": "364406365360",
50+
"ap-south-1": "763104351884",
51+
"ap-southeast-1": "763104351884",
52+
"ap-southeast-2": "763104351884",
53+
"ap-southeast-3": "907027046896",
54+
"ca-central-1": "763104351884",
55+
"cn-north-1": "727897471807",
56+
"cn-northwest-1": "727897471807",
57+
"eu-central-1": "763104351884",
58+
"eu-north-1": "763104351884",
59+
"eu-west-1": "763104351884",
60+
"eu-west-2": "763104351884",
61+
"eu-west-3": "763104351884",
62+
"eu-south-1": "692866216735",
63+
"me-south-1": "217643126080",
64+
"sa-east-1": "763104351884",
65+
"us-east-1": "763104351884",
66+
"us-east-2": "763104351884",
67+
"us-west-1": "763104351884",
68+
"us-west-2": "763104351884",
69+
"ca-west-1": "204538143572"
70+
},
71+
"repository": "sagemaker-tritonserver",
72+
"tag_prefix": "24.01-py3"
73+
}
74+
}
75+
}

src/sagemaker/image_uris.py

+6
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
INFERENCE_GRAVITON = "inference_graviton"
4545
DATA_WRANGLER_FRAMEWORK = "data-wrangler"
4646
STABILITYAI_FRAMEWORK = "stabilityai"
47+
SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver"
4748

4849

4950
@override_pipeline_parameter_var
@@ -335,6 +336,11 @@ def _get_image_tag(
335336
if key in container_versions:
336337
tag = "-".join([tag, container_versions[key]])
337338

339+
# Triton images don't have a trailing -gpu tag. Only -cpu images do.
340+
if framework == SAGEMAKER_TRITONSERVER_FRAMEWORK:
341+
if processor == "gpu":
342+
tag = tag.rstrip("-gpu")
343+
338344
return tag
339345

340346

tests/unit/sagemaker/image_uris/expected_uris.py

+7
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def djl_framework_uri(repo, account, tag, region=REGION):
8484
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
8585

8686

87+
def sagemaker_triton_framework_uri(repo, account, tag, processor="gpu", region=REGION):
88+
domain = ALTERNATE_DOMAINS.get(region, DOMAIN)
89+
if processor == "cpu":
90+
tag = f"{tag}-cpu"
91+
return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag)
92+
93+
8794
def huggingface_llm_framework_uri(
8895
repo,
8996
account,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
from __future__ import absolute_import
14+
import pytest
15+
from sagemaker import image_uris
16+
from tests.unit.sagemaker.image_uris import expected_uris
17+
18+
INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p2.xlarge"}
19+
20+
21+
@pytest.mark.parametrize(
22+
"load_config_and_file_name",
23+
["sagemaker-tritonserver.json"],
24+
indirect=True,
25+
)
26+
def test_sagemaker_tritonserver_uris(load_config_and_file_name):
27+
config, file_name = load_config_and_file_name
28+
framework = file_name.split(".json")[0]
29+
VERSIONS = config["versions"]
30+
processors = config["processors"]
31+
for version in VERSIONS:
32+
ACCOUNTS = config["versions"][version]["registries"]
33+
tag = config["versions"][version]["tag_prefix"]
34+
for processor in processors:
35+
instance_type = INSTANCE_TYPES[processor]
36+
for region in ACCOUNTS.keys():
37+
_test_sagemaker_tritonserver_uris(
38+
ACCOUNTS[region], region, version, tag, framework, instance_type, processor
39+
)
40+
41+
42+
def _test_sagemaker_tritonserver_uris(
43+
account, region, version, tag, triton_framework, instance_type, processor
44+
):
45+
uri = image_uris.retrieve(
46+
framework=triton_framework, region=region, version=version, instance_type=instance_type
47+
)
48+
expected = expected_uris.sagemaker_triton_framework_uri(
49+
"sagemaker-tritonserver",
50+
account,
51+
tag,
52+
processor,
53+
region,
54+
)
55+
assert expected == uri

0 commit comments

Comments
 (0)