diff --git a/src/sagemaker/image_uri_config/sagemaker-tritonserver.json b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json new file mode 100644 index 0000000000..82397d913e --- /dev/null +++ b/src/sagemaker/image_uri_config/sagemaker-tritonserver.json @@ -0,0 +1,75 @@ +{ + "processors": [ + "cpu", + "gpu" + ], + "scope": [ + "inference" + ], + "versions": { + "23.12": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "23.12-py3" + }, + "24.01": { + "registries": { + "af-south-1": "626614931356", + "il-central-1": "780543022126", + "ap-east-1": "871362719292", + "ap-northeast-1": "763104351884", + "ap-northeast-2": "763104351884", + "ap-northeast-3": "364406365360", + "ap-south-1": "763104351884", + "ap-southeast-1": "763104351884", + "ap-southeast-2": "763104351884", + "ap-southeast-3": "907027046896", + "ca-central-1": "763104351884", + "cn-north-1": "727897471807", + "cn-northwest-1": "727897471807", + "eu-central-1": "763104351884", + "eu-north-1": "763104351884", + "eu-west-1": "763104351884", + "eu-west-2": "763104351884", + "eu-west-3": "763104351884", + "eu-south-1": "692866216735", + "me-south-1": "217643126080", + "sa-east-1": "763104351884", + "us-east-1": "763104351884", + "us-east-2": "763104351884", + "us-west-1": "763104351884", + "us-west-2": "763104351884", + "ca-west-1": "204538143572" + }, + "repository": "sagemaker-tritonserver", + "tag_prefix": "24.01-py3" + } + } +} \ No newline at end of file diff --git a/src/sagemaker/image_uris.py b/src/sagemaker/image_uris.py index 252bf3c504..8498027079 100644 --- a/src/sagemaker/image_uris.py +++ b/src/sagemaker/image_uris.py @@ -44,6 +44,7 @@ INFERENCE_GRAVITON = "inference_graviton" DATA_WRANGLER_FRAMEWORK = "data-wrangler" STABILITYAI_FRAMEWORK = "stabilityai" +SAGEMAKER_TRITONSERVER_FRAMEWORK = "sagemaker-tritonserver" @override_pipeline_parameter_var @@ -335,6 +336,11 @@ def _get_image_tag( if key in container_versions: tag = "-".join([tag, container_versions[key]]) + # Triton images don't have a trailing -gpu tag. Only -cpu images do. + if framework == SAGEMAKER_TRITONSERVER_FRAMEWORK: + if processor == "gpu": + tag = tag.rstrip("-gpu") + return tag diff --git a/tests/unit/sagemaker/image_uris/expected_uris.py b/tests/unit/sagemaker/image_uris/expected_uris.py index 438a00a038..094323ef0b 100644 --- a/tests/unit/sagemaker/image_uris/expected_uris.py +++ b/tests/unit/sagemaker/image_uris/expected_uris.py @@ -84,6 +84,13 @@ def djl_framework_uri(repo, account, tag, region=REGION): return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) +def sagemaker_triton_framework_uri(repo, account, tag, processor="gpu", region=REGION): + domain = ALTERNATE_DOMAINS.get(region, DOMAIN) + if processor == "cpu": + tag = f"{tag}-cpu" + return IMAGE_URI_FORMAT.format(account, region, domain, repo, tag) + + def huggingface_llm_framework_uri( repo, account, diff --git a/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py b/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py new file mode 100644 index 0000000000..7dd75fc3b8 --- /dev/null +++ b/tests/unit/sagemaker/image_uris/test_sagemaker_tritonserver.py @@ -0,0 +1,55 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +from __future__ import absolute_import +import pytest +from sagemaker import image_uris +from tests.unit.sagemaker.image_uris import expected_uris + +INSTANCE_TYPES = {"cpu": "ml.c4.xlarge", "gpu": "ml.p2.xlarge"} + + +@pytest.mark.parametrize( + "load_config_and_file_name", + ["sagemaker-tritonserver.json"], + indirect=True, +) +def test_sagemaker_tritonserver_uris(load_config_and_file_name): + config, file_name = load_config_and_file_name + framework = file_name.split(".json")[0] + VERSIONS = config["versions"] + processors = config["processors"] + for version in VERSIONS: + ACCOUNTS = config["versions"][version]["registries"] + tag = config["versions"][version]["tag_prefix"] + for processor in processors: + instance_type = INSTANCE_TYPES[processor] + for region in ACCOUNTS.keys(): + _test_sagemaker_tritonserver_uris( + ACCOUNTS[region], region, version, tag, framework, instance_type, processor + ) + + +def _test_sagemaker_tritonserver_uris( + account, region, version, tag, triton_framework, instance_type, processor +): + uri = image_uris.retrieve( + framework=triton_framework, region=region, version=version, instance_type=instance_type + ) + expected = expected_uris.sagemaker_triton_framework_uri( + "sagemaker-tritonserver", + account, + tag, + processor, + region, + ) + assert expected == uri