|
19 | 19 |
|
20 | 20 | LMI_VERSIONS = ["0.24.0"]
|
21 | 21 | HF_VERSIONS_MAPPING = {
|
22 |
| - "0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04", |
23 |
| - "0.8.2": "2.0.0-tgi0.8.2-gpu-py39-cu118-ubuntu20.04", |
24 |
| - "0.9.3": "2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04", |
25 |
| - "1.0.3": "2.0.1-tgi1.0.3-gpu-py39-cu118-ubuntu20.04", |
26 |
| - "1.1.0": "2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04", |
27 |
| - "1.2.0": "2.1.1-tgi1.2.0-gpu-py310-cu121-ubuntu20.04", |
28 |
| - "1.3.1": "2.1.1-tgi1.3.1-gpu-py310-cu121-ubuntu20.04", |
29 |
| - "1.3.3": "2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04", |
| 22 | + "gpu": { |
| 23 | + "0.6.0": "2.0.0-tgi0.6.0-gpu-py39-cu118-ubuntu20.04", |
| 24 | + "0.8.2": "2.0.0-tgi0.8.2-gpu-py39-cu118-ubuntu20.04", |
| 25 | + "0.9.3": "2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04", |
| 26 | + "1.0.3": "2.0.1-tgi1.0.3-gpu-py39-cu118-ubuntu20.04", |
| 27 | + "1.1.0": "2.0.1-tgi1.1.0-gpu-py39-cu118-ubuntu20.04", |
| 28 | + "1.2.0": "2.1.1-tgi1.2.0-gpu-py310-cu121-ubuntu20.04", |
| 29 | + "1.3.1": "2.1.1-tgi1.3.1-gpu-py310-cu121-ubuntu20.04", |
| 30 | + "1.3.3": "2.1.1-tgi1.3.3-gpu-py310-cu121-ubuntu20.04", |
| 31 | + }, |
| 32 | + "inf": { |
| 33 | + "0.0.16": "2.1.1-optimum0.0.16-neuronx-py310-ubuntu22.04", |
| 34 | + }, |
30 | 35 | }
|
31 | 36 |
|
32 | 37 |
|
33 |
| -@pytest.mark.parametrize("load_config", ["huggingface-llm.json"], indirect=True) |
| 38 | +@pytest.mark.parametrize( |
| 39 | + "load_config", ["huggingface-llm.json", "huggingface-llm-neuronx.json"], indirect=True |
| 40 | +) |
34 | 41 | def test_huggingface_uris(load_config):
|
35 | 42 | VERSIONS = load_config["inference"]["versions"]
|
| 43 | + device = load_config["inference"]["processors"][0] |
| 44 | + backend = "huggingface-neuronx" if device == "inf" else "huggingface" |
36 | 45 | for version in VERSIONS:
|
37 | 46 | ACCOUNTS = load_config["inference"]["versions"][version]["registries"]
|
38 | 47 | for region in ACCOUNTS.keys():
|
39 |
| - uri = get_huggingface_llm_image_uri("huggingface", region=region, version=version) |
| 48 | + uri = get_huggingface_llm_image_uri(backend, region=region, version=version) |
40 | 49 | expected = expected_uris.huggingface_llm_framework_uri(
|
41 | 50 | "huggingface-pytorch-tgi-inference",
|
42 | 51 | ACCOUNTS[region],
|
43 | 52 | version,
|
44 |
| - HF_VERSIONS_MAPPING[version], |
| 53 | + HF_VERSIONS_MAPPING[device][version], |
45 | 54 | region=region,
|
46 | 55 | )
|
47 | 56 | assert expected == uri
|
|
0 commit comments