Skip to content

Commit 65abf6b

Browse files
TYP: change ArrayLike/AnyArrayLike alias to Union (#40379)
1 parent 63bfdf5 commit 65abf6b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+331
-606
lines changed

pandas/_typing.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from pandas.core.dtypes.dtypes import ExtensionDtype
4848

4949
from pandas import Interval
50-
from pandas.core.arrays.base import ExtensionArray # noqa: F401
50+
from pandas.core.arrays.base import ExtensionArray
5151
from pandas.core.frame import DataFrame
5252
from pandas.core.generic import NDFrame # noqa: F401
5353
from pandas.core.groupby.generic import (
@@ -74,8 +74,8 @@
7474

7575
# array-like
7676

77-
AnyArrayLike = TypeVar("AnyArrayLike", "ExtensionArray", "Index", "Series", np.ndarray)
78-
ArrayLike = TypeVar("ArrayLike", "ExtensionArray", np.ndarray)
77+
ArrayLike = Union["ExtensionArray", np.ndarray]
78+
AnyArrayLike = Union[ArrayLike, "Index", "Series"]
7979

8080
# scalars
8181

pandas/core/algorithms.py

+14-59
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,7 @@ def _ensure_data(values: ArrayLike) -> Tuple[np.ndarray, DtypeObj]:
176176
elif is_timedelta64_dtype(values.dtype):
177177
from pandas import TimedeltaIndex
178178

179-
# error: Incompatible types in assignment (expression has type
180-
# "TimedeltaArray", variable has type "ndarray")
181-
values = TimedeltaIndex(values)._data # type: ignore[assignment]
179+
values = TimedeltaIndex(values)._data
182180
else:
183181
# Datetime
184182
if values.ndim > 1 and is_datetime64_ns_dtype(values.dtype):
@@ -194,22 +192,13 @@ def _ensure_data(values: ArrayLike) -> Tuple[np.ndarray, DtypeObj]:
194192

195193
from pandas import DatetimeIndex
196194

197-
# Incompatible types in assignment (expression has type "DatetimeArray",
198-
# variable has type "ndarray")
199-
values = DatetimeIndex(values)._data # type: ignore[assignment]
195+
values = DatetimeIndex(values)._data
200196
dtype = values.dtype
201-
# error: Item "ndarray" of "Union[PeriodArray, Any, ndarray]" has no attribute
202-
# "asi8"
203-
return values.asi8, dtype # type: ignore[union-attr]
197+
return values.asi8, dtype
204198

205199
elif is_categorical_dtype(values.dtype):
206-
# error: Incompatible types in assignment (expression has type "Categorical",
207-
# variable has type "ndarray")
208-
values = cast("Categorical", values) # type: ignore[assignment]
209-
# error: Incompatible types in assignment (expression has type "ndarray",
210-
# variable has type "ExtensionArray")
211-
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute "codes"
212-
values = values.codes # type: ignore[assignment,union-attr]
200+
values = cast("Categorical", values)
201+
values = values.codes
213202
dtype = pandas_dtype("category")
214203

215204
# we are actually coercing to int64
@@ -222,10 +211,7 @@ def _ensure_data(values: ArrayLike) -> Tuple[np.ndarray, DtypeObj]:
222211
return values, dtype # type: ignore[return-value]
223212

224213
# we have failed, return object
225-
226-
# error: Incompatible types in assignment (expression has type "ndarray", variable
227-
# has type "ExtensionArray")
228-
values = np.asarray(values, dtype=object) # type: ignore[assignment]
214+
values = np.asarray(values, dtype=object)
229215
return ensure_object(values), np.dtype("object")
230216

231217

@@ -335,9 +321,7 @@ def _get_values_for_rank(values: ArrayLike):
335321
if is_categorical_dtype(values):
336322
values = cast("Categorical", values)._values_for_rank()
337323

338-
# error: Incompatible types in assignment (expression has type "ndarray", variable
339-
# has type "ExtensionArray")
340-
values, _ = _ensure_data(values) # type: ignore[assignment]
324+
values, _ = _ensure_data(values)
341325
return values
342326

343327

@@ -503,42 +487,15 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> np.ndarray:
503487
)
504488

505489
if not isinstance(values, (ABCIndex, ABCSeries, ABCExtensionArray, np.ndarray)):
506-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
507-
# variable has type "Index")
508-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
509-
# variable has type "Series")
510-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
511-
# variable has type "ndarray")
512-
values = _ensure_arraylike(list(values)) # type: ignore[assignment]
490+
values = _ensure_arraylike(list(values))
513491
elif isinstance(values, ABCMultiIndex):
514492
# Avoid raising in extract_array
515-
516-
# error: Incompatible types in assignment (expression has type "ndarray",
517-
# variable has type "ExtensionArray")
518-
# error: Incompatible types in assignment (expression has type "ndarray",
519-
# variable has type "Index")
520-
# error: Incompatible types in assignment (expression has type "ndarray",
521-
# variable has type "Series")
522-
values = np.array(values) # type: ignore[assignment]
493+
values = np.array(values)
523494
else:
524-
# error: Incompatible types in assignment (expression has type "Union[Any,
525-
# ExtensionArray]", variable has type "Index")
526-
# error: Incompatible types in assignment (expression has type "Union[Any,
527-
# ExtensionArray]", variable has type "Series")
528-
values = extract_array(values, extract_numpy=True) # type: ignore[assignment]
529-
530-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
531-
# variable has type "Index")
532-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
533-
# variable has type "Series")
534-
# error: Incompatible types in assignment (expression has type "ExtensionArray",
535-
# variable has type "ndarray")
536-
comps = _ensure_arraylike(comps) # type: ignore[assignment]
537-
# error: Incompatible types in assignment (expression has type "Union[Any,
538-
# ExtensionArray]", variable has type "Index")
539-
# error: Incompatible types in assignment (expression has type "Union[Any,
540-
# ExtensionArray]", variable has type "Series")
541-
comps = extract_array(comps, extract_numpy=True) # type: ignore[assignment]
495+
values = extract_array(values, extract_numpy=True)
496+
497+
comps = _ensure_arraylike(comps)
498+
comps = extract_array(comps, extract_numpy=True)
542499
if is_extension_array_dtype(comps.dtype):
543500
# error: Incompatible return value type (got "Series", expected "ndarray")
544501
# error: Item "ndarray" of "Union[Any, ndarray]" has no attribute "isin"
@@ -1000,9 +957,7 @@ def duplicated(values: ArrayLike, keep: Union[str, bool] = "first") -> np.ndarra
1000957
-------
1001958
duplicated : ndarray
1002959
"""
1003-
# error: Incompatible types in assignment (expression has type "ndarray", variable
1004-
# has type "ExtensionArray")
1005-
values, _ = _ensure_data(values) # type: ignore[assignment]
960+
values, _ = _ensure_data(values)
1006961
ndtype = values.dtype.name
1007962
f = getattr(htable, f"duplicated_{ndtype}")
1008963
return f(values, keep=keep)

pandas/core/array_algos/putmask.py

+3-9
Original file line numberDiff line numberDiff line change
@@ -191,16 +191,10 @@ def extract_bool_array(mask: ArrayLike) -> np.ndarray:
191191
# We could have BooleanArray, Sparse[bool], ...
192192
# Except for BooleanArray, this is equivalent to just
193193
# np.asarray(mask, dtype=bool)
194+
mask = mask.to_numpy(dtype=bool, na_value=False)
194195

195-
# error: Incompatible types in assignment (expression has type "ndarray",
196-
# variable has type "ExtensionArray")
197-
mask = mask.to_numpy(dtype=bool, na_value=False) # type: ignore[assignment]
198-
199-
# error: Incompatible types in assignment (expression has type "ndarray", variable
200-
# has type "ExtensionArray")
201-
mask = np.asarray(mask, dtype=bool) # type: ignore[assignment]
202-
# error: Incompatible return value type (got "ExtensionArray", expected "ndarray")
203-
return mask # type: ignore[return-value]
196+
mask = np.asarray(mask, dtype=bool)
197+
return mask
204198

205199

206200
def setitem_datetimelike_compat(values: np.ndarray, num_set: int, other):

pandas/core/array_algos/quantile.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,9 @@ def quantile_compat(values: ArrayLike, qs, interpolation: str, axis: int) -> Arr
4040
if isinstance(values, np.ndarray):
4141
fill_value = na_value_for_dtype(values.dtype, compat=False)
4242
mask = isna(values)
43-
result = quantile_with_mask(values, mask, fill_value, qs, interpolation, axis)
43+
return quantile_with_mask(values, mask, fill_value, qs, interpolation, axis)
4444
else:
45-
result = quantile_ea_compat(values, qs, interpolation, axis)
46-
return result
45+
return quantile_ea_compat(values, qs, interpolation, axis)
4746

4847

4948
def quantile_with_mask(

pandas/core/array_algos/replace.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,7 @@ def _check_comparison_types(
9595

9696
if is_numeric_v_string_like(a, b):
9797
# GH#29553 avoid deprecation warnings from numpy
98-
# error: Incompatible return value type (got "ndarray", expected
99-
# "Union[ExtensionArray, bool]")
100-
return np.zeros(a.shape, dtype=bool) # type: ignore[return-value]
98+
return np.zeros(a.shape, dtype=bool)
10199

102100
elif is_datetimelike_v_numeric(a, b):
103101
# GH#29553 avoid deprecation warnings from numpy

pandas/core/array_algos/take.py

+32-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
from __future__ import annotations
22

33
import functools
4-
from typing import Optional
4+
from typing import (
5+
TYPE_CHECKING,
6+
Optional,
7+
overload,
8+
)
59

610
import numpy as np
711

@@ -20,6 +24,33 @@
2024

2125
from pandas.core.construction import ensure_wrapped_if_datetimelike
2226

27+
if TYPE_CHECKING:
28+
from pandas.core.arrays.base import ExtensionArray
29+
30+
31+
@overload
32+
def take_nd(
33+
arr: np.ndarray,
34+
indexer,
35+
axis: int = ...,
36+
out: Optional[np.ndarray] = ...,
37+
fill_value=...,
38+
allow_fill: bool = ...,
39+
) -> np.ndarray:
40+
...
41+
42+
43+
@overload
44+
def take_nd(
45+
arr: ExtensionArray,
46+
indexer,
47+
axis: int = ...,
48+
out: Optional[np.ndarray] = ...,
49+
fill_value=...,
50+
allow_fill: bool = ...,
51+
) -> ArrayLike:
52+
...
53+
2354

2455
def take_nd(
2556
arr: ArrayLike,

pandas/core/arrays/_mixins.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ def fillna(
291291
value, mask, len(self) # type: ignore[arg-type]
292292
)
293293

294-
# error: "ExtensionArray" has no attribute "any"
295-
if mask.any(): # type: ignore[attr-defined]
294+
if mask.any():
296295
if method is not None:
297296
# TODO: check value is None
298297
# (for now) when self.ndim == 2, we assume axis=0

pandas/core/arrays/base.py

+18-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import operator
1212
from typing import (
13+
TYPE_CHECKING,
1314
Any,
1415
Callable,
1516
Dict,
@@ -71,6 +72,16 @@
7172
nargsort,
7273
)
7374

75+
if TYPE_CHECKING:
76+
77+
class ExtensionArraySupportsAnyAll("ExtensionArray"):
78+
def any(self, *, skipna: bool = True) -> bool:
79+
pass
80+
81+
def all(self, *, skipna: bool = True) -> bool:
82+
pass
83+
84+
7485
_extension_array_shared_docs: Dict[str, str] = {}
7586

7687
ExtensionArrayT = TypeVar("ExtensionArrayT", bound="ExtensionArray")
@@ -380,7 +391,7 @@ def __iter__(self):
380391
for i in range(len(self)):
381392
yield self[i]
382393

383-
def __contains__(self, item) -> bool:
394+
def __contains__(self, item) -> Union[bool, np.bool_]:
384395
"""
385396
Return for `item in self`.
386397
"""
@@ -391,8 +402,7 @@ def __contains__(self, item) -> bool:
391402
if not self._can_hold_na:
392403
return False
393404
elif item is self.dtype.na_value or isinstance(item, self.dtype.type):
394-
# error: "ExtensionArray" has no attribute "any"
395-
return self.isna().any() # type: ignore[attr-defined]
405+
return self.isna().any()
396406
else:
397407
return False
398408
else:
@@ -543,7 +553,7 @@ def astype(self, dtype, copy=True):
543553

544554
return np.array(self, dtype=dtype, copy=copy)
545555

546-
def isna(self) -> ArrayLike:
556+
def isna(self) -> Union[np.ndarray, ExtensionArraySupportsAnyAll]:
547557
"""
548558
A 1-D array indicating if each value is missing.
549559
@@ -648,8 +658,7 @@ def argmin(self, skipna: bool = True) -> int:
648658
ExtensionArray.argmax
649659
"""
650660
validate_bool_kwarg(skipna, "skipna")
651-
# error: "ExtensionArray" has no attribute "any"
652-
if not skipna and self.isna().any(): # type: ignore[attr-defined]
661+
if not skipna and self.isna().any():
653662
raise NotImplementedError
654663
return nargminmax(self, "argmin")
655664

@@ -673,8 +682,7 @@ def argmax(self, skipna: bool = True) -> int:
673682
ExtensionArray.argmin
674683
"""
675684
validate_bool_kwarg(skipna, "skipna")
676-
# error: "ExtensionArray" has no attribute "any"
677-
if not skipna and self.isna().any(): # type: ignore[attr-defined]
685+
if not skipna and self.isna().any():
678686
raise NotImplementedError
679687
return nargminmax(self, "argmax")
680688

@@ -714,8 +722,7 @@ def fillna(self, value=None, method=None, limit=None):
714722
value, mask, len(self) # type: ignore[arg-type]
715723
)
716724

