diff --git a/pandas/io/common.py b/pandas/io/common.py index fdee1600c2a32..33f83d7c66433 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -1,6 +1,10 @@ """Common IO api utilities""" from __future__ import annotations +from abc import ( + ABC, + abstractmethod, +) import bz2 import codecs from collections import abc @@ -10,7 +14,6 @@ from io import ( BufferedIOBase, BytesIO, - FileIO, RawIOBase, StringIO, TextIOBase, @@ -712,8 +715,7 @@ def get_handle( # GZ Compression if compression == "gzip": - if is_path: - assert isinstance(handle, str) + if isinstance(handle, str): # error: Incompatible types in assignment (expression has type # "GzipFile", variable has type "Union[str, BaseBuffer]") handle = gzip.GzipFile( # type: ignore[assignment] @@ -742,18 +744,18 @@ def get_handle( # ZIP Compression elif compression == "zip": - # error: Argument 1 to "_BytesZipFile" has incompatible type "Union[str, - # BaseBuffer]"; expected "Union[Union[str, PathLike[str]], + # error: Argument 1 to "_BytesZipFile" has incompatible type + # "Union[str, BaseBuffer]"; expected "Union[Union[str, PathLike[str]], # ReadBuffer[bytes], WriteBuffer[bytes]]" handle = _BytesZipFile( handle, ioargs.mode, **compression_args # type: ignore[arg-type] ) - if handle.mode == "r": + if handle.buffer.mode == "r": handles.append(handle) - zip_names = handle.namelist() + zip_names = handle.buffer.namelist() if len(zip_names) == 1: - handle = handle.open(zip_names.pop()) - elif len(zip_names) == 0: + handle = handle.buffer.open(zip_names.pop()) + elif not zip_names: raise ValueError(f"Zero files found in ZIP file {path_or_buf}") else: raise ValueError( @@ -763,21 +765,25 @@ def get_handle( # TAR Encoding elif compression == "tar": - if "mode" not in compression_args: - compression_args["mode"] = ioargs.mode - if is_path: - handle = _BytesTarFile.open(name=handle, **compression_args) + compression_args.setdefault("mode", ioargs.mode) + if isinstance(handle, str): + handle = _BytesTarFile(name=handle, **compression_args) else: - handle = _BytesTarFile.open(fileobj=handle, **compression_args) + # error: Argument "fileobj" to "_BytesTarFile" has incompatible + # type "BaseBuffer"; expected "Union[ReadBuffer[bytes], + # WriteBuffer[bytes], None]" + handle = _BytesTarFile( + fileobj=handle, **compression_args # type: ignore[arg-type] + ) assert isinstance(handle, _BytesTarFile) - if handle.mode == "r": + if "r" in handle.buffer.mode: handles.append(handle) - files = handle.getnames() + files = handle.buffer.getnames() if len(files) == 1: - file = handle.extractfile(files[0]) + file = handle.buffer.extractfile(files[0]) assert file is not None handle = file - elif len(files) == 0: + elif not files: raise ValueError(f"Zero files found in TAR archive {path_or_buf}") else: raise ValueError( @@ -876,138 +882,90 @@ def get_handle( ) -# error: Definition of "__exit__" in base class "TarFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "__enter__" in base class "TarFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "__enter__" in base class "TarFile" is incompatible with -# definition in base class "BinaryIO" [misc] -# error: Definition of "__enter__" in base class "TarFile" is incompatible with -# definition in base class "IO" [misc] -# error: Definition of "read" in base class "TarFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "read" in base class "TarFile" is incompatible with -# definition in base class "IO" [misc] -class _BytesTarFile(tarfile.TarFile, BytesIO): # type: ignore[misc] +# error: Definition of "__enter__" in base class "IOBase" is incompatible +# with definition in base class "BinaryIO" +class _BufferedWriter(BytesIO, ABC): # type: ignore[misc] """ - Wrapper for standard library class TarFile and allow the returned file-like - handle to accept byte strings via `write` method. - - BytesIO provides attributes of file-like object and TarFile.addfile writes - bytes strings into a member of the archive. + Some objects do not support multiple .write() calls (TarFile and ZipFile). + This wrapper writes to the underlying buffer on close. """ - # GH 17778 + @abstractmethod + def write_to_buffer(self): + ... + + def close(self) -> None: + if self.closed: + # already closed + return + if self.getvalue(): + # write to buffer + self.seek(0) + # error: "_BufferedWriter" has no attribute "buffer" + with self.buffer: # type: ignore[attr-defined] + self.write_to_buffer() + else: + # error: "_BufferedWriter" has no attribute "buffer" + self.buffer.close() # type: ignore[attr-defined] + super().close() + + +class _BytesTarFile(_BufferedWriter): def __init__( self, - name: str | bytes | os.PathLike[str] | os.PathLike[bytes], - mode: Literal["r", "a", "w", "x"], - fileobj: FileIO, + name: str | None = None, + mode: Literal["r", "a", "w", "x"] = "r", + fileobj: ReadBuffer[bytes] | WriteBuffer[bytes] | None = None, archive_name: str | None = None, **kwargs, - ): + ) -> None: + super().__init__() self.archive_name = archive_name - self.multiple_write_buffer: BytesIO | None = None - self._closing = False - - super().__init__(name=name, mode=mode, fileobj=fileobj, **kwargs) + self.name = name + # error: Argument "fileobj" to "open" of "TarFile" has incompatible + # type "Union[ReadBuffer[bytes], WriteBuffer[bytes], None]"; expected + # "Optional[IO[bytes]]" + self.buffer = tarfile.TarFile.open( + name=name, + mode=self.extend_mode(mode), + fileobj=fileobj, # type: ignore[arg-type] + **kwargs, + ) - @classmethod - def open(cls, name=None, mode="r", **kwargs): + def extend_mode(self, mode: str) -> str: mode = mode.replace("b", "") - return super().open(name=name, mode=cls.extend_mode(name, mode), **kwargs) - - @classmethod - def extend_mode( - cls, name: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], mode: str - ) -> str: if mode != "w": return mode - if isinstance(name, (os.PathLike, str)): - filename = Path(name) - if filename.suffix == ".gz": - return mode + ":gz" - elif filename.suffix == ".xz": - return mode + ":xz" - elif filename.suffix == ".bz2": - return mode + ":bz2" + if self.name is not None: + suffix = Path(self.name).suffix + if suffix in (".gz", ".xz", ".bz2"): + mode = f"{mode}:{suffix[1:]}" return mode - def infer_filename(self): + def infer_filename(self) -> str | None: """ If an explicit archive_name is not given, we still want the file inside the zip file not to be named something.tar, because that causes confusion (GH39465). """ - if isinstance(self.name, (os.PathLike, str)): - # error: Argument 1 to "Path" has - # incompatible type "Union[str, PathLike[str], PathLike[bytes]]"; - # expected "Union[str, PathLike[str]]" [arg-type] - filename = Path(self.name) # type: ignore[arg-type] - if filename.suffix == ".tar": - return filename.with_suffix("").name - if filename.suffix in [".tar.gz", ".tar.bz2", ".tar.xz"]: - return filename.with_suffix("").with_suffix("").name - return filename.name - return None - - def write(self, data): - # buffer multiple write calls, write on flush - if self.multiple_write_buffer is None: - self.multiple_write_buffer = BytesIO() - self.multiple_write_buffer.write(data) + if self.name is None: + return None - def flush(self) -> None: - # write to actual handle and close write buffer - if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: - return + filename = Path(self.name) + if filename.suffix == ".tar": + return filename.with_suffix("").name + elif filename.suffix in (".tar.gz", ".tar.bz2", ".tar.xz"): + return filename.with_suffix("").with_suffix("").name + return filename.name + def write_to_buffer(self) -> None: # TarFile needs a non-empty string archive_name = self.archive_name or self.infer_filename() or "tar" - with self.multiple_write_buffer: - value = self.multiple_write_buffer.getvalue() - tarinfo = tarfile.TarInfo(name=archive_name) - tarinfo.size = len(value) - self.addfile(tarinfo, BytesIO(value)) - - def close(self): - self.flush() - super().close() - - @property - def closed(self): - if self.multiple_write_buffer is None: - return False - return self.multiple_write_buffer.closed and super().closed - - @closed.setter - def closed(self, value): - if not self._closing and value: - self._closing = True - self.close() - - -# error: Definition of "__exit__" in base class "ZipFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "__enter__" in base class "ZipFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "__enter__" in base class "ZipFile" is incompatible with -# definition in base class "BinaryIO" [misc] -# error: Definition of "__enter__" in base class "ZipFile" is incompatible with -# definition in base class "IO" [misc] -# error: Definition of "read" in base class "ZipFile" is incompatible with -# definition in base class "BytesIO" [misc] -# error: Definition of "read" in base class "ZipFile" is incompatible with -# definition in base class "IO" [misc] -class _BytesZipFile(zipfile.ZipFile, BytesIO): # type: ignore[misc] - """ - Wrapper for standard library class ZipFile and allow the returned file-like - handle to accept byte strings via `write` method. + tarinfo = tarfile.TarInfo(name=archive_name) + tarinfo.size = len(self.getvalue()) + self.buffer.addfile(tarinfo, self) - BytesIO provides attributes of file-like object and ZipFile.writestr writes - bytes strings into a member of the archive. - """ - # GH 17778 +class _BytesZipFile(_BufferedWriter): def __init__( self, file: FilePath | ReadBuffer[bytes] | WriteBuffer[bytes], @@ -1015,56 +973,32 @@ def __init__( archive_name: str | None = None, **kwargs, ) -> None: + super().__init__() mode = mode.replace("b", "") self.archive_name = archive_name - self.multiple_write_buffer: StringIO | BytesIO | None = None - kwargs_zip: dict[str, Any] = {"compression": zipfile.ZIP_DEFLATED} - kwargs_zip.update(kwargs) + kwargs.setdefault("compression", zipfile.ZIP_DEFLATED) + # error: Argument 1 to "ZipFile" has incompatible type "Union[ + # Union[str, PathLike[str]], ReadBuffer[bytes], WriteBuffer[bytes]]"; + # expected "Union[Union[str, PathLike[str]], IO[bytes]]" + self.buffer = zipfile.ZipFile(file, mode, **kwargs) # type: ignore[arg-type] - # error: Argument 1 to "__init__" of "ZipFile" has incompatible type - # "Union[_PathLike[str], Union[str, Union[IO[Any], RawIOBase, BufferedIOBase, - # TextIOBase, TextIOWrapper, mmap]]]"; expected "Union[Union[str, - # _PathLike[str]], IO[bytes]]" - super().__init__(file, mode, **kwargs_zip) # type: ignore[arg-type] - - def infer_filename(self): + def infer_filename(self) -> str | None: """ If an explicit archive_name is not given, we still want the file inside the zip file not to be named something.zip, because that causes confusion (GH39465). """ - if isinstance(self.filename, (os.PathLike, str)): - filename = Path(self.filename) + if isinstance(self.buffer.filename, (os.PathLike, str)): + filename = Path(self.buffer.filename) if filename.suffix == ".zip": return filename.with_suffix("").name return filename.name return None - def write(self, data): - # buffer multiple write calls, write on flush - if self.multiple_write_buffer is None: - self.multiple_write_buffer = ( - BytesIO() if isinstance(data, bytes) else StringIO() - ) - self.multiple_write_buffer.write(data) - - def flush(self) -> None: - # write to actual handle and close write buffer - if self.multiple_write_buffer is None or self.multiple_write_buffer.closed: - return - + def write_to_buffer(self) -> None: # ZipFile needs a non-empty string archive_name = self.archive_name or self.infer_filename() or "zip" - with self.multiple_write_buffer: - super().writestr(archive_name, self.multiple_write_buffer.getvalue()) - - def close(self): - self.flush() - super().close() - - @property - def closed(self): - return self.fp is None + self.buffer.writestr(archive_name, self.getvalue()) class _CSVMMapWrapper(abc.Iterator): diff --git a/pandas/tests/io/test_compression.py b/pandas/tests/io/test_compression.py index 98e136a9c4ba6..125d078ff39b1 100644 --- a/pandas/tests/io/test_compression.py +++ b/pandas/tests/io/test_compression.py @@ -334,3 +334,9 @@ def test_tar_gz_to_different_filename(): expected = "foo,bar\n1,2\n" assert content == expected + + +def test_tar_no_error_on_close(): + with io.BytesIO() as buffer: + with icom._BytesTarFile(fileobj=buffer, mode="w"): + pass