|
1 | 1 | import os
|
2 | 2 | from pathlib import Path
|
3 |
| -from typing import Dict, List, Literal, Optional, Union |
| 3 | +from typing import Dict, Iterable, List, Literal, Optional, Union |
4 | 4 |
|
5 | 5 | import requests
|
6 | 6 | from tqdm.auto import tqdm as base_tqdm
|
|
15 | 15 | RevisionNotFoundError,
|
16 | 16 | )
|
17 | 17 | from .file_download import REGEX_COMMIT_HASH, hf_hub_download, repo_folder_name
|
18 |
| -from .hf_api import DatasetInfo, HfApi, ModelInfo, SpaceInfo |
| 18 | +from .hf_api import DatasetInfo, HfApi, ModelInfo, RepoFile, SpaceInfo |
19 | 19 | from .utils import OfflineModeIsEnabled, filter_repo_objects, logging, validate_hf_hub_args
|
20 | 20 | from .utils import tqdm as hf_tqdm
|
21 | 21 |
|
22 | 22 |
|
23 | 23 | logger = logging.get_logger(__name__)
|
24 | 24 |
|
| 25 | +VERY_LARGE_REPO_THRESHOLD = 50000 # After this limit, we don't consider `repo_info.siblings` to be reliable enough |
| 26 | + |
25 | 27 |
|
26 | 28 | @validate_hf_hub_args
|
27 | 29 | def snapshot_download(
|
@@ -145,20 +147,22 @@ def snapshot_download(
|
145 | 147 |
|
146 | 148 | storage_folder = os.path.join(cache_dir, repo_folder_name(repo_id=repo_id, repo_type=repo_type))
|
147 | 149 |
|
| 150 | + api = HfApi( |
| 151 | + library_name=library_name, |
| 152 | + library_version=library_version, |
| 153 | + user_agent=user_agent, |
| 154 | + endpoint=endpoint, |
| 155 | + headers=headers, |
| 156 | + token=token, |
| 157 | + ) |
| 158 | + |
148 | 159 | repo_info: Union[ModelInfo, DatasetInfo, SpaceInfo, None] = None
|
149 | 160 | api_call_error: Optional[Exception] = None
|
150 | 161 | if not local_files_only:
|
151 | 162 | # try/except logic to handle different errors => taken from `hf_hub_download`
|
152 | 163 | try:
|
153 | 164 | # if we have internet connection we want to list files to download
|
154 |
| - api = HfApi( |
155 |
| - library_name=library_name, |
156 |
| - library_version=library_version, |
157 |
| - user_agent=user_agent, |
158 |
| - endpoint=endpoint, |
159 |
| - headers=headers, |
160 |
| - ) |
161 |
| - repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision, token=token) |
| 165 | + repo_info = api.repo_info(repo_id=repo_id, repo_type=repo_type, revision=revision) |
162 | 166 | except (requests.exceptions.SSLError, requests.exceptions.ProxyError):
|
163 | 167 | # Actually raise for those subclasses of ConnectionError
|
164 | 168 | raise
|
@@ -251,13 +255,31 @@ def snapshot_download(
|
251 | 255 | # => let's download the files!
|
252 | 256 | assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
|
253 | 257 | assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
|
254 |
| - filtered_repo_files = list( |
255 |
| - filter_repo_objects( |
256 |
| - items=[f.rfilename for f in repo_info.siblings], |
257 |
| - allow_patterns=allow_patterns, |
258 |
| - ignore_patterns=ignore_patterns, |
| 258 | + |
| 259 | + # Corner case: on very large repos, the siblings list in `repo_info` might not contain all files. |
| 260 | + # In that case, we need to use the `list_repo_tree` method to prevent caching issues. |
| 261 | + repo_files: Iterable[str] = [f.rfilename for f in repo_info.siblings] |
| 262 | + has_many_files = len(repo_info.siblings) > VERY_LARGE_REPO_THRESHOLD |
| 263 | + if has_many_files: |
| 264 | + logger.info("The repo has more than 50,000 files. Using `list_repo_tree` to ensure all files are listed.") |
| 265 | + repo_files = ( |
| 266 | + f.rfilename |
| 267 | + for f in api.list_repo_tree(repo_id=repo_id, recursive=True, revision=revision, repo_type=repo_type) |
| 268 | + if isinstance(f, RepoFile) |
259 | 269 | )
|
| 270 | + |
| 271 | + filtered_repo_files: Iterable[str] = filter_repo_objects( |
| 272 | + items=repo_files, |
| 273 | + allow_patterns=allow_patterns, |
| 274 | + ignore_patterns=ignore_patterns, |
260 | 275 | )
|
| 276 | + |
| 277 | + if not has_many_files: |
| 278 | + filtered_repo_files = list(filtered_repo_files) |
| 279 | + tqdm_desc = f"Fetching {len(filtered_repo_files)} files" |
| 280 | + else: |
| 281 | + tqdm_desc = "Fetching ... files" |
| 282 | + |
261 | 283 | commit_hash = repo_info.sha
|
262 | 284 | snapshot_folder = os.path.join(storage_folder, "snapshots", commit_hash)
|
263 | 285 | # if passed revision is not identical to commit_hash
|
@@ -305,7 +327,7 @@ def _inner_hf_hub_download(repo_file: str):
|
305 | 327 | thread_map(
|
306 | 328 | _inner_hf_hub_download,
|
307 | 329 | filtered_repo_files,
|
308 |
| - desc=f"Fetching {len(filtered_repo_files)} files", |
| 330 | + desc=tqdm_desc, |
309 | 331 | max_workers=max_workers,
|
310 | 332 | # User can use its own tqdm class or the default one from `huggingface_hub.utils`
|
311 | 333 | tqdm_class=tqdm_class or hf_tqdm,
|
|
0 commit comments