Skip to content

CLN: Share code between _BytesTarFile and _BytesZipFile #47153

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 30, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 95 additions & 161 deletions pandas/io/common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
"""Common IO api utilities"""
from __future__ import annotations

from abc import (
ABC,
abstractmethod,
)
import bz2
import codecs
from collections import abc
Expand All @@ -10,7 +14,6 @@
from io import (
BufferedIOBase,
BytesIO,
FileIO,
RawIOBase,
StringIO,
TextIOBase,
Expand Down Expand Up @@ -712,8 +715,7 @@ def get_handle(

# GZ Compression
if compression == "gzip":
if is_path:
assert isinstance(handle, str)
if isinstance(handle, str):
# error: Incompatible types in assignment (expression has type
# "GzipFile", variable has type "Union[str, BaseBuffer]")
handle = gzip.GzipFile( # type: ignore[assignment]
Expand Down Expand Up @@ -742,18 +744,18 @@ def get_handle(

# ZIP Compression
elif compression == "zip":
# error: Argument 1 to "_BytesZipFile" has incompatible type "Union[str,
# BaseBuffer]"; expected "Union[Union[str, PathLike[str]],
# error: Argument 1 to "_BytesZipFile" has incompatible type
# "Union[str, BaseBuffer]"; expected "Union[Union[str, PathLike[str]],
# ReadBuffer[bytes], WriteBuffer[bytes]]"
handle = _BytesZipFile(
handle, ioargs.mode, **compression_args # type: ignore[arg-type]
)
if handle.mode == "r":
if handle.buffer.mode == "r":
handles.append(handle)
zip_names = handle.namelist()
zip_names = handle.buffer.namelist()
if len(zip_names) == 1:
handle = handle.open(zip_names.pop())
elif len(zip_names) == 0:
handle = handle.buffer.open(zip_names.pop())
elif not zip_names:
raise ValueError(f"Zero files found in ZIP file {path_or_buf}")
else:
raise ValueError(
Expand All @@ -763,21 +765,25 @@ def get_handle(

# TAR Encoding
elif compression == "tar":
if "mode" not in compression_args:
compression_args["mode"] = ioargs.mode
if is_path:
handle = _BytesTarFile.open(name=handle, **compression_args)
compression_args.setdefault("mode", ioargs.mode)
if isinstance(handle, str):
handle = _BytesTarFile(name=handle, **compression_args)
else:
handle = _BytesTarFile.open(fileobj=handle, **compression_args)
# error: Argument "fileobj" to "_BytesTarFile" has incompatible
# type "BaseBuffer"; expected "Union[ReadBuffer[bytes],
# WriteBuffer[bytes], None]"
handle = _BytesTarFile(
fileobj=handle, **compression_args # type: ignore[arg-type]
)
assert isinstance(handle, _BytesTarFile)
if handle.mode == "r":
if "r" in handle.buffer.mode:
handles.append(handle)
files = handle.getnames()
files = handle.buffer.getnames()
if len(files) == 1:
file = handle.extractfile(files[0])
file = handle.buffer.extractfile(files[0])
assert file is not None
handle = file
elif len(files) == 0:
elif not files:
raise ValueError(f"Zero files found in TAR archive {path_or_buf}")
else:
raise ValueError(
Expand Down Expand Up @@ -876,195 +882,123 @@ def get_handle(
)


# error: Definition of "__exit__" in base class "TarFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
# definition in base class "BinaryIO" [misc]
# error: Definition of "__enter__" in base class "TarFile" is incompatible with
# definition in base class "IO" [misc]
# error: Definition of "read" in base class "TarFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "read" in base class "TarFile" is incompatible with
# definition in base class "IO" [misc]
class _BytesTarFile(tarfile.TarFile, BytesIO): # type: ignore[misc]
# error: Definition of "__enter__" in base class "IOBase" is incompatible
# with definition in base class "BinaryIO"
class _BufferedWriter(BytesIO, ABC): # type: ignore[misc]
"""
Wrapper for standard library class TarFile and allow the returned file-like
handle to accept byte strings via `write` method.

BytesIO provides attributes of file-like object and TarFile.addfile writes
bytes strings into a member of the archive.
Some objects do not support multiple .write() calls (TarFile and ZipFile).
This wrapper writes to the underlying buffer on close.
"""

# GH 17778
@abstractmethod
def write_to_buffer(self):
...

def close(self) -> None:
if self.closed:
# already closed
return
if self.getvalue():
# write to buffer
self.seek(0)
# error: "_BufferedWriter" has no attribute "buffer"
with self.buffer: # type: ignore[attr-defined]
self.write_to_buffer()
else:
# error: "_BufferedWriter" has no attribute "buffer"
self.buffer.close() # type: ignore[attr-defined]
super().close()


class _BytesTarFile(_BufferedWriter):
def __init__(
self,
name: str | bytes | os.PathLike[str] | os.PathLike[bytes],
mode: Literal["r", "a", "w", "x"],
fileobj: FileIO,
name: str | None = None,
mode: Literal["r", "a", "w", "x"] = "r",
fileobj: ReadBuffer[bytes] | WriteBuffer[bytes] | None = None,
archive_name: str | None = None,
**kwargs,
):
) -> None:
super().__init__()
self.archive_name = archive_name
self.multiple_write_buffer: BytesIO | None = None
self._closing = False

super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs)
self.name = name
# error: Argument "fileobj" to "open" of "TarFile" has incompatible
# type "Union[ReadBuffer[bytes], WriteBuffer[bytes], None]"; expected
# "Optional[IO[bytes]]"
self.buffer = tarfile.TarFile.open(
name=name,
mode=self.extend_mode(mode),
fileobj=fileobj, # type: ignore[arg-type]
**kwargs,
)

@classmethod
def open(cls, name=None, mode="r", **kwargs):
def extend_mode(self, mode: str) -> str:
mode = mode.replace("b", "")
return super().open(name=name, mode=cls.extend_mode(name, mode), **kwargs)

@classmethod
def extend_mode(
cls, name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], mode: str
) -> str:
if mode != "w":
return mode
if isinstance(name, (os.PathLike, str)):
filename = Path(name)
if filename.suffix == ".gz":
return mode + ":gz"
elif filename.suffix == ".xz":
return mode + ":xz"
elif filename.suffix == ".bz2":
return mode + ":bz2"
if self.name is not None:
suffix = Path(self.name).suffix
if suffix in (".gz", ".xz", ".bz2"):
mode = f"{mode}:{suffix[1:]}"
return mode

def infer_filename(self):
def infer_filename(self) -> str | None:
"""
If an explicit archive_name is not given, we still want the file inside the zip
file not to be named something.tar, because that causes confusion (GH39465).
"""
if isinstance(self.name, (os.PathLike, str)):
# error: Argument 1 to "Path" has
# incompatible type "Union[str, PathLike[str], PathLike[bytes]]";
# expected "Union[str, PathLike[str]]" [arg-type]
filename = Path(self.name) # type: ignore[arg-type]
if filename.suffix == ".tar":
return filename.with_suffix("").name
if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]:
return filename.with_suffix("").with_suffix("").name
return filename.name
return None

def write(self, data):
# buffer multiple write calls, write on flush
if self.multiple_write_buffer is None:
self.multiple_write_buffer = BytesIO()
self.multiple_write_buffer.write(data)
if self.name is None:
return None

def flush(self) -> None:
# write to actual handle and close write buffer
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed:
return
filename = Path(self.name)
if filename.suffix == ".tar":
return filename.with_suffix("").name
elif filename.suffix in (".tar.gz", ".tar.bz2", ".tar.xz"):
return filename.with_suffix("").with_suffix("").name
return filename.name

def write_to_buffer(self) -> None:
# TarFile needs a non-empty string
archive_name = self.archive_name or self.infer_filename() or "tar"
with self.multiple_write_buffer:
value = self.multiple_write_buffer.getvalue()
tarinfo = tarfile.TarInfo(name=archive_name)
tarinfo.size = len(value)
self.addfile(tarinfo, BytesIO(value))

def close(self):
self.flush()
super().close()

@property
def closed(self):
if self.multiple_write_buffer is None:
return False
return self.multiple_write_buffer.closed and super().closed

@closed.setter
def closed(self, value):
if not self._closing and value:
self._closing = True
self.close()


# error: Definition of "__exit__" in base class "ZipFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
# definition in base class "BinaryIO" [misc]
# error: Definition of "__enter__" in base class "ZipFile" is incompatible with
# definition in base class "IO" [misc]
# error: Definition of "read" in base class "ZipFile" is incompatible with
# definition in base class "BytesIO" [misc]
# error: Definition of "read" in base class "ZipFile" is incompatible with
# definition in base class "IO" [misc]
class _BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore[misc]
"""
Wrapper for standard library class ZipFile and allow the returned file-like
handle to accept byte strings via `write` method.
tarinfo = tarfile.TarInfo(name=archive_name)
tarinfo.size = len(self.getvalue())
self.buffer.addfile(tarinfo, self)

BytesIO provides attributes of file-like object and ZipFile.writestr writes
bytes strings into a member of the archive.
"""

# GH 17778
class _BytesZipFile(_BufferedWriter):
def __init__(
self,
file: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes],
mode: str,
archive_name: str | None = None,
**kwargs,
) -> None:
super().__init__()
mode = mode.replace("b", "")
self.archive_name = archive_name
self.multiple_write_buffer: StringIO | BytesIO | None = None

kwargs_zip: dict[str, Any] = {"compression": zipfile.ZIP_DEFLATED}
kwargs_zip.update(kwargs)
kwargs.setdefault("compression", zipfile.ZIP_DEFLATED)
# error: Argument 1 to "ZipFile" has incompatible type "Union[
# Union[str, PathLike[str]], ReadBuffer[bytes], WriteBuffer[bytes]]";
# expected "Union[Union[str, PathLike[str]], IO[bytes]]"
self.buffer = zipfile.ZipFile(file, mode, **kwargs) # type: ignore[arg-type]

# error: Argument 1 to "__init__" of "ZipFile" has incompatible type
# "Union[_PathLike[str], Union[str, Union[IO[Any], RawIOBase, BufferedIOBase,
# TextIOBase, TextIOWrapper, mmap]]]"; expected "Union[Union[str,
# _PathLike[str]], IO[bytes]]"
super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type]

def infer_filename(self):
def infer_filename(self) -> str | None:
"""
If an explicit archive_name is not given, we still want the file inside the zip
file not to be named something.zip, because that causes confusion (GH39465).
"""
if isinstance(self.filename, (os.PathLike, str)):
filename = Path(self.filename)
if isinstance(self.buffer.filename, (os.PathLike, str)):
filename = Path(self.buffer.filename)
if filename.suffix == ".zip":
return filename.with_suffix("").name
return filename.name
return None

def write(self, data):
# buffer multiple write calls, write on flush
if self.multiple_write_buffer is None:
self.multiple_write_buffer = (
BytesIO() if isinstance(data, bytes) else StringIO()
)
self.multiple_write_buffer.write(data)

def flush(self) -> None:
# write to actual handle and close write buffer
if self.multiple_write_buffer is None or self.multiple_write_buffer.closed:
return

def write_to_buffer(self) -> None:
# ZipFile needs a non-empty string
archive_name = self.archive_name or self.infer_filename() or "zip"
with self.multiple_write_buffer:
super().writestr(archive_name, self.multiple_write_buffer.getvalue())

def close(self):
self.flush()
super().close()

@property
def closed(self):
return self.fp is None
self.buffer.writestr(archive_name, self.getvalue())


class _CSVMMapWrapper(abc.Iterator):
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/io/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,9 @@ def test_tar_gz_to_different_filename():
expected = "foo,bar\n1,2\n"

assert content == expected


def test_tar_no_error_on_close():
with io.BytesIO() as buffer:
with icom._BytesTarFile(fileobj=buffer, mode="w"):
pass