diff --git a/pandas/io/common.py b/pandas/io/common.py index 06b00a9cbb4eb..4e97eaf8b953c 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -6,11 +6,13 @@ from collections import abc import dataclasses import gzip +import io from io import ( BufferedIOBase, BytesIO, RawIOBase, StringIO, + TextIOBase, TextIOWrapper, ) import mmap @@ -50,7 +52,6 @@ lzma = import_lzma() - _VALID_URLS = set(uses_relative + uses_netloc + uses_params) _VALID_URLS.discard("") @@ -102,7 +103,7 @@ def close(self) -> None: avoid closing the potentially user-created buffer. """ if self.is_wrapped: - assert isinstance(self.handle, TextIOWrapper) + assert isinstance(self.handle, (TextIOWrapper, BytesIOWrapper)) self.handle.flush() self.handle.detach() self.created_handles.remove(self.handle) @@ -712,7 +713,16 @@ def get_handle( # Convert BytesIO or file objects passed with an encoding is_wrapped = False - if is_text and (compression or _is_binary_mode(handle, ioargs.mode)): + if not is_text and ioargs.mode == "rb" and isinstance(handle, TextIOBase): + handle = BytesIOWrapper( + handle, + encoding=ioargs.encoding, + ) + handles.append(handle) + # the (text) handle is always provided by the caller + # since get_handle would have opened it in binary mode + is_wrapped = True + elif is_text and (compression or _is_binary_mode(handle, ioargs.mode)): handle = TextIOWrapper( # error: Argument 1 to "TextIOWrapper" has incompatible type # "Union[IO[bytes], IO[Any], RawIOBase, BufferedIOBase, TextIOBase, mmap]"; @@ -878,6 +888,46 @@ def __next__(self) -> str: return newline.lstrip("\n") +# Wrapper that wraps a StringIO buffer and reads bytes from it +# Created for compat with pyarrow read_csv +class BytesIOWrapper(io.BytesIO): + buffer: StringIO | TextIOBase | None + + def __init__(self, buffer: StringIO | TextIOBase, encoding: str = "utf-8"): + self.buffer = buffer + self.encoding = encoding + # Because a character can be represented by more than 1 byte, + # it is possible that reading will produce more bytes than n + # We store the extra bytes in this overflow variable, and append the + # overflow to the front of the bytestring the next time reading is performed + self.overflow = b"" + + def __getattr__(self, attr: str): + return getattr(self.buffer, attr) + + def read(self, n: int | None = -1) -> bytes: + assert self.buffer is not None + bytestring = self.buffer.read(n).encode(self.encoding) + # When n=-1/n greater than remaining bytes: Read entire file/rest of file + combined_bytestring = self.overflow + bytestring + if n is None or n < 0 or n >= len(combined_bytestring): + self.overflow = b"" + return combined_bytestring + else: + to_return = combined_bytestring[:n] + self.overflow = combined_bytestring[n:] + return to_return + + def detach(self): + # Slightly modified from Python's TextIOWrapper detach method + if self.buffer is None: + raise ValueError("buffer is already detached") + self.flush() + buffer = self.buffer + self.buffer = None + return buffer + + def _maybe_memory_map( handle: FileOrBuffer, memory_map: bool, diff --git a/pandas/tests/io/test_common.py b/pandas/tests/io/test_common.py index d52ea01ac35de..b48d676cd0f8a 100644 --- a/pandas/tests/io/test_common.py +++ b/pandas/tests/io/test_common.py @@ -135,6 +135,48 @@ def test_get_handle_with_buffer(self): assert not input_buffer.closed input_buffer.close() + # Test that BytesIOWrapper(get_handle) returns correct amount of bytes every time + def test_bytesiowrapper_returns_correct_bytes(self): + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + with icom.get_handle(StringIO(data), "rb", is_text=False) as handles: + result = b"" + chunksize = 5 + while True: + chunk = handles.handle.read(chunksize) + # Make sure each chunk is correct amount of bytes + assert len(chunk) <= chunksize + if len(chunk) < chunksize: + # Can be less amount of bytes, but only at EOF + # which happens when read returns empty + assert len(handles.handle.read()) == 0 + result += chunk + break + result += chunk + assert result == data.encode("utf-8") + + # Test that pyarrow can handle a file opened with get_handle + @td.skip_if_no("pyarrow", min_version="0.15.0") + def test_get_handle_pyarrow_compat(self): + from pyarrow import csv + + # Test latin1, ucs-2, and ucs-4 chars + data = """a,b,c +1,2,3 +©,®,® +Look,a snake,🐍""" + expected = pd.DataFrame( + {"a": ["1", "©", "Look"], "b": ["2", "®", "a snake"], "c": ["3", "®", "🐍"]} + ) + s = StringIO(data) + with icom.get_handle(s, "rb", is_text=False) as handles: + df = csv.read_csv(handles.handle).to_pandas() + tm.assert_frame_equal(df, expected) + assert not s.closed + def test_iterator(self): with pd.read_csv(StringIO(self.data1), chunksize=1) as reader: result = pd.concat(reader, ignore_index=True)