717-
# error: "ExtensionArray" has no attribute "any"
718-
if mask.any(): # type: ignore[attr-defined]
725+
if mask.any():
719726
if method is not None:
720727
func = missing.get_fill_func(method)
721728
new_values, _ = func(self.astype(object), limit=limit, mask=mask)
@@ -1156,9 +1163,7 @@ def view(self, dtype: Optional[Dtype] = None) -> ArrayLike:
11561163
# giving a view with the same dtype as self.
11571164
if dtype is not None:
11581165
raise NotImplementedError(dtype)
1159-
# error: Incompatible return value type (got "Union[ExtensionArray, Any]",
1160-
# expected "ndarray")
1161-
return self[:] # type: ignore[return-value]
1166+
return self[:]
11621167

11631168
# ------------------------------------------------------------------------
11641169
# Printing

pandas/core/arrays/boolean.py

+3-12
Original file line numberDiff line numberDiff line change
@@ -406,18 +406,14 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
406406
dtype = pandas_dtype(dtype)
407407

408408
if isinstance(dtype, ExtensionDtype):
409-
# error: Incompatible return value type (got "ExtensionArray", expected
410-
# "ndarray")
411-
return super().astype(dtype, copy) # type: ignore[return-value]
409+
return super().astype(dtype, copy)
412410

