Skip to content

Commit 99376b3

Browse files
authored
TYP: Add cast to ABC classes. (#37902)
1 parent a04a6f7 commit 99376b3

File tree

14 files changed

+74
-37
lines changed

14 files changed

+74
-37
lines changed

pandas/core/algorithms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from pandas.core.indexers import validate_indices
6161

6262
if TYPE_CHECKING:
63-
from pandas import Categorical, DataFrame, Series
63+
from pandas import Categorical, DataFrame, Index, Series
6464

6565
_shared_docs: Dict[str, str] = {}
6666

@@ -540,7 +540,7 @@ def factorize(
540540
sort: bool = False,
541541
na_sentinel: Optional[int] = -1,
542542
size_hint: Optional[int] = None,
543-
) -> Tuple[np.ndarray, Union[np.ndarray, ABCIndex]]:
543+
) -> Tuple[np.ndarray, Union[np.ndarray, "Index"]]:
544544
"""
545545
Encode the object as an enumerated type or categorical variable.
546546

pandas/core/base.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,14 @@ def __getitem__(self, key):
269269
return self._gotitem(list(key), ndim=2)
270270

271271
elif not getattr(self, "as_index", False):
272-
if key not in self.obj.columns:
272+
# error: "SelectionMixin" has no attribute "obj" [attr-defined]
273+
if key not in self.obj.columns: # type: ignore[attr-defined]
273274
raise KeyError(f"Column not found: {key}")
274275
return self._gotitem(key, ndim=2)
275276

276277
else:
277-
if key not in self.obj:
278+
# error: "SelectionMixin" has no attribute "obj" [attr-defined]
279+
if key not in self.obj: # type: ignore[attr-defined]
278280
raise KeyError(f"Column not found: {key}")
279281
return self._gotitem(key, ndim=1)
280282

@@ -919,10 +921,9 @@ def _map_values(self, mapper, na_action=None):
919921
# "astype" [attr-defined]
920922
values = self.astype(object)._values # type: ignore[attr-defined]
921923
if na_action == "ignore":
922-
923-
def map_f(values, f):
924-
return lib.map_infer_mask(values, f, isna(values).view(np.uint8))
925-
924+
map_f = lambda values, f: lib.map_infer_mask(
925+
values, f, isna(values).view(np.uint8)
926+
)
926927
elif na_action is None:
927928
map_f = lib.map_infer
928929
else:

pandas/core/common.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,8 @@ def convert_to_list_like(
472472
inputs are returned unmodified whereas others are converted to list.
473473
"""
474474
if isinstance(values, (list, np.ndarray, ABCIndex, ABCSeries, ABCExtensionArray)):
475-
return values
475+
# np.ndarray resolving as Any gives a false positive
476+
return values # type: ignore[return-value]
476477
elif isinstance(values, abc.Iterable) and not isinstance(values, str):
477478
return list(values)
478479

pandas/core/computation/align.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""
22
Core eval alignment algorithms.
33
"""
4+
from __future__ import annotations
45

56
from functools import partial, wraps
6-
from typing import Dict, Optional, Sequence, Tuple, Type, Union
7+
from typing import TYPE_CHECKING, Dict, Optional, Sequence, Tuple, Type, Union
78
import warnings
89

910
import numpy as np
@@ -17,13 +18,16 @@
1718
import pandas.core.common as com
1819
from pandas.core.computation.common import result_type_many
1920

21+
if TYPE_CHECKING:
22+
from pandas.core.indexes.api import Index
23+
2024

2125
def _align_core_single_unary_op(
2226
term,
23-
) -> Tuple[Union[partial, Type[FrameOrSeries]], Optional[Dict[str, int]]]:
27+
) -> Tuple[Union[partial, Type[FrameOrSeries]], Optional[Dict[str, Index]]]:
2428

2529
typ: Union[partial, Type[FrameOrSeries]]
26-
axes: Optional[Dict[str, int]] = None
30+
axes: Optional[Dict[str, Index]] = None
2731

2832
if isinstance(term.value, np.ndarray):
2933
typ = partial(np.asanyarray, dtype=term.value.dtype)
@@ -36,8 +40,8 @@ def _align_core_single_unary_op(
3640

3741

3842
def _zip_axes_from_type(
39-
typ: Type[FrameOrSeries], new_axes: Sequence[int]
40-
) -> Dict[str, int]:
43+
typ: Type[FrameOrSeries], new_axes: Sequence[Index]
44+
) -> Dict[str, Index]:
4145
return {name: new_axes[i] for i, name in enumerate(typ._AXIS_ORDERS)}
4246

4347

pandas/core/computation/parsing.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import tokenize
99
from typing import Iterator, Tuple
1010

11+
from pandas._typing import Label
12+
1113
# A token value Python's tokenizer probably will never use.
1214
BACKTICK_QUOTED_STRING = 100
1315

@@ -91,7 +93,7 @@ def clean_backtick_quoted_toks(tok: Tuple[int, str]) -> Tuple[int, str]:
9193
return toknum, tokval
9294

9395

94-
def clean_column_name(name: str) -> str:
96+
def clean_column_name(name: "Label") -> "Label":
9597
"""
9698
Function to emulate the cleaning of a backtick quoted name.
9799
@@ -102,12 +104,12 @@ def clean_column_name(name: str) -> str:
102104
103105
Parameters
104106
----------
105-
name : str
107+
name : hashable
106108
Name to be cleaned.
107109
108110
Returns
109111
-------
110-
name : str
112+
name : hashable
111113
Returns the name after tokenizing and cleaning.
112114
113115
Notes

