Skip to content

Commit ed7622c

Browse files
authored
TYP: make IOHandles generic (#43855)
1 parent ce12646 commit ed7622c

File tree

8 files changed

+50
-13
lines changed

8 files changed

+50
-13
lines changed

pandas/_typing.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
BufferedIOBase,
88
RawIOBase,
99
TextIOBase,
10-
TextIOWrapper,
1110
)
1211
from mmap import mmap
1312
from os import PathLike
@@ -170,7 +169,7 @@
170169
PythonFuncType = Callable[[Any], Any]
171170

172171
# filenames and file-like-objects
173-
Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, TextIOWrapper, mmap]
172+
Buffer = Union[IO[AnyStr], RawIOBase, BufferedIOBase, TextIOBase, mmap]
174173
FileOrBuffer = Union[str, Buffer[AnyStr]]
175174
FilePathOrBuffer = Union["PathLike[str]", FileOrBuffer[AnyStr]]
176175

pandas/io/common.py

+38-4
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,11 @@
2323
IO,
2424
Any,
2525
AnyStr,
26+
Generic,
27+
Literal,
2628
Mapping,
2729
cast,
30+
overload,
2831
)
2932
from urllib.parse import (
3033
urljoin,
@@ -78,7 +81,7 @@ class IOArgs:
7881

7982

8083
@dataclasses.dataclass
81-
class IOHandles:
84+
class IOHandles(Generic[AnyStr]):
8285
"""
8386
Return value of io/common.py:get_handle
8487
@@ -92,7 +95,7 @@ class IOHandles:
9295
is_wrapped: Whether a TextIOWrapper needs to be detached.
9396
"""
9497

95-
handle: Buffer
98+
handle: Buffer[AnyStr]
9699
compression: CompressionDict
97100
created_handles: list[Buffer] = dataclasses.field(default_factory=list)
98101
is_wrapped: bool = False
@@ -118,7 +121,7 @@ def close(self) -> None:
118121
self.created_handles = []
119122
self.is_wrapped = False
120123

121-
def __enter__(self) -> IOHandles:
124+
def __enter__(self) -> IOHandles[AnyStr]:
122125
return self
123126

124127
def __exit__(self, *args: Any) -> None:
@@ -533,16 +536,47 @@ def check_parent_directory(path: Path | str) -> None:
533536
raise OSError(fr"Cannot save file into a non-existent directory: '{parent}'")
534537

535538

539+
@overload
536540
def get_handle(
537541
path_or_buf: FilePathOrBuffer,
538542
mode: str,
543+
*,
544+
encoding: str | None = ...,
545+
compression: CompressionOptions = ...,
546+
memory_map: bool = ...,
547+
is_text: Literal[False],
548+
errors: str | None = ...,
549+
storage_options: StorageOptions = ...,
550+
) -> IOHandles[bytes]:
551+
...
552+
553+
554+
@overload
555+
def get_handle(
556+
path_or_buf: FilePathOrBuffer,
557+
mode: str,
558+
*,
559+
encoding: str | None = ...,
560+
compression: CompressionOptions = ...,
561+
memory_map: bool = ...,
562+
is_text: Literal[True] = True,
563+
errors: str | None = ...,
564+
storage_options: StorageOptions = ...,
565+
) -> IOHandles[str]:
566+
...
567+
568+
569+
def get_handle(
570+
path_or_buf: FilePathOrBuffer,
571+
mode: str,
572+
*,
539573
encoding: str | None = None,
540574
compression: CompressionOptions = None,
541575
memory_map: bool = False,
542576
is_text: bool = True,
543577
errors: str | None = None,
544578
storage_options: StorageOptions = None,
545-
) -> IOHandles:
579+
) -> IOHandles[str] | IOHandles[bytes]:
546580
"""
547581
Get file handle for given path/buffer and mode.
548582

pandas/io/excel/_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -941,7 +941,9 @@ def __init__(
941941
mode = mode.replace("a", "r+")
942942

943943
# cast ExcelWriter to avoid adding 'if self.handles is not None'
944-
self.handles = IOHandles(cast(Buffer, path), compression={"copression": None})
944+
self.handles = IOHandles(
945+
cast(Buffer[bytes], path), compression={"copression": None}
946+
)
945947
if not isinstance(path, ExcelWriter):
946948
self.handles = get_handle(
947949
path, mode, storage_options=storage_options, is_text=False

pandas/io/json/_json.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -661,7 +661,7 @@ def __init__(
661661
self.nrows_seen = 0
662662
self.nrows = nrows
663663
self.encoding_errors = encoding_errors
664-
self.handles: IOHandles | None = None
664+
self.handles: IOHandles[str] | None = None
665665

666666
if self.chunksize is not None:
667667
self.chunksize = validate_integer("chunksize", self.chunksize, 1)

pandas/io/parquet.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def _get_path_or_handle(
7474
storage_options: StorageOptions = None,
7575
mode: str = "rb",
7676
is_dir: bool = False,
77-
) -> tuple[FilePathOrBuffer, IOHandles | None, Any]:
77+
) -> tuple[FilePathOrBuffer, IOHandles[bytes] | None, Any]:
7878
"""File handling for PyArrow."""
7979
path_or_handle = stringify_path(path)
8080
if is_fsspec_url(path_or_handle) and fs is None:

pandas/io/parsers/base_parser.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def __init__(self, kwds):
212212

213213
self.usecols, self.usecols_dtype = self._validate_usecols_arg(kwds["usecols"])
214214

215-
self.handles: IOHandles | None = None
215+
self.handles: IOHandles[str] | None = None
216216

217217
# Fallback to error to pass a sketchy test(test_override_set_noconvert_columns)
218218
# Normally, this arg would get pre-processed earlier on

pandas/io/stata.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2294,7 +2294,7 @@ def __init__(
22942294
self._value_labels: list[StataValueLabel] = []
22952295
self._has_value_labels = np.array([], dtype=bool)
22962296
self._compression = compression
2297-
self._output_file: Buffer | None = None
2297+
self._output_file: Buffer[bytes] | None = None
22982298
self._converted_names: dict[Hashable, str] = {}
22992299
# attach nobs, nvars, data, varlist, typlist
23002300
self._prepare_pandas(data)

pandas/tests/io/xml/test_to_xml.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1287,7 +1287,8 @@ def test_compression_output(parser, comp):
12871287

12881288
output = equalize_decl(output)
12891289

1290-
assert geom_xml == output.strip()
1290+
# error: Item "None" of "Union[str, bytes, None]" has no attribute "strip"
1291+
assert geom_xml == output.strip() # type: ignore[union-attr]
12911292

12921293

12931294
@pytest.mark.parametrize("comp", ["bz2", "gzip", "xz", "zip"])
@@ -1305,7 +1306,8 @@ def test_filename_and_suffix_comp(parser, comp, compfile):
13051306

13061307
output = equalize_decl(output)
13071308

1308-
assert geom_xml == output.strip()
1309+
# error: Item "None" of "Union[str, bytes, None]" has no attribute "strip"
1310+
assert geom_xml == output.strip() # type: ignore[union-attr]
13091311

13101312

13111313
def test_unsuported_compression(datapath, parser):

0 commit comments

Comments
 (0)