Skip to content

Commit c3d3357

Browse files
authored
TYP: type excel util module (#45014)
1 parent 3c19380 commit c3d3357

File tree

2 files changed

+67
-10
lines changed

2 files changed

+67
-10
lines changed

pandas/io/excel/_base.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -943,7 +943,7 @@ def supported_extensions(self):
943943

944944
@property
945945
@abc.abstractmethod
946-
def engine(self):
946+
def engine(self) -> str:
947947
"""Name of engine."""
948948
pass
949949

pandas/io/excel/_util.py

+66-9
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
from __future__ import annotations
22

33
from typing import (
4+
TYPE_CHECKING,
45
Any,
6+
Callable,
7+
Hashable,
8+
Iterable,
9+
Literal,
510
MutableMapping,
11+
Sequence,
12+
TypeVar,
13+
overload,
614
)
715

816
from pandas.compat._optional import import_optional_dependency
@@ -12,10 +20,16 @@
1220
is_list_like,
1321
)
1422

15-
_writers: MutableMapping[str, str] = {}
23+
if TYPE_CHECKING:
24+
from pandas.io.excel._base import ExcelWriter
1625

26+
ExcelWriter_t = type[ExcelWriter]
27+
usecols_func = TypeVar("usecols_func", bound=Callable[[Hashable], object])
1728

18-
def register_writer(klass):
29+
_writers: MutableMapping[str, ExcelWriter_t] = {}
30+
31+
32+
def register_writer(klass: ExcelWriter_t) -> None:
1933
"""
2034
Add engine to the excel writer registry.io.excel.
2135
@@ -28,10 +42,12 @@ def register_writer(klass):
2842
if not callable(klass):
2943
raise ValueError("Can only register callables as engines")
3044
engine_name = klass.engine
45+
# for mypy
46+
assert isinstance(engine_name, str)
3147
_writers[engine_name] = klass
3248

3349

34-
def get_default_engine(ext, mode="reader"):
50+
def get_default_engine(ext: str, mode: Literal["reader", "writer"] = "reader") -> str:
3551
"""
3652
Return the default reader/writer for the given extension.
3753
@@ -73,7 +89,7 @@ def get_default_engine(ext, mode="reader"):
7389
return _default_readers[ext]
7490

7591

76-
def get_writer(engine_name):
92+
def get_writer(engine_name: str) -> ExcelWriter_t:
7793
try:
7894
return _writers[engine_name]
7995
except KeyError as err:
@@ -145,7 +161,29 @@ def _range2cols(areas: str) -> list[int]:
145161
return cols
146162

147163

148-
def maybe_convert_usecols(usecols):
164+
@overload
165+
def maybe_convert_usecols(usecols: str | list[int]) -> list[int]:
166+
...
167+
168+
169+
@overload
170+
def maybe_convert_usecols(usecols: list[str]) -> list[str]:
171+
...
172+
173+
174+
@overload
175+
def maybe_convert_usecols(usecols: usecols_func) -> usecols_func:
176+
...
177+
178+
179+
@overload
180+
def maybe_convert_usecols(usecols: None) -> None:
181+
...
182+
183+
184+
def maybe_convert_usecols(
185+
usecols: str | list[int] | list[str] | usecols_func | None,
186+
) -> None | list[int] | list[str] | usecols_func:
149187
"""
150188
Convert `usecols` into a compatible format for parsing in `parsers.py`.
151189
@@ -174,7 +212,17 @@ def maybe_convert_usecols(usecols):
174212
return usecols
175213

176214

177-
def validate_freeze_panes(freeze_panes):
215+
@overload
216+
def validate_freeze_panes(freeze_panes: tuple[int, int]) -> Literal[True]:
217+
...
218+
219+
220+
@overload
221+
def validate_freeze_panes(freeze_panes: None) -> Literal[False]:
222+
...
223+
224+
225+
def validate_freeze_panes(freeze_panes: tuple[int, int] | None) -> bool:
178226
if freeze_panes is not None:
179227
if len(freeze_panes) == 2 and all(
180228
isinstance(item, int) for item in freeze_panes
@@ -191,7 +239,9 @@ def validate_freeze_panes(freeze_panes):
191239
return False
192240

193241

194-
def fill_mi_header(row, control_row):
242+
def fill_mi_header(
243+
row: list[Hashable], control_row: list[bool]
244+
) -> tuple[list[Hashable], list[bool]]:
195245
"""
196246
Forward fill blank entries in row but only inside the same parent index.
197247
@@ -224,7 +274,9 @@ def fill_mi_header(row, control_row):
224274
return row, control_row
225275

226276

227-
def pop_header_name(row, index_col):
277+
def pop_header_name(
278+
row: list[Hashable], index_col: int | Sequence[int]
279+
) -> tuple[Hashable | None, list[Hashable]]:
228280
"""
229281
Pop the header name for MultiIndex parsing.
230282
@@ -243,7 +295,12 @@ def pop_header_name(row, index_col):
243295
The original data row with the header name removed.
244296
"""
245297
# Pop out header name and fill w/blank.
246-
i = index_col if not is_list_like(index_col) else max(index_col)
298+
if is_list_like(index_col):
299+
assert isinstance(index_col, Iterable)
300+
i = max(index_col)
301+
else:
302+
assert not isinstance(index_col, Iterable)
303+
i = index_col
247304

248305
header_name = row[i]
249306
header_name = None if header_name == "" else header_name

0 commit comments

Comments
 (0)