pandas/core/construction.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def array(
351351
return result
352352

353353

354-
def extract_array(obj: AnyArrayLike, extract_numpy: bool = False) -> ArrayLike:
354+
def extract_array(obj: object, extract_numpy: bool = False) -> Union[Any, ArrayLike]:
355355
"""
356356
Extract the ndarray or ExtensionArray from a Series or Index.
357357
@@ -399,9 +399,7 @@ def extract_array(obj: AnyArrayLike, extract_numpy: bool = False) -> ArrayLike:
399399
if extract_numpy and isinstance(obj, ABCPandasArray):
400400
obj = obj.to_numpy()
401401

402-
# error: Incompatible return value type (got "Index", expected "ExtensionArray")
403-
# error: Incompatible return value type (got "Series", expected "ExtensionArray")
404-
return obj # type: ignore[return-value]
402+
return obj
405403

406404

407405
def sanitize_array(

pandas/core/dtypes/generic.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
""" define generic base classes for pandas objects """
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING, Type, cast
5+
6+
if TYPE_CHECKING:
7+
from pandas import DataFrame, Series
8+
from pandas.core.generic import NDFrame
29

310

411
# define abstract base classes to enable isinstance type checking on our
@@ -53,9 +60,17 @@ def _check(cls, inst) -> bool:
5360
},
5461
)
5562

56-
ABCNDFrame = create_pandas_abc_type("ABCNDFrame", "_typ", ("series", "dataframe"))
57-
ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series",))
58-
ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",))
63+
ABCNDFrame = cast(
64+
"Type[NDFrame]",
65+
create_pandas_abc_type("ABCNDFrame", "_typ", ("series", "dataframe")),
66+
)
67+
ABCSeries = cast(
68+
"Type[Series]",
69+
create_pandas_abc_type("ABCSeries", "_typ", ("series",)),
70+
)
71+
ABCDataFrame = cast(
72+
"Type[DataFrame]", create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",))
73+
)
5974

6075
ABCCategorical = create_pandas_abc_type("ABCCategorical", "_typ", ("categorical"))
6176
ABCDatetimeArray = create_pandas_abc_type("ABCDatetimeArray", "_typ", ("datetimearray"))

pandas/core/generic.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -512,7 +512,7 @@ def _get_axis_resolvers(self, axis: str) -> Dict[str, Union[Series, MultiIndex]]
512512
return d
513513

514514
@final
515-
def _get_index_resolvers(self) -> Dict[str, Union[Series, MultiIndex]]:
515+
def _get_index_resolvers(self) -> Dict[Label, Union[Series, MultiIndex]]:
516516
from pandas.core.computation.parsing import clean_column_name
517517

518518
d: Dict[str, Union[Series, MultiIndex]] = {}
@@ -522,7 +522,7 @@ def _get_index_resolvers(self) -> Dict[str, Union[Series, MultiIndex]]:
522522
return {clean_column_name(k): v for k, v in d.items() if not isinstance(k, int)}
523523

524524
@final
525-
def _get_cleaned_column_resolvers(self) -> Dict[str, ABCSeries]:
525+
def _get_cleaned_column_resolvers(self) -> Dict[Label, Series]:
526526
"""
527527
Return the special character free column resolvers of a dataframe.
528528
@@ -533,7 +533,6 @@ def _get_cleaned_column_resolvers(self) -> Dict[str, ABCSeries]:
533533
from pandas.core.computation.parsing import clean_column_name
534534

535535
if isinstance(self, ABCSeries):
536-
self = cast("Series", self)
537536
return {clean_column_name(self.name): self}
538537

539538
return {

pandas/core/indexing.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2007,7 +2007,7 @@ def ravel(i):
20072007

20082008
raise ValueError("Incompatible indexer with Series")
20092009

2010-
def _align_frame(self, indexer, df: ABCDataFrame):
2010+
def _align_frame(self, indexer, df: "DataFrame"):
20112011
is_frame = self.ndim == 2
20122012

20132013
if isinstance(indexer, tuple):

pandas/core/internals/construction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def extract_index(data) -> Index:
370370
index = Index([])
371371
elif len(data) > 0:
372372
raw_lengths = []
373-
indexes = []
373+
indexes: List[Union[List[Label], Index]] = []
374374

375375
have_raw_arrays = False
376376
have_series = False

pandas/core/reshape/concat.py

+16-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,21 @@
33
"""
44

55
from collections import abc
6-
from typing import TYPE_CHECKING, Iterable, List, Mapping, Type, Union, cast, overload
6+
from typing import (
7+
TYPE_CHECKING,
8+
Iterable,
9+
List,
10+
Mapping,
11+
Optional,
12+
Type,
13+
Union,
14+
cast,
15+
overload,
16+
)
717

818
import numpy as np
919

10-
from pandas._typing import FrameOrSeries, FrameOrSeriesUnion, Label
20+
from pandas._typing import FrameOrSeriesUnion, Label
1121

1222
from pandas.core.dtypes.concat import concat_compat
1323
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
@@ -296,7 +306,7 @@ class _Concatenator:
296306

297307
def __init__(
298308
self,
299-
objs: Union[Iterable[FrameOrSeries], Mapping[Label, FrameOrSeries]],
309+
objs: Union[Iterable["NDFrame"], Mapping[Label, "NDFrame"]],
300310
axis=0,
301311
join: str = "outer",
302312
keys=None,
@@ -367,7 +377,7 @@ def __init__(
367377
# get the sample
368378
# want the highest ndim that we have, and must be non-empty
369379
# unless all objs are empty
370-
sample = None
380+
sample: Optional["NDFrame"] = None
371381
if len(ndims) > 1:
372382
max_ndim = max(ndims)
373383
for obj in objs:
@@ -437,6 +447,8 @@ def __init__(
437447
# to line up
438448
if self._is_frame and axis == 1:
439449
name = 0
450+
# mypy needs to know sample is not an NDFrame
451+
sample = cast("FrameOrSeriesUnion", sample)
440452
obj = sample._constructor({name: obj})
441453

442454
self.objs.append(obj)

pandas/core/strings/accessor.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -157,11 +157,10 @@ def __init__(self, data):
157157
array = data.array
158158
self._array = array
159159

160+
self._index = self._name = None
160161
if isinstance(data, ABCSeries):
161162
self._index = data.index
162163
self._name = data.name
163-
else:
164-
self._index = self._name = None
165164

166165
# ._values.categories works for both Series/Index
167166
self._parent = data._values.categories if self._is_categorical else data

pandas/core/window/common.py

+4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Common utility functions for rolling operations"""
22
from collections import defaultdict
3+
from typing import cast
34
import warnings
45

56
import numpy as np
@@ -109,6 +110,9 @@ def dataframe_from_int_dict(data, frame_template):
109110

110111
# set the index and reorder
111112
if arg2.columns.nlevels > 1:
113+
# mypy needs to know columns is a MultiIndex, Index doesn't
114+
# have levels attribute
115+
arg2.columns = cast(MultiIndex, arg2.columns)
112116
result.index = MultiIndex.from_product(
113117
arg2.columns.levels + [result_index]
114118
)

pandas/plotting/_matplotlib/tools.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import matplotlib.ticker as ticker
88
import numpy as np
99

10-
from pandas._typing import FrameOrSeries
10+
from pandas._typing import FrameOrSeriesUnion
1111

1212
from pandas.core.dtypes.common import is_list_like
1313
from pandas.core.dtypes.generic import ABCDataFrame, ABCIndexClass, ABCSeries
@@ -30,7 +30,9 @@ def format_date_labels(ax: "Axes", rot):
3030
fig.subplots_adjust(bottom=0.2)
3131

3232

33-
def table(ax, data: FrameOrSeries, rowLabels=None, colLabels=None, **kwargs) -> "Table":
33+
def table(
34+
ax, data: FrameOrSeriesUnion, rowLabels=None, colLabels=None, **kwargs
35+
) -> "Table":
3436
if isinstance(data, ABCSeries):
3537
data = data.to_frame()
3638
elif isinstance(data, ABCDataFrame):

0 commit comments

Comments
 (0)