Skip to content

TYP: make IOHandles generic #43855

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pandas/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
BufferedIOBase,
RawIOBase,
TextIOBase,
TextIOWrapper,
)
from mmap import mmap
from os import PathLike
Expand Down Expand Up @@ -170,7 +169,7 @@
PythonFuncType = Callable[[Any], Any]

# filenames and file-like-objects
Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, TextIOWrapper, mmap]
Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, mmap]
FileOrBuffer = Union[str, Buffer[AnyStr]]
FilePathOrBuffer = Union["PathLike[str]", FileOrBuffer[AnyStr]]

Expand Down
42 changes: 38 additions & 4 deletions pandas/io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,11 @@
IO,
Any,
AnyStr,
Generic,
Literal,
Mapping,
cast,
overload,
)
from urllib.parse import (
urljoin,
Expand Down Expand Up @@ -78,7 +81,7 @@ class IOArgs:


@dataclasses.dataclass
class IOHandles:
class IOHandles(Generic[AnyStr]):
"""
Return value of io/common.py:get_handle

Expand All @@ -92,7 +95,7 @@ class IOHandles:
is_wrapped: Whether a TextIOWrapper needs to be detached.
"""

handle: Buffer
handle: Buffer[AnyStr]
compression: CompressionDict
created_handles: list[Buffer] = dataclasses.field(default_factory=list)
is_wrapped: bool = False
Expand All @@ -118,7 +121,7 @@ def close(self) -> None:
self.created_handles = []
self.is_wrapped = False

def __enter__(self) -> IOHandles:
def __enter__(self) -> IOHandles[AnyStr]:
return self

def __exit__(self, *args: Any) -> None:
Expand Down Expand Up @@ -533,16 +536,47 @@ def check_parent_directory(path: Path | str) -> None:
raise OSError(fr"Cannot save file into a non-existent directory: '{parent}'")


@overload
def get_handle(
path_or_buf: FilePathOrBuffer,
mode: str,
*,
encoding: str | None = ...,
compression: CompressionOptions = ...,
memory_map: bool = ...,
is_text: Literal[False],
errors: str | None = ...,
storage_options: StorageOptions = ...,
) -> IOHandles[bytes]:
...


@overload
def get_handle(
path_or_buf: FilePathOrBuffer,
mode: str,
*,
encoding: str | None = ...,
compression: CompressionOptions = ...,
memory_map: bool = ...,
is_text: Literal[True] = True,
errors: str | None = ...,
storage_options: StorageOptions = ...,
) -> IOHandles[str]:
...


def get_handle(
path_or_buf: FilePathOrBuffer,
mode: str,
*,
encoding: str | None = None,
compression: CompressionOptions = None,
memory_map: bool = False,
is_text: bool = True,
errors: str | None = None,
storage_options: StorageOptions = None,
) -> IOHandles:
) -> IOHandles[str] | IOHandles[bytes]:
"""
Get file handle for given path/buffer and mode.

Expand Down
4 changes: 3 additions & 1 deletion pandas/io/excel/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,9 @@ def __init__(
mode = mode.replace("a", "r+")

# cast ExcelWriter to avoid adding 'if self.handles is not None'
self.handles = IOHandles(cast(Buffer, path), compression={"copression": None})
self.handles = IOHandles(
cast(Buffer[bytes], path), compression={"copression": None}
)
if not isinstance(path, ExcelWriter):
self.handles = get_handle(
path, mode, storage_options=storage_options, is_text=False
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/json/_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,7 @@ def __init__(
self.nrows_seen = 0
self.nrows = nrows
self.encoding_errors = encoding_errors
self.handles: IOHandles | None = None
self.handles: IOHandles[str] | None = None

if self.chunksize is not None:
self.chunksize = validate_integer("chunksize", self.chunksize, 1)
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def _get_path_or_handle(
storage_options: StorageOptions = None,
mode: str = "rb",
is_dir: bool = False,
) -> tuple[FilePathOrBuffer, IOHandles | None, Any]:
) -> tuple[FilePathOrBuffer, IOHandles[bytes] | None, Any]:
"""File handling for PyArrow."""
path_or_handle = stringify_path(path)
if is_fsspec_url(path_or_handle) and fs is None:
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(self, kwds):

self.usecols, self.usecols_dtype = self._validate_usecols_arg(kwds["usecols"])

self.handles: IOHandles | None = None
self.handles: IOHandles[str] | None = None

# Fallback to error to pass a sketchy test(test_override_set_noconvert_columns)
# Normally, this arg would get pre-processed earlier on
Expand Down
2 changes: 1 addition & 1 deletion pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2294,7 +2294,7 @@ def __init__(
self._value_labels: list[StataValueLabel] = []
self._has_value_labels = np.array([], dtype=bool)
self._compression = compression
self._output_file: Buffer | None = None
self._output_file: Buffer[bytes] | None = None
self._converted_names: dict[Hashable, str] = {}
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)
Expand Down
6 changes: 4 additions & 2 deletions pandas/tests/io/xml/test_to_xml.py
Original file line number Diff line number Diff line change
Expand Up @@ -1287,7 +1287,8 @@ def test_compression_output(parser, comp):

output = equalize_decl(output)

assert geom_xml == output.strip()
# error: Item "None" of "Union[str, bytes, None]" has no attribute "strip"
assert geom_xml == output.strip() # type: ignore[union-attr]


@pytest.mark.parametrize("comp", ["bz2", "gzip", "xz", "zip"])
Expand All @@ -1305,7 +1306,8 @@ def test_filename_and_suffix_comp(parser, comp, compfile):

output = equalize_decl(output)

assert geom_xml == output.strip()
# error: Item "None" of "Union[str, bytes, None]" has no attribute "strip"
assert geom_xml == output.strip() # type: ignore[union-attr]


def test_unsuported_compression(datapath, parser):
Expand Down