|
1 | 1 | import functools
|
2 | 2 | import logging
|
3 | 3 | import os
|
| 4 | +import pathlib |
4 | 5 | from typing import Dict, Type
|
5 | 6 |
|
| 7 | +from api_inference_community import hub |
6 | 8 | from api_inference_community.routes import pipeline_route, status_ok
|
7 | 9 | from app.pipelines import Pipeline, TextClassificationPipeline
|
| 10 | +from huggingface_hub import constants |
8 | 11 | from starlette.applications import Starlette
|
9 | 12 | from starlette.middleware import Middleware
|
10 | 13 | from starlette.middleware.gzip import GZipMiddleware
|
11 | 14 | from starlette.routing import Route
|
12 | 15 |
|
13 | 16 |
|
14 | 17 | TASK = os.getenv("TASK")
|
15 |
| -MODEL_ID = os.getenv("MODEL_ID") |
16 | 18 |
|
17 | 19 |
|
| 20 | +def get_model_id(): |
| 21 | + m_id = os.getenv("MODEL_ID") |
| 22 | + # Workaround, when sentence_transformers handles properly this env variable |
| 23 | + # this should not be needed anymore |
| 24 | + if constants.HF_HUB_OFFLINE: |
| 25 | + cache_dir = pathlib.Path(constants.HF_HUB_CACHE) |
| 26 | + m_id = hub.cached_revision_path( |
| 27 | + cache_dir=cache_dir, repo_id=m_id, revision=os.getenv("REVISION") |
| 28 | + ) |
| 29 | + return m_id |
| 30 | + |
| 31 | + |
| 32 | +MODEL_ID = get_model_id() |
| 33 | + |
18 | 34 | logger = logging.getLogger(__name__)
|
19 | 35 |
|
20 | 36 |
|
|
40 | 56 | @functools.lru_cache()
|
41 | 57 | def get_pipeline() -> Pipeline:
|
42 | 58 | task = os.environ["TASK"]
|
43 |
| - model_id = os.environ["MODEL_ID"] |
| 59 | + model_id = MODEL_ID |
44 | 60 | if task not in ALLOWED_TASKS:
|
45 | 61 | raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
|
46 | 62 | return ALLOWED_TASKS[task](model_id)
|
|
0 commit comments