Skip to content

Commit caeaeeb

Browse files
bpronanWauplin
andauthored
Xet Upload with byte array (#3035)
* Using upload_bytes from hf_xet * Factoring out the file key * Using the new datastructures * Adding bytes based test * Adding tests for byte array based upload and testing xet backwards compatibility * Test cleanup * Lint * Quality check * Updating hf-xet version in setup.py * Update src/huggingface_hub/_commit_api.py Co-authored-by: Lucain <[email protected]> * Update src/huggingface_hub/_commit_api.py Co-authored-by: Lucain <[email protected]> * Update src/huggingface_hub/hf_api.py Co-authored-by: Lucain <[email protected]> * Update tests/test_xet_upload.py Co-authored-by: Lucain <[email protected]> * Update src/huggingface_hub/hf_api.py Co-authored-by: Lucain <[email protected]> * PR comments * Fixing compilation * Quality * Fixing spelling --------- Co-authored-by: Lucain <[email protected]>
1 parent 2bafd2a commit caeaeeb

File tree

6 files changed

+97
-20
lines changed

6 files changed

+97
-20
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def get_version() -> str:
1414
install_requires = [
1515
"filelock",
1616
"fsspec>=2023.5.0",
17-
"hf-xet>=1.0.2,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'",
17+
"hf-xet>=1.1.0,<2.0.0; platform_machine=='x86_64' or platform_machine=='amd64' or platform_machine=='arm64' or platform_machine=='aarch64'",
1818
"packaging>=20.9",
1919
"pyyaml>=5.1",
2020
"requests",
@@ -56,7 +56,7 @@ def get_version() -> str:
5656
"keras<3.0",
5757
]
5858

59-
extras["hf_xet"] = ["hf_xet>=1.0.2,<2.0.0"]
59+
extras["hf_xet"] = ["hf_xet>=1.1.0,<2.0.0"]
6060

6161
extras["testing"] = (
6262
extras["cli"]

src/huggingface_hub/_commit_api.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def _upload_xet_files(
530530
if len(additions) == 0:
531531
return
532532
# at this point, we know that hf_xet is installed
533-
from hf_xet import upload_files
533+
from hf_xet import upload_bytes, upload_files
534534

535535
try:
536536
xet_connection_info = fetch_xet_connection_info_from_repo_info(
@@ -571,8 +571,10 @@ def token_refresher() -> Tuple[str, int]:
571571
num_chunks_num_digits = int(math.log10(num_chunks)) + 1
572572
for i, chunk in enumerate(chunk_iterable(additions, chunk_size=UPLOAD_BATCH_MAX_NUM_FILES)):
573573
_chunk = [op for op in chunk]
574-
paths = [str(op.path_or_fileobj) for op in _chunk]
575-
expected_size = sum([os.path.getsize(path) for path in paths])
574+
575+
bytes_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, bytes)]
576+
paths_ops = [op for op in _chunk if isinstance(op.path_or_fileobj, (str, Path))]
577+
expected_size = sum(op.upload_info.size for op in bytes_ops + paths_ops)
576578

577579
if num_chunks > 1:
578580
description = f"Uploading Batch [{str(i + 1).zfill(num_chunks_num_digits)}/{num_chunks}]..."
@@ -592,7 +594,24 @@ def token_refresher() -> Tuple[str, int]:
592594
def update_progress(increment: int):
593595
progress.update(increment)
594596

595-
upload_files(paths, xet_endpoint, access_token_info, token_refresher, update_progress, repo_type)
597+
if len(paths_ops) > 0:
598+
upload_files(
599+
[str(op.path_or_fileobj) for op in paths_ops],
600+
xet_endpoint,
601+
access_token_info,
602+
token_refresher,
603+
update_progress,
604+
repo_type,
605+
)
606+
if len(bytes_ops) > 0:
607+
upload_bytes(
608+
[op.path_or_fileobj for op in bytes_ops],
609+
xet_endpoint,
610+
access_token_info,
611+
token_refresher,
612+
update_progress,
613+
repo_type,
614+
)
596615
return
597616

598617

src/huggingface_hub/file_download.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,7 +582,7 @@ def xet_get(
582582
583583
"""
584584
try:
585-
from hf_xet import PyPointerFile, download_files # type: ignore[no-redef]
585+
from hf_xet import PyXetDownloadInfo, download_files # type: ignore[no-redef]
586586
except ImportError:
587587
raise ValueError(
588588
"To use optimized download using Xet storage, you need to install the hf_xet package. "
@@ -597,8 +597,10 @@ def token_refresher() -> Tuple[str, int]:
597597
raise ValueError("Failed to refresh token using xet metadata.")
598598
return connection_info.access_token, connection_info.expiration_unix_epoch
599599

600-
pointer_files = [
601-
PyPointerFile(path=str(incomplete_path.absolute()), hash=xet_file_data.file_hash, filesize=expected_size)
600+
xet_download_info = [
601+
PyXetDownloadInfo(
602+
destination_path=str(incomplete_path.absolute()), hash=xet_file_data.file_hash, file_size=expected_size
603+
)
602604
]
603605

604606
if not displayed_filename:
@@ -623,7 +625,7 @@ def progress_updater(progress_bytes: float):
623625
progress.update(progress_bytes)
624626

625627
download_files(
626-
pointer_files,
628+
xet_download_info,
627629
endpoint=connection_info.endpoint,
628630
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
629631
token_refresher=token_refresher,

src/huggingface_hub/hf_api.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4475,18 +4475,17 @@ def preupload_lfs_files(
44754475
expand="xetEnabled",
44764476
token=token,
44774477
).xet_enabled
4478-
has_binary_data = any(
4479-
isinstance(addition.path_or_fileobj, (bytes, io.BufferedIOBase))
4480-
for addition in new_lfs_additions_to_upload
4478+
has_buffered_io_data = any(
4479+
isinstance(addition.path_or_fileobj, io.BufferedIOBase) for addition in new_lfs_additions_to_upload
44814480
)
4482-
if xet_enabled and not has_binary_data and is_xet_available():
4481+
if xet_enabled and not has_buffered_io_data and is_xet_available():
44834482
logger.info("Uploading files using Xet Storage..")
44844483
_upload_xet_files(**upload_kwargs, create_pr=create_pr) # type: ignore [arg-type]
44854484
else:
44864485
if xet_enabled and is_xet_available():
4487-
if has_binary_data:
4486+
if has_buffered_io_data:
44884487
logger.warning(
4489-
"Uploading files as bytes or binary IO objects is not supported by Xet Storage. "
4488+
"Uploading files as a binary IO buffer is not supported by Xet Storage. "
44904489
"Falling back to HTTP upload."
44914490
)
44924491
_upload_lfs_files(**upload_kwargs, num_threads=num_threads) # type: ignore [arg-type]

tests/test_xet_download.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
from contextlib import contextmanager
33
from pathlib import Path
4+
from typing import Tuple
45
from unittest.mock import DEFAULT, Mock, patch
56

67
from huggingface_hub import snapshot_download
@@ -293,3 +294,42 @@ def test_snapshot_download_cache_reuse(self, tmp_path):
293294

294295
# Verify xet_get was not called (files were cached)
295296
mock_xet_get.assert_not_called()
297+
298+
def test_download_backward_compatibility(self, tmp_path):
299+
"""Test that xet download works with the old pointer file protocol.
300+
301+
Until the next major version of hf-xet is released, we need to support the old
302+
pointer file based download to support old huggingface_hub versions.
303+
"""
304+
305+
file_path = os.path.join(tmp_path, DUMMY_XET_FILE)
306+
307+
file_metadata = get_hf_file_metadata(
308+
hf_hub_url(
309+
repo_id=DUMMY_XET_MODEL_ID,
310+
filename=DUMMY_XET_FILE,
311+
)
312+
)
313+
314+
xet_file_data = file_metadata.xet_file_data
315+
316+
# Mock the response to not include xet metadata
317+
from hf_xet import PyPointerFile, download_files
318+
319+
connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={})
320+
321+
def token_refresher() -> Tuple[str, int]:
322+
connection_info = refresh_xet_connection_info(file_data=xet_file_data, headers={})
323+
return connection_info.access_token, connection_info.expiration_unix_epoch
324+
325+
pointer_files = [PyPointerFile(path=file_path, hash=xet_file_data.file_hash, filesize=file_metadata.size)]
326+
327+
download_files(
328+
pointer_files,
329+
endpoint=connection_info.endpoint,
330+
token_info=(connection_info.access_token, connection_info.expiration_unix_epoch),
331+
token_refresher=token_refresher,
332+
progress_updater=None,
333+
)
334+
335+
assert os.path.exists(file_path)

tests/test_xet_upload.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,21 @@ def test_upload_file_with_bytesio(self, api, tmp_path, repo_url):
136136
downloaded_content = f.read()
137137
assert downloaded_content == self.bin_content
138138

139+
def test_upload_file_with_byte_array(self, api, tmp_path, repo_url):
140+
repo_id = repo_url.repo_id
141+
content = self.bin_content
142+
with assert_upload_mode("xet"):
143+
api.upload_file(
144+
path_or_fileobj=content,
145+
path_in_repo="bytearray_file.bin",
146+
repo_id=repo_id,
147+
)
148+
# Download and verify content
149+
downloaded_file = hf_hub_download(repo_id=repo_id, filename="bytearray_file.bin", cache_dir=tmp_path)
150+
with open(downloaded_file, "rb") as f:
151+
downloaded_content = f.read()
152+
assert downloaded_content == self.bin_content
153+
139154
def test_fallback_to_lfs_when_xet_not_available(self, api, repo_url):
140155
repo_id = repo_url.repo_id
141156
with patch("huggingface_hub.hf_api.is_xet_available", return_value=False):
@@ -284,7 +299,7 @@ def test_hf_xet_with_token_refresher(self, api, tmp_path, repo_url):
284299
285300
This test ensures that the downloaded file is the same as the uploaded file.
286301
"""
287-
from hf_xet import PyPointerFile, download_files
302+
from hf_xet import PyXetDownloadInfo, download_files
288303

289304
filename_in_repo = "binary_file.bin"
290305
repo_id = repo_url.repo_id
@@ -327,13 +342,15 @@ def token_refresher() -> Tuple[str, int]:
327342
mock_token_refresher = MagicMock(side_effect=token_refresher)
328343

329344
incomplete_path = Path(tmp_path) / "file.bin.incomplete"
330-
py_file = [
331-
PyPointerFile(path=str(incomplete_path.absolute()), hash=xet_filedata.file_hash, filesize=expected_size)
345+
file_info = [
346+
PyXetDownloadInfo(
347+
destination_path=str(incomplete_path.absolute()), hash=xet_filedata.file_hash, file_size=expected_size
348+
)
332349
]
333350

334351
# Call the download_files function with the token refresher, set expiration to 0 forcing a refresh
335352
download_files(
336-
py_file,
353+
file_info,
337354
endpoint=xet_connection_info.endpoint,
338355
token_info=(xet_connection_info.access_token, 0),
339356
token_refresher=mock_token_refresher,

0 commit comments

Comments
 (0)