Skip to content

Commit 8b5cce2

Browse files
committed
TGI NeuronX
1 parent 668e65d commit 8b5cce2

File tree

4 files changed

+71
-13
lines changed

4 files changed

+71
-13
lines changed

src/sagemaker/huggingface/llm_utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,15 @@ def get_huggingface_llm_image_uri(
5757
version=version,
5858
image_scope="inference",
5959
)
60-
if backend == "lmi":
60+
elif backend == "huggingface-neuronx":
61+
return image_uris.retrieve(
62+
"huggingface-llm-neuronx",
63+
region=region,
64+
version=version,
65+
image_scope="inference",
66+
inference_tool="neuronx",
67+
)
68+
elif backend == "lmi":
6169
version = version or "0.24.0"
6270
return image_uris.retrieve(framework="djl-deepspeed", region=region, version=version)
6371
raise ValueError("Unsupported backend: %s" % backend)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
{
2+
"inference": {
3+
"processors": [
4+
"inf"
5+
],
6+
"version_aliases": {
7+
"0.0": "0.0.16"
8+
},
9+
"versions": {
10+
"0.0.16": {
11+
"py_versions": [
12+
"py310"
13+
],
14+
"registries": {
15+
"ap-northeast-1": "763104351884",
16+
"ap-south-1": "763104351884",
17+
"ap-south-2": "772153158452",
18+
"ap-southeast-1": "763104351884",
19+
"ap-southeast-2": "763104351884",
20+
"ap-southeast-4": "457447274322",
21+
"eu-central-1": "763104351884",
22+
"eu-central-2": "380420809688",
23+
"eu-south-2": "503227376785",
24+
"eu-west-1": "763104351884",
25+
"eu-west-3": "763104351884",
26+
"il-central-1": "780543022126",
27+
"sa-east-1": "763104351884",
28+
"us-east-1": "763104351884",
29+
"us-east-2": "763104351884",
30+
"us-west-2": "763104351884",
31+
"ca-west-1": "204538143572"
32+
},
33+
"tag_prefix": "2.1.1-optimum0.0.16",
34+
"repository": "huggingface-pytorch-tgi-inference",
35+
"container_version": {
36+
"inf": "ubuntu22.04"
37+
}
38+
}
39+
}
40+
}
41+
}

src/sagemaker/image_uris.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ECR_URI_TEMPLATE = "{registry}.dkr.{hostname}/{repository}"
3838
HUGGING_FACE_FRAMEWORK = "huggingface"
3939
HUGGING_FACE_LLM_FRAMEWORK = "huggingface-llm"
40+
HUGGING_FACE_LLM_NEURONX_FRAMEWORK = "huggingface-llm-neuronx"
4041
XGBOOST_FRAMEWORK = "xgboost"
4142
SKLEARN_FRAMEWORK = "sklearn"
4243
TRAINIUM_ALLOWED_FRAMEWORKS = "pytorch"
@@ -167,7 +168,7 @@ def retrieve(
167168
)
168169
else:
169170
_framework = framework
170-
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
171+
if framework == HUGGING_FACE_FRAMEWORK or framework in TRAINIUM_ALLOWED_FRAMEWORKS:
171172
inference_tool = _get_inference_tool(inference_tool, instance_type)
172173
if inference_tool in ["neuron", "neuronx"]:
173174
_framework = f"{framework}-{inference_tool}"
@@ -470,6 +471,7 @@ def _validate_version_and_set_if_needed(version, config, framework):
470471
if version is None and framework in [
471472
DATA_WRANGLER_FRAMEWORK,
472473
HUGGING_FACE_LLM_FRAMEWORK,
474+
HUGGING_FACE_LLM_NEURONX_FRAMEWORK,
473475
STABILITYAI_FRAMEWORK,
474476
]:
475477
version = _get_latest_versions(available_versions)

tests/unit/sagemaker/image_uris/test_huggingface_llm.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -19,29 +19,36 @@
1919

2020
LMI_VERSIONS = ["0.24.0"]
2121
HF_VERSIONS_MAPPING = {
22-
"0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04",
23-
"0.8.2": "2.0.0-tgi0.8.2-gpu-py39-cu118-ubuntu20.04",
24-
"0.9.3": "2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04",
25-
"1.0.3": "2.0.1-tgi1.0.3-gpu-py39-cu118-ubuntu20.04",
26-
"1.1.0": "2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04",
27-
"1.2.0": "2.1.1-tgi1.2.0-gpu-py310-cu121-ubuntu20.04",
28-
"1.3.1": "2.1.1-tgi1.3.1-gpu-py310-cu121-ubuntu20.04",
29-
"1.3.3": "2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04",
22+
"gpu": {
23+
"0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04",
24+
"0.8.2": "2.0.0-tgi0.8.2-gpu-py39-cu118-ubuntu20.04",
25+
"0.9.3": "2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04",
26+
"1.0.3": "2.0.1-tgi1.0.3-gpu-py39-cu118-ubuntu20.04",
27+
"1.1.0": "2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04",
28+
"1.2.0": "2.1.1-tgi1.2.0-gpu-py310-cu121-ubuntu20.04",
29+
"1.3.1": "2.1.1-tgi1.3.1-gpu-py310-cu121-ubuntu20.04",
30+
"1.3.3": "2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04",
31+
},
32+
"inf": {
33+
"0.0.16": "2.1.1-optimum0.0.16-neuronx-py310-ubuntu22.04",
34+
}
3035
}
3136

3237

33-
@pytest.mark.parametrize("load_config", ["huggingface-llm.json"], indirect=True)
38+
@pytest.mark.parametrize("load_config", ["huggingface-llm.json", "huggingface-llm-neuronx.json"], indirect=True)
3439
def test_huggingface_uris(load_config):
3540
VERSIONS = load_config["inference"]["versions"]
41+
device = load_config["inference"]["processors"][0]
42+
backend = "huggingface-neuronx" if device == "inf" else "huggingface"
3643
for version in VERSIONS:
3744
ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
3845
for region in ACCOUNTS.keys():
39-
uri = get_huggingface_llm_image_uri("huggingface", region=region, version=version)
46+
uri = get_huggingface_llm_image_uri(backend, region=region, version=version)
4047
expected = expected_uris.huggingface_llm_framework_uri(
4148
"huggingface-pytorch-tgi-inference",
4249
ACCOUNTS[region],
4350
version,
44-
HF_VERSIONS_MAPPING[version],
51+
HF_VERSIONS_MAPPING[device][version],
4552
region=region,
4653
)
4754
assert expected == uri

0 commit comments

Comments
 (0)