Skip to content

Commit 3d29aee

Browse files
authored
TYP: mostly io, plotting (#37059)
1 parent abd3acf commit 3d29aee

File tree

10 files changed

+73
-33
lines changed

10 files changed

+73
-33
lines changed

pandas/core/base.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,22 @@
44

55
import builtins
66
import textwrap
7-
from typing import Any, Callable, Dict, FrozenSet, Optional, TypeVar, Union
7+
from typing import (
8+
TYPE_CHECKING,
9+
Any,
10+
Callable,
11+
Dict,
12+
FrozenSet,
13+
Optional,
14+
TypeVar,
15+
Union,
16+
cast,
17+
)
818

919
import numpy as np
1020

1121
import pandas._libs.lib as lib
12-
from pandas._typing import IndexLabel
22+
from pandas._typing import DtypeObj, IndexLabel
1323
from pandas.compat import PYPY
1424
from pandas.compat.numpy import function as nv
1525
from pandas.errors import AbstractMethodError
@@ -33,6 +43,9 @@
3343
from pandas.core.construction import create_series_with_explicit_dtype
3444
import pandas.core.nanops as nanops
3545

46+
if TYPE_CHECKING:
47+
from pandas import Categorical
48+
3649
_shared_docs: Dict[str, str] = dict()
3750
_indexops_doc_kwargs = dict(
3851
klass="IndexOpsMixin",
@@ -238,7 +251,7 @@ def _gotitem(self, key, ndim: int, subset=None):
238251
Parameters
239252
----------
240253
key : str / list of selections
241-
ndim : 1,2
254+
ndim : {1, 2}
242255
requested ndim of result
243256
subset : object, default None
244257
subset to act on
@@ -305,6 +318,11 @@ class IndexOpsMixin(OpsMixin):
305318
["tolist"] # tolist is not deprecated, just suppressed in the __dir__
306319
)
307320

321+
@property
322+
def dtype(self) -> DtypeObj:
323+
# must be defined here as a property for mypy
324+
raise AbstractMethodError(self)
325+
308326
@property
309327
def _values(self) -> Union[ExtensionArray, np.ndarray]:
310328
# must be defined here as a property for mypy
@@ -832,6 +850,7 @@ def _map_values(self, mapper, na_action=None):
832850
if is_categorical_dtype(self.dtype):
833851
# use the built in categorical series mapper which saves
834852
# time by mapping the categories instead of all values
853+
self = cast("Categorical", self)
835854
return self._values.map(mapper)
836855

837856
values = self._values

pandas/core/indexes/category.py

+1
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ class CategoricalIndex(ExtensionIndex, accessor.PandasDelegate):
164164
codes: np.ndarray
165165
categories: Index
166166
_data: Categorical
167+
_values: Categorical
167168

168169
@property
169170
def _engine_type(self):

pandas/io/excel/_base.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -785,14 +785,19 @@ def _value_with_fmt(self, val):
785785
return val, fmt
786786

787787
@classmethod
788-
def check_extension(cls, ext):
788+
def check_extension(cls, ext: str):
789789
"""
790790
checks that path's extension against the Writer's supported
791791
extensions. If it isn't supported, raises UnsupportedFiletypeError.
792792
"""
793793
if ext.startswith("."):
794794
ext = ext[1:]
795-
if not any(ext in extension for extension in cls.supported_extensions):
795+
# error: "Callable[[ExcelWriter], Any]" has no attribute "__iter__"
796+
# (not iterable) [attr-defined]
797+
if not any(
798+
ext in extension
799+
for extension in cls.supported_extensions # type: ignore[attr-defined]
800+
):
796801
raise ValueError(f"Invalid extension for engine '{cls.engine}': '{ext}'")
797802
else:
798803
return True

pandas/io/formats/format.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -1835,9 +1835,11 @@ def _make_fixed_width(
18351835
return strings
18361836

18371837
if adj is None:
1838-
adj = get_adjustment()
1838+
adjustment = get_adjustment()
1839+
else:
1840+
adjustment = adj
18391841

1840-
max_len = max(adj.len(x) for x in strings)
1842+
max_len = max(adjustment.len(x) for x in strings)
18411843

18421844
if minimum is not None:
18431845
max_len = max(minimum, max_len)
@@ -1846,14 +1848,14 @@ def _make_fixed_width(
18461848
if conf_max is not None and max_len > conf_max:
18471849
max_len = conf_max
18481850

1849-
def just(x):
1851+
def just(x: str) -> str:
18501852
if conf_max is not None:
1851-
if (conf_max > 3) & (adj.len(x) > max_len):
1853+
if (conf_max > 3) & (adjustment.len(x) > max_len):
18521854
x = x[: max_len - 3] + "..."
18531855
return x
18541856

18551857
strings = [just(x) for x in strings]
1856-
result = adj.justify(strings, max_len, mode=justify)
1858+
result = adjustment.justify(strings, max_len, mode=justify)
18571859
return result
18581860

18591861

pandas/io/parsers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1661,7 +1661,7 @@ def _get_name(icol):
16611661

16621662
return index
16631663

1664-
def _agg_index(self, index, try_parse_dates=True):
1664+
def _agg_index(self, index, try_parse_dates=True) -> Index:
16651665
arrays = []
16661666

16671667
for i, arr in enumerate(index):

pandas/io/pytables.py

+24-9
Original file line numberDiff line numberDiff line change
@@ -565,6 +565,7 @@ def __fspath__(self):
565565
def root(self):
566566
""" return the root node """
567567
self._check_if_open()
568+
assert self._handle is not None # for mypy
568569
return self._handle.root
569570

570571
@property
@@ -1393,6 +1394,8 @@ def groups(self):
13931394
"""
13941395
_tables()
13951396
self._check_if_open()
1397+
assert self._handle is not None # for mypy
1398+
assert _table_mod is not None # for mypy
13961399
return [
13971400
g
13981401
for g in self._handle.walk_groups()
@@ -1437,6 +1440,9 @@ def walk(self, where="/"):
14371440
"""
14381441
_tables()
14391442
self._check_if_open()
1443+
assert self._handle is not None # for mypy
1444+
assert _table_mod is not None # for mypy
1445+
14401446
for g in self._handle.walk_groups(where):
14411447
if getattr(g._v_attrs, "pandas_type", None) is not None:
14421448
continue
@@ -1862,6 +1868,8 @@ def __init__(
18621868
def __iter__(self):
18631869
# iterate
18641870
current = self.start
1871+
if self.coordinates is None:
1872+
raise ValueError("Cannot iterate until get_result is called.")
18651873
while current < self.stop:
18661874
stop = min(current + self.chunksize, self.stop)
18671875
value = self.func(None, None, self.coordinates[current:stop])
@@ -3196,7 +3204,7 @@ class Table(Fixed):
31963204
pandas_kind = "wide_table"
31973205
format_type: str = "table" # GH#30962 needed by dask
31983206
table_type: str
3199-
levels = 1
3207+
levels: Union[int, List[Label]] = 1
32003208
is_table = True
32013209

32023210
index_axes: List[IndexCol]
@@ -3292,7 +3300,9 @@ def is_multi_index(self) -> bool:
32923300
"""the levels attribute is 1 or a list in the case of a multi-index"""
32933301
return isinstance(self.levels, list)
32943302

3295-
def validate_multiindex(self, obj):
3303+
def validate_multiindex(
3304+
self, obj: FrameOrSeriesUnion
3305+
) -> Tuple[DataFrame, List[Label]]:
32963306
"""
32973307
validate that we can store the multi-index; reset and return the
32983308
new object
@@ -3301,11 +3311,13 @@ def validate_multiindex(self, obj):
33013311
l if l is not None else f"level_{i}" for i, l in enumerate(obj.index.names)
33023312
]
33033313
try:
3304-
return obj.reset_index(), levels
3314+
reset_obj = obj.reset_index()
33053315
except ValueError as err:
33063316
raise ValueError(
33073317
"duplicate names/columns in the multi-index when storing as a table"
33083318
) from err
3319+
assert isinstance(reset_obj, DataFrame) # for mypy
3320+
return reset_obj, levels
33093321

33103322
@property
33113323
def nrows_expected(self) -> int:
@@ -3433,7 +3445,7 @@ def get_attrs(self):
34333445
self.nan_rep = getattr(self.attrs, "nan_rep", None)
34343446
self.encoding = _ensure_encoding(getattr(self.attrs, "encoding", None))
34353447
self.errors = _ensure_decoded(getattr(self.attrs, "errors", "strict"))
3436-
self.levels = getattr(self.attrs, "levels", None) or []
3448+
self.levels: List[Label] = getattr(self.attrs, "levels", None) or []
34373449
self.index_axes = [a for a in self.indexables if a.is_an_indexable]
34383450
self.values_axes = [a for a in self.indexables if not a.is_an_indexable]
34393451

@@ -4562,11 +4574,12 @@ class AppendableMultiSeriesTable(AppendableSeriesTable):
45624574
def write(self, obj, **kwargs):
45634575
""" we are going to write this as a frame table """
45644576
name = obj.name or "values"
4565-
obj, self.levels = self.validate_multiindex(obj)
4577+
newobj, self.levels = self.validate_multiindex(obj)
4578+
assert isinstance(self.levels, list) # for mypy
45664579
cols = list(self.levels)
45674580
cols.append(name)
4568-
obj.columns = cols
4569-
return super().write(obj=obj, **kwargs)
4581+
newobj.columns = Index(cols)
4582+
return super().write(obj=newobj, **kwargs)
45704583

45714584

45724585
class GenericTable(AppendableFrameTable):
@@ -4576,6 +4589,7 @@ class GenericTable(AppendableFrameTable):
45764589
table_type = "generic_table"
45774590
ndim = 2
45784591
obj_type = DataFrame
4592+
levels: List[Label]
45794593

45804594
@property
45814595
def pandas_type(self) -> str:
@@ -4609,7 +4623,7 @@ def indexables(self):
46094623
name="index", axis=0, table=self.table, meta=meta, metadata=md
46104624
)
46114625

4612-
_indexables = [index_col]
4626+
_indexables: List[Union[GenericIndexCol, GenericDataIndexableCol]] = [index_col]
46134627

46144628
for i, n in enumerate(d._v_names):
46154629
assert isinstance(n, str)
@@ -4652,6 +4666,7 @@ def write(self, obj, data_columns=None, **kwargs):
46524666
elif data_columns is True:
46534667
data_columns = obj.columns.tolist()
46544668
obj, self.levels = self.validate_multiindex(obj)
4669+
assert isinstance(self.levels, list) # for mypy
46554670
for n in self.levels:
46564671
if n not in data_columns:
46574672
data_columns.insert(0, n)
@@ -5173,7 +5188,7 @@ def select_coords(self):
51735188
start = 0
51745189
elif start < 0:
51755190
start += nrows
5176-
if self.stop is None:
5191+
if stop is None:
51775192
stop = nrows
51785193
elif stop < 0:
51795194
stop += nrows

pandas/io/stata.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ def parse_dates_safe(dates, delta=False, year=False, days=False):
378378
d["delta"] = time_delta._values.astype(np.int64) // 1000 # microseconds
379379
if days or year:
380380
date_index = DatetimeIndex(dates)
381-
d["year"] = date_index.year
382-
d["month"] = date_index.month
381+
d["year"] = date_index._data.year
382+
d["month"] = date_index._data.month
383383
if days:
384384
days_in_ns = dates.astype(np.int64) - to_datetime(
385385
d["year"], format="%Y"
@@ -887,7 +887,9 @@ def __init__(self):
887887
(65530, np.int8),
888888
]
889889
)
890-
self.TYPE_MAP = list(range(251)) + list("bhlfd")
890+
# error: Argument 1 to "list" has incompatible type "str";
891+
# expected "Iterable[int]" [arg-type]
892+
self.TYPE_MAP = list(range(251)) + list("bhlfd") # type: ignore[arg-type]
891893
self.TYPE_MAP_XML = dict(
892894
[
893895
# Not really a Q, unclear how to handle byteswap

pandas/plotting/_matplotlib/core.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ def _kind(self):
8282
_default_rot = 0
8383
orientation: Optional[str] = None
8484

85+
axes: np.ndarray # of Axes objects
86+
8587
def __init__(
8688
self,
8789
data,
@@ -177,7 +179,7 @@ def __init__(
177179

178180
self.ax = ax
179181
self.fig = fig
180-
self.axes = None
182+
self.axes = np.array([], dtype=object) # "real" version get set in `generate`
181183

182184
# parse errorbar input if given
183185
xerr = kwds.pop("xerr", None)
@@ -697,7 +699,7 @@ def _get_ax_layer(cls, ax, primary=True):
697699
else:
698700
return getattr(ax, "right_ax", ax)
699701

700-
def _get_ax(self, i):
702+
def _get_ax(self, i: int):
701703
# get the twinx ax if appropriate
702704
if self.subplots:
703705
ax = self.axes[i]

pandas/plotting/_matplotlib/tools.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -401,11 +401,11 @@ def handle_shared_axes(
401401
_remove_labels_from_axis(ax.yaxis)
402402

403403

404-
def flatten_axes(axes: Union["Axes", Sequence["Axes"]]) -> Sequence["Axes"]:
404+
def flatten_axes(axes: Union["Axes", Sequence["Axes"]]) -> np.ndarray:
405405
if not is_list_like(axes):
406406
return np.array([axes])
407407
elif isinstance(axes, (np.ndarray, ABCIndexClass)):
408-
return axes.ravel()
408+
return np.asarray(axes).ravel()
409409
return np.array(axes)
410410

411411

setup.cfg

-6
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,6 @@ check_untyped_defs=False
223223
[mypy-pandas.io.parsers]
224224
check_untyped_defs=False
225225

226-
[mypy-pandas.io.pytables]
227-
check_untyped_defs=False
228-
229-
[mypy-pandas.io.stata]
230-
check_untyped_defs=False
231-
232226
[mypy-pandas.plotting._matplotlib.converter]
233227
check_untyped_defs=False
234228

0 commit comments

Comments
 (0)