Skip to content

Commit f06a71e

Browse files
authored
model info that works offline (#371)
* offline model info + hub local file layout helpers can be used to fix #372
1 parent 7deb0e3 commit f06a71e

File tree

3 files changed

+276
-1
lines changed

3 files changed

+276
-1
lines changed

api_inference_community/hub.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import json
2+
import logging
3+
import os
4+
import pathlib
5+
import re
6+
from typing import List, Optional
7+
8+
from huggingface_hub import ModelCard, constants, hf_api, try_to_load_from_cache
9+
from huggingface_hub.file_download import repo_folder_name
10+
11+
12+
logger = logging.getLogger(__name__)
13+
14+
15+
def _cached_repo_root_path(cache_dir: pathlib.Path, repo_id: str) -> pathlib.Path:
16+
folder = pathlib.Path(repo_folder_name(repo_id=repo_id, repo_type="model"))
17+
return cache_dir / folder
18+
19+
20+
def cached_revision_path(cache_dir, repo_id, revision) -> pathlib.Path:
21+
22+
error_msg = f"No revision path found for {repo_id}, revision {revision}"
23+
24+
if revision is None:
25+
revision = "main"
26+
27+
repo_cache = _cached_repo_root_path(cache_dir, repo_id)
28+
29+
if not repo_cache.is_dir():
30+
msg = f"Local repo {repo_cache} does not exist"
31+
logger.error(msg)
32+
raise Exception(msg)
33+
34+
refs_dir = repo_cache / "refs"
35+
snapshots_dir = repo_cache / "snapshots"
36+
37+
# Resolve refs (for instance to convert main to the associated commit sha)
38+
if refs_dir.is_dir():
39+
revision_file = refs_dir / revision
40+
if revision_file.exists():
41+
with revision_file.open() as f:
42+
revision = f.read()
43+
44+
# Check if revision folder exists
45+
if not snapshots_dir.exists():
46+
msg = f"No local revision path {snapshots_dir} found for {repo_id}, revision {revision}"
47+
logger.error(msg)
48+
raise Exception(msg)
49+
50+
cached_shas = os.listdir(snapshots_dir)
51+
if revision not in cached_shas:
52+
# No cache for this revision and we won't try to return a random revision
53+
logger.error(error_msg)
54+
raise Exception(error_msg)
55+
56+
return snapshots_dir / revision
57+
58+
59+
def _build_offline_model_info(
60+
repo_id: str, cache_dir: pathlib.Path, revision: str
61+
) -> hf_api.ModelInfo:
62+
63+
logger.info("Rebuilding offline model info for repo %s", repo_id)
64+
65+
# Let's rebuild some partial model info from what we see in cache, info extracted should be enough
66+
# for most use cases
67+
card_path = try_to_load_from_cache(
68+
repo_id=repo_id,
69+
filename="README.md",
70+
cache_dir=cache_dir,
71+
revision=revision,
72+
)
73+
if not isinstance(card_path, str):
74+
raise Exception(
75+
"Unable to rebuild offline model info, no README could be found"
76+
)
77+
78+
card_path = pathlib.Path(card_path)
79+
logger.debug("Loading model card from model readme %s", card_path)
80+
model_card = ModelCard.load(card_path)
81+
card_data = model_card.data.to_dict()
82+
83+
repo = card_path.parent
84+
logger.debug("Repo path %s", repo)
85+
siblings = _build_offline_siblings(repo)
86+
model_info = hf_api.ModelInfo(
87+
private=False,
88+
downloads=0,
89+
likes=0,
90+
id=repo_id,
91+
card_data=card_data,
92+
siblings=siblings,
93+
**card_data,
94+
)
95+
logger.info("Offline model info for repo %s: %s", repo, model_info)
96+
return model_info
97+
98+
99+
def _build_offline_siblings(repo: pathlib.Path) -> List[dict]:
100+
siblings = []
101+
prefix_pattern = re.compile(r"^" + re.escape(str(repo)) + r"(.*)$")
102+
for root, dirs, files in os.walk(repo):
103+
for file in files:
104+
filepath = os.path.join(root, file)
105+
size = os.stat(filepath).st_size
106+
m = prefix_pattern.match(filepath)
107+
if not m:
108+
msg = (
109+
f"File {filepath} does not match expected pattern {prefix_pattern}"
110+
)
111+
logger.error(msg)
112+
raise Exception(msg)
113+
filepath = m.group(1)
114+
filepath = filepath.strip(os.sep)
115+
sibling = dict(rfilename=filepath, size=size)
116+
siblings.append(sibling)
117+
return siblings
118+
119+
120+
def _cached_model_info(
121+
repo_id: str, revision: str, cache_dir: pathlib.Path
122+
) -> hf_api.ModelInfo:
123+
"""
124+
Looks for a json file containing prefetched model info in the revision path.
125+
If none found we just rebuild model info with the local directory files.
126+
Note that this file is not automatically created by hub_download/snapshot_download.
127+
It is just a convenience we add here, just in case the offline info we rebuild from
128+
the local directories would not cover all use cases.
129+
"""
130+
revision_path = cached_revision_path(cache_dir, repo_id, revision)
131+
model_info_basename = "hub_model_info.json"
132+
model_info_path = revision_path / model_info_basename
133+
logger.info("Checking if there are some cached model info at %s", model_info_path)
134+
if os.path.exists(model_info_path):
135+
with open(model_info_path, "r") as f:
136+
o = json.load(f)
137+
r = hf_api.ModelInfo(**o)
138+
logger.debug("Cached model info from file: %s", r)
139+
else:
140+
logger.debug(
141+
"No cached model info file %s found, "
142+
"rebuilding partial model info from cached model files",
143+
model_info_path,
144+
)
145+
# Let's rebuild some partial model info from what we see in cache, info extracted should be enough
146+
# for most use cases
147+
r = _build_offline_model_info(repo_id, cache_dir, revision)
148+
149+
return r
150+
151+
152+
def hub_model_info(
153+
repo_id: str,
154+
revision: Optional[str] = None,
155+
cache_dir: Optional[pathlib.Path] = None,
156+
**kwargs,
157+
) -> hf_api.ModelInfo:
158+
"""
159+
Get Hub model info with offline support
160+
"""
161+
if revision is None:
162+
revision = "main"
163+
164+
if not constants.HF_HUB_OFFLINE:
165+
return hf_api.model_info(repo_id=repo_id, revision=revision, **kwargs)
166+
167+
logger.info("Model info for offline mode")
168+
169+
if cache_dir is None:
170+
cache_dir = pathlib.Path(constants.HF_HUB_CACHE)
171+
172+
return _cached_model_info(repo_id, revision, cache_dir)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ numpy>=1.18.0
33
pydantic>=2
44
parameterized>=0.8.1
55
pillow>=8.2.0
6-
huggingface_hub>=0.5.1
6+
huggingface_hub>=0.20.2
77
datasets>=2.2
88
pytest
99
httpx

tests/test_hub.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import logging
2+
import sys
3+
from unittest import TestCase
4+
5+
from api_inference_community import hub
6+
from huggingface_hub import constants, hf_api, snapshot_download
7+
8+
9+
logger = logging.getLogger(__name__)
10+
logger.level = logging.DEBUG
11+
stream_handler = logging.StreamHandler(sys.stdout)
12+
logger.addHandler(stream_handler)
13+
14+
15+
class HubTestCase(TestCase):
16+
def test_offline_model_info1(self):
17+
repo_id = "google/t5-efficient-tiny"
18+
revision = "3441d7e8bf3f89841f366d39452b95200416e4a9"
19+
bak_value = constants.HF_HUB_OFFLINE
20+
try:
21+
# with tempfile.TemporaryDirectory() as cache_dir:
22+
# logger.info("Cache directory %s", cache_dir)
23+
dirpath = snapshot_download(repo_id=repo_id, revision=revision)
24+
logger.info("Snapshot downloaded at %s", dirpath)
25+
constants.HF_HUB_OFFLINE = True
26+
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
27+
finally:
28+
constants.HF_HUB_OFFLINE = bak_value
29+
30+
logger.info("Model info %s", model_info)
31+
self.assertIsInstance(model_info, hf_api.ModelInfo)
32+
self.assertEqual(model_info.id, repo_id)
33+
self.assertEqual(model_info.downloads, 0)
34+
self.assertEqual(model_info.likes, 0)
35+
self.assertEqual(len(model_info.siblings), 12)
36+
self.assertIn("pytorch_model.bin", [s.rfilename for s in model_info.siblings])
37+
self.assertFalse(model_info.private)
38+
self.assertEqual(model_info.license, "apache-2.0") # noqa
39+
self.assertEqual(model_info.tags, ["deep-narrow"])
40+
self.assertIsNone(model_info.library_name)
41+
42+
logger.info("Model card data %s", model_info.card_data)
43+
self.assertEqual(model_info.card_data, model_info.cardData)
44+
self.assertEqual(model_info.card_data.license, "apache-2.0")
45+
self.assertEqual(model_info.card_data.tags, ["deep-narrow"])
46+
47+
def test_offline_model_info2(self):
48+
repo_id = "dfurman/Mixtral-8x7B-peft-v0.1"
49+
revision = "8908d586219993ec79949acaef566363a7c7864c"
50+
bak_value = constants.HF_HUB_OFFLINE
51+
try:
52+
# with tempfile.TemporaryDirectory() as cache_dir:
53+
# logger.info("Cache directory %s", cache_dir)
54+
dirpath = snapshot_download(repo_id=repo_id, revision=revision)
55+
logger.info("Snapshot downloaded at %s", dirpath)
56+
constants.HF_HUB_OFFLINE = True
57+
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
58+
finally:
59+
constants.HF_HUB_OFFLINE = bak_value
60+
61+
logger.info("Model info %s", model_info)
62+
self.assertIsInstance(model_info, hf_api.ModelInfo)
63+
self.assertEqual(model_info.id, repo_id)
64+
self.assertEqual(model_info.downloads, 0)
65+
self.assertEqual(model_info.likes, 0)
66+
self.assertEqual(len(model_info.siblings), 9)
67+
self.assertFalse(model_info.private)
68+
self.assertEqual(model_info.license, "apache-2.0") # noqa
69+
self.assertEqual(model_info.tags, ["mistral"])
70+
self.assertEqual(model_info.library_name, "peft")
71+
self.assertEqual(model_info.pipeline_tag, "text-generation")
72+
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings])
73+
logger.info("Model card data %s", model_info.card_data)
74+
self.assertEqual(model_info.card_data, model_info.cardData)
75+
self.assertEqual(model_info.card_data.license, "apache-2.0")
76+
self.assertEqual(model_info.card_data.tags, ["mistral"])
77+
78+
def test_online_model_info(self):
79+
repo_id = "dfurman/Mixtral-8x7B-Instruct-v0.1"
80+
revision = "8908d586219993ec79949acaef566363a7c7864c"
81+
bak_value = constants.HF_HUB_OFFLINE
82+
try:
83+
constants.HF_HUB_OFFLINE = False
84+
model_info = hub.hub_model_info(repo_id=repo_id, revision=revision)
85+
finally:
86+
constants.HF_HUB_OFFLINE = bak_value
87+
88+
logger.info("Model info %s", model_info)
89+
self.assertIsInstance(model_info, hf_api.ModelInfo)
90+
self.assertEqual(model_info.id, repo_id)
91+
self.assertGreater(model_info.downloads, 0)
92+
self.assertGreater(model_info.likes, 0)
93+
self.assertEqual(len(model_info.siblings), 9)
94+
self.assertFalse(model_info.private)
95+
self.assertGreater(model_info.tags, ["peft", "safetensors", "mistral"])
96+
self.assertEqual(model_info.library_name, "peft")
97+
self.assertEqual(model_info.pipeline_tag, "text-generation")
98+
self.assertIn(".gitattributes", [s.rfilename for s in model_info.siblings])
99+
logger.info("Model card data %s", model_info.card_data)
100+
self.assertEqual(model_info.card_data, model_info.cardData)
101+
self.assertEqual(model_info.card_data.license, "apache-2.0")
102+
self.assertEqual(model_info.card_data.tags, ["mistral"])
103+
self.assertIsNone(model_info.safetensors)

0 commit comments

Comments
 (0)