413411
if is_bool_dtype(dtype):
414412
# astype_nansafe converts np.nan to True
415413
if self._hasna:
416414
raise ValueError("cannot convert float NaN to bool")
417415
else:
418-
# error: Incompatible return value type (got "ndarray", expected
419-
# "ExtensionArray")
420-
return self._data.astype(dtype, copy=copy) # type: ignore[return-value]
416+
return self._data.astype(dtype, copy=copy)
421417

422418
# for integer, error if there are missing values
423419
if is_integer_dtype(dtype) and self._hasna:
@@ -429,12 +425,7 @@ def astype(self, dtype, copy: bool = True) -> ArrayLike:
429425
if is_float_dtype(dtype):
430426
na_value = np.nan
431427
# coerce
432-
433-
# error: Incompatible return value type (got "ndarray", expected
434-
# "ExtensionArray")
435-
return self.to_numpy( # type: ignore[return-value]
436-
dtype=dtype, na_value=na_value, copy=False
437-
)
428+
return self.to_numpy(dtype=dtype, na_value=na_value, copy=False)
438429

439430
def _values_for_argsort(self) -> np.ndarray:
440431
"""

pandas/core/arrays/categorical.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -550,8 +550,7 @@ def astype(self, dtype: Dtype, copy: bool = True) -> ArrayLike:
550550
new_cats, libalgos.ensure_platform_int(self._codes)
551551
)
552552

553-
# error: Incompatible return value type (got "Categorical", expected "ndarray")
554-
return result # type: ignore[return-value]
553+
return result
555554

556555
@cache_readonly
557556
def itemsize(self) -> int:
@@ -2659,8 +2658,9 @@ def _get_codes_for_values(values, categories: Index) -> np.ndarray:
26592658
# Only hit here when we've already coerced to object dtypee.
26602659

26612660
hash_klass, vals = get_data_algo(values)
2662-
# error: Value of type variable "ArrayLike" of "get_data_algo" cannot be "Index"
2663-
_, cats = get_data_algo(categories) # type: ignore[type-var]
2661+
# pandas/core/arrays/categorical.py:2661: error: Argument 1 to "get_data_algo" has
2662+
# incompatible type "Index"; expected "Union[ExtensionArray, ndarray]" [arg-type]
2663+
_, cats = get_data_algo(categories) # type: ignore[arg-type]
26642664
t = hash_klass(len(cats))
26652665
t.map_locations(cats)
26662666
return coerce_indexer_dtype(t.lookup(vals), cats)

0 commit comments

Comments
 (0)