|
6 | 6 | from collections import abc
|
7 | 7 | import dataclasses
|
8 | 8 | import gzip
|
| 9 | +import io |
9 | 10 | from io import (
|
10 | 11 | BufferedIOBase,
|
11 | 12 | BytesIO,
|
12 | 13 | RawIOBase,
|
13 | 14 | StringIO,
|
| 15 | + TextIOBase, |
14 | 16 | TextIOWrapper,
|
15 | 17 | )
|
16 | 18 | import mmap
|
|
50 | 52 |
|
51 | 53 | lzma = import_lzma()
|
52 | 54 |
|
53 |
| - |
54 | 55 | _VALID_URLS = set(uses_relative + uses_netloc + uses_params)
|
55 | 56 | _VALID_URLS.discard("")
|
56 | 57 |
|
@@ -102,7 +103,7 @@ def close(self) -> None:
|
102 | 103 | avoid closing the potentially user-created buffer.
|
103 | 104 | """
|
104 | 105 | if self.is_wrapped:
|
105 |
| - assert isinstance(self.handle, TextIOWrapper) |
| 106 | + assert isinstance(self.handle, (TextIOWrapper, BytesIOWrapper)) |
106 | 107 | self.handle.flush()
|
107 | 108 | self.handle.detach()
|
108 | 109 | self.created_handles.remove(self.handle)
|
@@ -712,7 +713,16 @@ def get_handle(
|
712 | 713 |
|
713 | 714 | # Convert BytesIO or file objects passed with an encoding
|
714 | 715 | is_wrapped = False
|
715 |
| - if is_text and (compression or _is_binary_mode(handle, ioargs.mode)): |
| 716 | + if not is_text and ioargs.mode == "rb" and isinstance(handle, TextIOBase): |
| 717 | + handle = BytesIOWrapper( |
| 718 | + handle, |
| 719 | + encoding=ioargs.encoding, |
| 720 | + ) |
| 721 | + handles.append(handle) |
| 722 | + # the (text) handle is always provided by the caller |
| 723 | + # since get_handle would have opened it in binary mode |
| 724 | + is_wrapped = True |
| 725 | + elif is_text and (compression or _is_binary_mode(handle, ioargs.mode)): |
716 | 726 | handle = TextIOWrapper(
|
717 | 727 | # error: Argument 1 to "TextIOWrapper" has incompatible type
|
718 | 728 | # "Union[IO[bytes], IO[Any], RawIOBase, BufferedIOBase, TextIOBase, mmap]";
|
@@ -878,6 +888,46 @@ def __next__(self) -> str:
|
878 | 888 | return newline.lstrip("\n")
|
879 | 889 |
|
880 | 890 |
|
| 891 | +# Wrapper that wraps a StringIO buffer and reads bytes from it |
| 892 | +# Created for compat with pyarrow read_csv |
| 893 | +class BytesIOWrapper(io.BytesIO): |
| 894 | + buffer: StringIO | TextIOBase | None |
| 895 | + |
| 896 | + def __init__(self, buffer: StringIO | TextIOBase, encoding: str = "utf-8"): |
| 897 | + self.buffer = buffer |
| 898 | + self.encoding = encoding |
| 899 | + # Because a character can be represented by more than 1 byte, |
| 900 | + # it is possible that reading will produce more bytes than n |
| 901 | + # We store the extra bytes in this overflow variable, and append the |
| 902 | + # overflow to the front of the bytestring the next time reading is performed |
| 903 | + self.overflow = b"" |
| 904 | + |
| 905 | + def __getattr__(self, attr: str): |
| 906 | + return getattr(self.buffer, attr) |
| 907 | + |
| 908 | + def read(self, n: int | None = -1) -> bytes: |
| 909 | + assert self.buffer is not None |
| 910 | + bytestring = self.buffer.read(n).encode(self.encoding) |
| 911 | + # When n=-1/n greater than remaining bytes: Read entire file/rest of file |
| 912 | + combined_bytestring = self.overflow + bytestring |
| 913 | + if n is None or n < 0 or n >= len(combined_bytestring): |
| 914 | + self.overflow = b"" |
| 915 | + return combined_bytestring |
| 916 | + else: |
| 917 | + to_return = combined_bytestring[:n] |
| 918 | + self.overflow = combined_bytestring[n:] |
| 919 | + return to_return |
| 920 | + |
| 921 | + def detach(self): |
| 922 | + # Slightly modified from Python's TextIOWrapper detach method |
| 923 | + if self.buffer is None: |
| 924 | + raise ValueError("buffer is already detached") |
| 925 | + self.flush() |
| 926 | + buffer = self.buffer |
| 927 | + self.buffer = None |
| 928 | + return buffer |
| 929 | + |
| 930 | + |
881 | 931 | def _maybe_memory_map(
|
882 | 932 | handle: FileOrBuffer,
|
883 | 933 | memory_map: bool,
|
|
0 commit comments