Skip to content

Commit e74c506

Browse files
authored
Fix setfit in offline mode (#378)
1 parent f06a71e commit e74c506

File tree

2 files changed

+21
-5
lines changed

2 files changed

+21
-5
lines changed

docker_images/setfit/app/main.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,36 @@
11
import functools
22
import logging
33
import os
4+
import pathlib
45
from typing import Dict, Type
56

7+
from api_inference_community import hub
68
from api_inference_community.routes import pipeline_route, status_ok
79
from app.pipelines import Pipeline, TextClassificationPipeline
10+
from huggingface_hub import constants
811
from starlette.applications import Starlette
912
from starlette.middleware import Middleware
1013
from starlette.middleware.gzip import GZipMiddleware
1114
from starlette.routing import Route
1215

1316

1417
TASK = os.getenv("TASK")
15-
MODEL_ID = os.getenv("MODEL_ID")
1618

1719

20+
def get_model_id():
21+
m_id = os.getenv("MODEL_ID")
22+
# Workaround, when sentence_transformers handles properly this env variable
23+
# this should not be needed anymore
24+
if constants.HF_HUB_OFFLINE:
25+
cache_dir = pathlib.Path(constants.HF_HUB_CACHE)
26+
m_id = hub.cached_revision_path(
27+
cache_dir=cache_dir, repo_id=m_id, revision=os.getenv("REVISION")
28+
)
29+
return m_id
30+
31+
32+
MODEL_ID = get_model_id()
33+
1834
logger = logging.getLogger(__name__)
1935

2036

@@ -40,7 +56,7 @@
4056
@functools.lru_cache()
4157
def get_pipeline() -> Pipeline:
4258
task = os.environ["TASK"]
43-
model_id = os.environ["MODEL_ID"]
59+
model_id = MODEL_ID
4460
if task not in ALLOWED_TASKS:
4561
raise EnvironmentError(f"{task} is not a valid pipeline for model : {model_id}")
4662
return ALLOWED_TASKS[task](model_id)

docker_images/setfit/requirements.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
starlette==0.27.0
2-
api-inference-community==0.0.32
3-
huggingface_hub==0.19.4
4-
setfit==1.0.1
2+
git+https://github.com/huggingface/api-inference-community.git@f06a71e72e92caeebabaeced979eacb3542bf2ca
3+
huggingface_hub==0.20.2
4+
setfit==1.0.1

0 commit comments

Comments
 (0)