Skip to content

Commit 6f09ac7

Browse files
committed
Fix snapshot_download on very large repo (>50k files) (#3122)
* Fix snapshot_download on very large repo (>50k files) * use iterators * fix typing issues
1 parent b70c474 commit 6f09ac7

File tree

2 files changed

+39
-17
lines changed

2 files changed

+39
-17
lines changed

src/huggingface_hub/_snapshot_download.py

Lines changed: 38 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import os
22
from pathlib import Path
3-
from typing import Dict, List, Literal, Optional, Union
3+
from typing import Dict, Iterable, List, Literal, Optional, Union
44

55
import requests
66
from tqdm.auto import tqdm as base_tqdm
@@ -15,13 +15,15 @@
1515
RevisionNotFoundError,
1616
)
1717
from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
18-
from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo
18+
from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo
1919
from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
2020
from .utils import tqdm as hf_tqdm
2121

2222

2323
logger = logging.get_logger(__name__)
2424

25+
VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough
26+
2527

2628
@validate_hf_hub_args
2729
def snapshot_download(
@@ -145,20 +147,22 @@ def snapshot_download(
145147

146148
storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
147149

150+
api = HfApi(
151+
library_name=library_name,
152+
library_version=library_version,
153+
user_agent=user_agent,
154+
endpoint=endpoint,
155+
headers=headers,
156+
token=token,
157+
)
158+
148159
repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
149160
api_call_error: Optional[Exception] = None
150161
if not local_files_only:
151162
# try/except logic to handle different errors => taken from `hf_hub_download`
152163
try:
153164
# if we have internet connection we want to list files to download
154-
api = HfApi(
155-
library_name=library_name,
156-
library_version=library_version,
157-
user_agent=user_agent,
158-
endpoint=endpoint,
159-
headers=headers,
160-
)
161-
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token)
165+
repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision)
162166
except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
163167
# Actually raise for those subclasses of ConnectionError
164168
raise
@@ -251,13 +255,31 @@ def snapshot_download(
251255
# => let's download the files!
252256
assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
253257
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
254-
filtered_repo_files = list(
255-
filter_repo_objects(
256-
items=[f.rfilename for f in repo_info.siblings],
257-
allow_patterns=allow_patterns,
258-
ignore_patterns=ignore_patterns,
258+
259+
# Corner case: on very large repos, the siblings list in `repo_info` might not contain all files.
260+
# In that case, we need to use the `list_repo_tree` method to prevent caching issues.
261+
repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings]
262+
has_many_files = len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD
263+
if has_many_files:
264+
logger.info("The repo has more than 50,000 files. Using `list_repo_tree` to ensure all files are listed.")
265+
repo_files = (
266+
f.rfilename
267+
for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type)
268+
if isinstance(f, RepoFile)
259269
)
270+
271+
filtered_repo_files: Iterable[str] = filter_repo_objects(
272+
items=repo_files,
273+
allow_patterns=allow_patterns,
274+
ignore_patterns=ignore_patterns,
260275
)
276+
277+
if not has_many_files:
278+
filtered_repo_files = list(filtered_repo_files)
279+
tqdm_desc = f"Fetching {len(filtered_repo_files)} files"
280+
else:
281+
tqdm_desc = "Fetching ... files"
282+
261283
commit_hash = repo_info.sha
262284
snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
263285
# if passed revision is not identical to commit_hash
@@ -305,7 +327,7 @@ def _inner_hf_hub_download(repo_file: str):
305327
thread_map(
306328
_inner_hf_hub_download,
307329
filtered_repo_files,
308-
desc=f"Fetching {len(filtered_repo_files)} files",
330+
desc=tqdm_desc,
309331
max_workers=max_workers,
310332
# User can use its own tqdm class or the default one from `huggingface_hub.utils`
311333
tqdm_class=tqdm_class or hf_tqdm,

src/huggingface_hub/serialization/_torch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ def save_torch_state_dict(
246246
shared_tensors_to_discard=shared_tensors_to_discard,
247247
)
248248
else:
249-
from torch import save as save_file_fn # type: ignore[assignment]
249+
from torch import save as save_file_fn # type: ignore[assignment, no-redef]
250250

251251
logger.warning(
252252
"You are using unsafe serialization. Due to security reasons, it is recommended not to load "

0 commit comments

Comments
 (0)