Skip to content

Commit 530f844

Browse files
twoertweinyehoshuadimarsky
authored andcommitted
CLN: Share code between _BytesTarFile and _BytesZipFile (pandas-dev#47153)
* CLN: do not suppress errors when closing file handles * more cleanups * move to 1.4.3 * only cleanup * add test case
1 parent 107250a commit 530f844

File tree

2 files changed

+101
-161
lines changed

2 files changed

+101
-161
lines changed

pandas/io/common.py

+95-161
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
"""Common IO api utilities"""
22
from __future__ import annotations
33

4+
from abc import (
5+
ABC,
6+
abstractmethod,
7+
)
48
import bz2
59
import codecs
610
from collections import abc
@@ -10,7 +14,6 @@
1014
from io import (
1115
BufferedIOBase,
1216
BytesIO,
13-
FileIO,
1417
RawIOBase,
1518
StringIO,
1619
TextIOBase,
@@ -712,8 +715,7 @@ def get_handle(
712715

713716
# GZ Compression
714717
if compression == "gzip":
715-
if is_path:
716-
assert isinstance(handle, str)
718+
if isinstance(handle, str):
717719
# error: Incompatible types in assignment (expression has type
718720
# "GzipFile", variable has type "Union[str, BaseBuffer]")
719721
handle = gzip.GzipFile( # type: ignore[assignment]
@@ -742,18 +744,18 @@ def get_handle(
742744

743745
# ZIP Compression
744746
elif compression == "zip":
745-
# error: Argument 1 to "_BytesZipFile" has incompatible type "Union[str,
746-
# BaseBuffer]"; expected "Union[Union[str, PathLike[str]],
747+
# error: Argument 1 to "_BytesZipFile" has incompatible type
748+
# "Union[str, BaseBuffer]"; expected "Union[Union[str, PathLike[str]],
747749
# ReadBuffer[bytes], WriteBuffer[bytes]]"
748750
handle = _BytesZipFile(
749751
handle, ioargs.mode, **compression_args # type: ignore[arg-type]
750752
)
751-
if handle.mode == "r":
753+
if handle.buffer.mode == "r":
752754
handles.append(handle)
753-
zip_names = handle.namelist()
755+
zip_names = handle.buffer.namelist()
754756
if len(zip_names) == 1:
755-
handle = handle.open(zip_names.pop())
756-
elif len(zip_names) == 0:
757+
handle = handle.buffer.open(zip_names.pop())
758+
elif not zip_names:
757759
raise ValueError(f"Zero files found in ZIP file {path_or_buf}")
758760
else:
759761
raise ValueError(
@@ -763,21 +765,25 @@ def get_handle(
763765

764766
# TAR Encoding
765767
elif compression == "tar":
766-
if "mode" not in compression_args:
767-
compression_args["mode"] = ioargs.mode
768-
if is_path:
769-
handle = _BytesTarFile.open(name=handle, **compression_args)
768+
compression_args.setdefault("mode", ioargs.mode)
769+
if isinstance(handle, str):
770+
handle = _BytesTarFile(name=handle, **compression_args)
770771
else:
771-
handle = _BytesTarFile.open(fileobj=handle, **compression_args)
772+
# error: Argument "fileobj" to "_BytesTarFile" has incompatible
773+
# type "BaseBuffer"; expected "Union[ReadBuffer[bytes],
774+
# WriteBuffer[bytes], None]"
775+
handle = _BytesTarFile(
776+
fileobj=handle, **compression_args # type: ignore[arg-type]
777+
)
772778
assert isinstance(handle, _BytesTarFile)
773-
if handle.mode == "r":
779+
if "r" in handle.buffer.mode:
774780
handles.append(handle)
775-
files = handle.getnames()
781+
files = handle.buffer.getnames()
776782
if len(files) == 1:
777-
file = handle.extractfile(files[0])
783+
file = handle.buffer.extractfile(files[0])
778784
assert file is not None
779785
handle = file
780-
elif len(files) == 0:
786+
elif not files:
781787
raise ValueError(f"Zero files found in TAR archive {path_or_buf}")
782788
else:
783789
raise ValueError(
@@ -876,195 +882,123 @@ def get_handle(
876882
)
877883

878884

879-
# error: Definition of "__exit__" in base class "TarFile" is incompatible with
880-
# definition in base class "BytesIO" [misc]
881-
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
882-
# definition in base class "BytesIO" [misc]
883-
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
884-
# definition in base class "BinaryIO" [misc]
885-
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
886-
# definition in base class "IO" [misc]
887-
# error: Definition of "read" in base class "TarFile" is incompatible with
888-
# definition in base class "BytesIO" [misc]
889-
# error: Definition of "read" in base class "TarFile" is incompatible with
890-
# definition in base class "IO" [misc]
891-
class _BytesTarFile(tarfile.TarFile, BytesIO): # type: ignore[misc]
885+
# error: Definition of "__enter__" in base class "IOBase" is incompatible
886+
# with definition in base class "BinaryIO"
887+
class _BufferedWriter(BytesIO, ABC): # type: ignore[misc]
892888
"""
893-
Wrapper for standard library class TarFile and allow the returned file-like
894-
handle to accept byte strings via `write` method.
895-
896-
BytesIO provides attributes of file-like object and TarFile.addfile writes
897-
bytes strings into a member of the archive.
889+
Some objects do not support multiple .write() calls (TarFile and ZipFile).
890+
This wrapper writes to the underlying buffer on close.
898891
"""
899892

900-
# GH 17778
893+
@abstractmethod
894+
def write_to_buffer(self):
895+
...
896+
897+
def close(self) -> None:
898+
if self.closed:
899+
# already closed
900+
return
901+
if self.getvalue():
902+
# write to buffer
903+
self.seek(0)
904+
# error: "_BufferedWriter" has no attribute "buffer"
905+
with self.buffer: # type: ignore[attr-defined]
906+
self.write_to_buffer()
907+
else:
908+
# error: "_BufferedWriter" has no attribute "buffer"
909+
self.buffer.close() # type: ignore[attr-defined]
910+
super().close()
911+
912+
913+
class _BytesTarFile(_BufferedWriter):
901914
def __init__(
902915
self,
903-
name: str | bytes | os.PathLike[str] | os.PathLike[bytes],
904-
mode: Literal["r", "a", "w", "x"],
905-
fileobj: FileIO,
916+
name: str | None = None,
917+
mode: Literal["r", "a", "w", "x"] = "r",
918+
fileobj: ReadBuffer[bytes] | WriteBuffer[bytes] | None = None,
906919
archive_name: str | None = None,
907920
**kwargs,
908-
):
921+
) -> None:
922+
super().__init__()
909923
self.archive_name = archive_name
910-
self.multiple_write_buffer: BytesIO | None = None
911-
self._closing = False
912-
913-
super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs)
924+
self.name = name
925+
# error: Argument "fileobj" to "open" of "TarFile" has incompatible
926+
# type "Union[ReadBuffer[bytes], WriteBuffer[bytes], None]"; expected
927+
# "Optional[IO[bytes]]"
928+
self.buffer = tarfile.TarFile.open(
929+
name=name,
930+
mode=self.extend_mode(mode),
931+
fileobj=fileobj, # type: ignore[arg-type]
932+
**kwargs,
933+
)
914934

915-
@classmethod
916-
def open(cls, name=None, mode="r", **kwargs):
935+
def extend_mode(self, mode: str) -> str:
917936
mode = mode.replace("b", "")
918-
return super().open(name=name, mode=cls.extend_mode(name, mode), **kwargs)
919-
920-
@classmethod
921-
def extend_mode(
922-
cls, name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], mode: str
923-
) -> str:
924937
if mode != "w":
925938
return mode
926-
if isinstance(name, (os.PathLike, str)):
927-
filename = Path(name)
928-
if filename.suffix == ".gz":
929-
return mode + ":gz"
930-
elif filename.suffix == ".xz":
931-
return mode + ":xz"
932-
elif filename.suffix == ".bz2":
933-
return mode + ":bz2"
939+
if self.name is not None:
940+
suffix = Path(self.name).suffix
941+
if suffix in (".gz", ".xz", ".bz2"):
942+
mode = f"{mode}:{suffix[1:]}"
934943
return mode
935944

936-
def infer_filename(self):
945+
def infer_filename(self) -> str | None:
937946
"""
938947
If an explicit archive_name is not given, we still want the file inside the zip
939948
file not to be named something.tar, because that causes confusion (GH39465).
940949
"""
941-
if isinstance(self.name, (os.PathLike, str)):
942-
# error: Argument 1 to "Path" has
943-
# incompatible type "Union[str, PathLike[str], PathLike[bytes]]";
944-
# expected "Union[str, PathLike[str]]" [arg-type]
945-
filename = Path(self.name) # type: ignore[arg-type]
946-
if filename.suffix == ".tar":
947-
return filename.with_suffix("").name
948-
if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]:
949-
return filename.with_suffix("").with_suffix("").name
950-
return filename.name
951-
return None
952-
953-
def write(self, data):
954-
# buffer multiple write calls, write on flush
955-
if self.multiple_write_buffer is None:
956-
self.multiple_write_buffer = BytesIO()
957-
self.multiple_write_buffer.write(data)
950+
if self.name is None:
951+
return None
958952

959-
def flush(self) -> None:
960-
# write to actual handle and close write buffer
961-
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed:
962-
return
953+
filename = Path(self.name)
954+
if filename.suffix == ".tar":
955+
return filename.with_suffix("").name
956+
elif filename.suffix in (".tar.gz", ".tar.bz2", ".tar.xz"):
957+
return filename.with_suffix("").with_suffix("").name
958+
return filename.name
963959

960+
def write_to_buffer(self) -> None:
964961
# TarFile needs a non-empty string
965962
archive_name = self.archive_name or self.infer_filename() or "tar"
966-
with self.multiple_write_buffer:
967-
value = self.multiple_write_buffer.getvalue()
968-
tarinfo = tarfile.TarInfo(name=archive_name)
969-
tarinfo.size = len(value)
970-
self.addfile(tarinfo, BytesIO(value))
971-
972-
def close(self):
973-
self.flush()
974-
super().close()
975-
976-
@property
977-
def closed(self):
978-
if self.multiple_write_buffer is None:
979-
return False
980-
return self.multiple_write_buffer.closed and super().closed
981-
982-
@closed.setter
983-
def closed(self, value):
984-
if not self._closing and value:
985-
self._closing = True
986-
self.close()
987-
988-
989-
# error: Definition of "__exit__" in base class "ZipFile" is incompatible with
990-
# definition in base class "BytesIO" [misc]
991-
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
992-
# definition in base class "BytesIO" [misc]
993-
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
994-
# definition in base class "BinaryIO" [misc]
995-
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
996-
# definition in base class "IO" [misc]
997-
# error: Definition of "read" in base class "ZipFile" is incompatible with
998-
# definition in base class "BytesIO" [misc]
999-
# error: Definition of "read" in base class "ZipFile" is incompatible with
1000-
# definition in base class "IO" [misc]
1001-
class _BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore[misc]
1002-
"""
1003-
Wrapper for standard library class ZipFile and allow the returned file-like
1004-
handle to accept byte strings via `write` method.
963+
tarinfo = tarfile.TarInfo(name=archive_name)
964+
tarinfo.size = len(self.getvalue())
965+
self.buffer.addfile(tarinfo, self)
1005966

1006-
BytesIO provides attributes of file-like object and ZipFile.writestr writes
1007-
bytes strings into a member of the archive.
1008-
"""
1009967

1010-
# GH 17778
968+
class _BytesZipFile(_BufferedWriter):
1011969
def __init__(
1012970
self,
1013971
file: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes],
1014972
mode: str,
1015973
archive_name: str | None = None,
1016974
**kwargs,
1017975
) -> None:
976+
super().__init__()
1018977
mode = mode.replace("b", "")
1019978
self.archive_name = archive_name
1020-
self.multiple_write_buffer: StringIO | BytesIO | None = None
1021979

1022-
kwargs_zip: dict[str, Any] = {"compression": zipfile.ZIP_DEFLATED}
1023-
kwargs_zip.update(kwargs)
980+
kwargs.setdefault("compression", zipfile.ZIP_DEFLATED)
981+
# error: Argument 1 to "ZipFile" has incompatible type "Union[
982+
# Union[str, PathLike[str]], ReadBuffer[bytes], WriteBuffer[bytes]]";
983+
# expected "Union[Union[str, PathLike[str]], IO[bytes]]"
984+
self.buffer = zipfile.ZipFile(file, mode, **kwargs) # type: ignore[arg-type]
1024985

1025-
# error: Argument 1 to "__init__" of "ZipFile" has incompatible type
1026-
# "Union[_PathLike[str], Union[str, Union[IO[Any], RawIOBase, BufferedIOBase,
1027-
# TextIOBase, TextIOWrapper, mmap]]]"; expected "Union[Union[str,
1028-
# _PathLike[str]], IO[bytes]]"
1029-
super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type]
1030-
1031-
def infer_filename(self):
986+
def infer_filename(self) -> str | None:
1032987
"""
1033988
If an explicit archive_name is not given, we still want the file inside the zip
1034989
file not to be named something.zip, because that causes confusion (GH39465).
1035990
"""
1036-
if isinstance(self.filename, (os.PathLike, str)):
1037-
filename = Path(self.filename)
991+
if isinstance(self.buffer.filename, (os.PathLike, str)):
992+
filename = Path(self.buffer.filename)
1038993
if filename.suffix == ".zip":
1039994
return filename.with_suffix("").name
1040995
return filename.name
1041996
return None
1042997

1043-
def write(self, data):
1044-
# buffer multiple write calls, write on flush
1045-
if self.multiple_write_buffer is None:
1046-
self.multiple_write_buffer = (
1047-
BytesIO() if isinstance(data, bytes) else StringIO()
1048-
)
1049-
self.multiple_write_buffer.write(data)
1050-
1051-
def flush(self) -> None:
1052-
# write to actual handle and close write buffer
1053-
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed:
1054-
return
1055-
998+
def write_to_buffer(self) -> None:
1056999
# ZipFile needs a non-empty string
10571000
archive_name = self.archive_name or self.infer_filename() or "zip"
1058-
with self.multiple_write_buffer:
1059-
super().writestr(archive_name, self.multiple_write_buffer.getvalue())
1060-
1061-
def close(self):
1062-
self.flush()
1063-
super().close()
1064-
1065-
@property
1066-
def closed(self):
1067-
return self.fp is None
1001+
self.buffer.writestr(archive_name, self.getvalue())
10681002

10691003

10701004
class _CSVMMapWrapper(abc.Iterator):

pandas/tests/io/test_compression.py

+6
Original file line numberDiff line numberDiff line change
@@ -334,3 +334,9 @@ def test_tar_gz_to_different_filename():
334334
expected = "foo,bar\n1,2\n"
335335

336336
assert content == expected
337+
338+
339+
def test_tar_no_error_on_close():
340+
with io.BytesIO() as buffer:
341+
with icom._BytesTarFile(fileobj=buffer, mode="w"):
342+
pass

0 commit comments

Comments
 (0)