1
1
from __future__ import annotations
2
2
3
3
from typing import (
4
+ TYPE_CHECKING ,
4
5
Any ,
6
+ Callable ,
7
+ Hashable ,
8
+ Iterable ,
9
+ Literal ,
5
10
MutableMapping ,
11
+ Sequence ,
12
+ TypeVar ,
13
+ overload ,
6
14
)
7
15
8
16
from pandas .compat ._optional import import_optional_dependency
12
20
is_list_like ,
13
21
)
14
22
15
- _writers : MutableMapping [str , str ] = {}
23
+ if TYPE_CHECKING :
24
+ from pandas .io .excel ._base import ExcelWriter
16
25
26
+ ExcelWriter_t = type [ExcelWriter ]
27
+ usecols_func = TypeVar ("usecols_func" , bound = Callable [[Hashable ], object ])
17
28
18
- def register_writer (klass ):
29
+ _writers : MutableMapping [str , ExcelWriter_t ] = {}
30
+
31
+
32
+ def register_writer (klass : ExcelWriter_t ) -> None :
19
33
"""
20
34
Add engine to the excel writer registry.io.excel.
21
35
@@ -28,10 +42,12 @@ def register_writer(klass):
28
42
if not callable (klass ):
29
43
raise ValueError ("Can only register callables as engines" )
30
44
engine_name = klass .engine
45
+ # for mypy
46
+ assert isinstance (engine_name , str )
31
47
_writers [engine_name ] = klass
32
48
33
49
34
- def get_default_engine (ext , mode = "reader" ):
50
+ def get_default_engine (ext : str , mode : Literal [ "reader" , "writer" ] = "reader" ) -> str :
35
51
"""
36
52
Return the default reader/writer for the given extension.
37
53
@@ -73,7 +89,7 @@ def get_default_engine(ext, mode="reader"):
73
89
return _default_readers [ext ]
74
90
75
91
76
- def get_writer (engine_name ) :
92
+ def get_writer (engine_name : str ) -> ExcelWriter_t :
77
93
try :
78
94
return _writers [engine_name ]
79
95
except KeyError as err :
@@ -145,7 +161,29 @@ def _range2cols(areas: str) -> list[int]:
145
161
return cols
146
162
147
163
148
- def maybe_convert_usecols (usecols ):
164
+ @overload
165
+ def maybe_convert_usecols (usecols : str | list [int ]) -> list [int ]:
166
+ ...
167
+
168
+
169
+ @overload
170
+ def maybe_convert_usecols (usecols : list [str ]) -> list [str ]:
171
+ ...
172
+
173
+
174
+ @overload
175
+ def maybe_convert_usecols (usecols : usecols_func ) -> usecols_func :
176
+ ...
177
+
178
+
179
+ @overload
180
+ def maybe_convert_usecols (usecols : None ) -> None :
181
+ ...
182
+
183
+
184
+ def maybe_convert_usecols (
185
+ usecols : str | list [int ] | list [str ] | usecols_func | None ,
186
+ ) -> None | list [int ] | list [str ] | usecols_func :
149
187
"""
150
188
Convert `usecols` into a compatible format for parsing in `parsers.py`.
151
189
@@ -174,7 +212,17 @@ def maybe_convert_usecols(usecols):
174
212
return usecols
175
213
176
214
177
- def validate_freeze_panes (freeze_panes ):
215
+ @overload
216
+ def validate_freeze_panes (freeze_panes : tuple [int , int ]) -> Literal [True ]:
217
+ ...
218
+
219
+
220
+ @overload
221
+ def validate_freeze_panes (freeze_panes : None ) -> Literal [False ]:
222
+ ...
223
+
224
+
225
+ def validate_freeze_panes (freeze_panes : tuple [int , int ] | None ) -> bool :
178
226
if freeze_panes is not None :
179
227
if len (freeze_panes ) == 2 and all (
180
228
isinstance (item , int ) for item in freeze_panes
@@ -191,7 +239,9 @@ def validate_freeze_panes(freeze_panes):
191
239
return False
192
240
193
241
194
- def fill_mi_header (row , control_row ):
242
+ def fill_mi_header (
243
+ row : list [Hashable ], control_row : list [bool ]
244
+ ) -> tuple [list [Hashable ], list [bool ]]:
195
245
"""
196
246
Forward fill blank entries in row but only inside the same parent index.
197
247
@@ -224,7 +274,9 @@ def fill_mi_header(row, control_row):
224
274
return row , control_row
225
275
226
276
227
- def pop_header_name (row , index_col ):
277
+ def pop_header_name (
278
+ row : list [Hashable ], index_col : int | Sequence [int ]
279
+ ) -> tuple [Hashable | None , list [Hashable ]]:
228
280
"""
229
281
Pop the header name for MultiIndex parsing.
230
282
@@ -243,7 +295,12 @@ def pop_header_name(row, index_col):
243
295
The original data row with the header name removed.
244
296
"""
245
297
# Pop out header name and fill w/blank.
246
- i = index_col if not is_list_like (index_col ) else max (index_col )
298
+ if is_list_like (index_col ):
299
+ assert isinstance (index_col , Iterable )
300
+ i = max (index_col )
301
+ else :
302
+ assert not isinstance (index_col , Iterable )
303
+ i = index_col
247
304
248
305
header_name = row [i ]
249
306
header_name = None if header_name == "" else header_name
0 commit comments