Skip to content

Commit dfe1d6f

Browse files
committed
TYP: narrow type bounds on extract_array
1 parent f86644d commit dfe1d6f

File tree

7 files changed

+57
-41
lines changed

7 files changed

+57
-41
lines changed

pandas/core/algorithms.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -453,30 +453,34 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> npt.NDArray[np.bool_]:
453453
else:
454454
values = extract_array(values, extract_numpy=True, extract_range=True)
455455

456-
comps = _ensure_arraylike(comps)
457-
comps = extract_array(comps, extract_numpy=True)
458-
if not isinstance(comps, np.ndarray):
456+
comps_array = _ensure_arraylike(comps)
457+
comps_array = extract_array(comps_array, extract_numpy=True)
458+
if not isinstance(comps_array, np.ndarray):
459459
# i.e. Extension Array
460-
return comps.isin(values)
460+
return comps_array.isin(values)
461461

462-
elif needs_i8_conversion(comps.dtype):
462+
elif needs_i8_conversion(comps_array.dtype):
463463
# Dispatch to DatetimeLikeArrayMixin.isin
464-
return pd_array(comps).isin(values)
465-
elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps.dtype):
466-
# e.g. comps are integers and values are datetime64s
467-
return np.zeros(comps.shape, dtype=bool)
464+
return pd_array(comps_array).isin(values)
465+
elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps_array.dtype):
466+
# e.g. comps_array are integers and values are datetime64s
467+
return np.zeros(comps_array.shape, dtype=bool)
468468
# TODO: not quite right ... Sparse/Categorical
469469
elif needs_i8_conversion(values.dtype):
470-
return isin(comps, values.astype(object))
470+
return isin(comps_array, values.astype(object))
471471

472472
elif isinstance(values.dtype, ExtensionDtype):
473-
return isin(np.asarray(comps), np.asarray(values))
473+
return isin(np.asarray(comps_array), np.asarray(values))
474474

475475
# GH16012
476476
# Ensure np.in1d doesn't get object types or it *may* throw an exception
477477
# Albeit hashmap has O(1) look-up (vs. O(logn) in sorted array),
478478
# in1d is faster for small sizes
479-
if len(comps) > 1_000_000 and len(values) <= 26 and not is_object_dtype(comps):
479+
if (
480+
len(comps_array) > 1_000_000
481+
and len(values) <= 26
482+
and not is_object_dtype(comps_array)
483+
):
480484
# If the values include nan we need to check for nan explicitly
481485
# since np.nan it not equal to np.nan
482486
if isna(values).any():
@@ -488,12 +492,12 @@ def f(c, v):
488492
f = np.in1d
489493

490494
else:
491-
common = np.find_common_type([values.dtype, comps.dtype], [])
495+
common = np.find_common_type([values.dtype, comps_array.dtype], [])
492496
values = values.astype(common, copy=False)
493-
comps = comps.astype(common, copy=False)
497+
comps_array = comps_array.astype(common, copy=False)
494498
f = htable.ismember
495499

496-
return f(comps, values)
500+
return f(comps_array, values)
497501

498502

499503
def factorize_array(

pandas/core/arrays/categorical.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -447,9 +447,7 @@ def __init__(
447447
dtype = CategoricalDtype(categories, dtype.ordered)
448448

449449
elif is_categorical_dtype(values.dtype):
450-
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no
451-
# attribute "_codes"
452-
old_codes = extract_array(values)._codes # type: ignore[union-attr]
450+
old_codes = extract_array(values)._codes
453451
codes = recode_for_categories(
454452
old_codes, values.dtype.categories, dtype.categories, copy=copy
455453
)

pandas/core/arrays/datetimelike.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -1235,13 +1235,7 @@ def _addsub_object_array(self, other: np.ndarray, op):
12351235
res_values = op(self.astype("O"), np.asarray(other))
12361236

12371237
result = pd_array(res_values.ravel())
1238-
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no attribute
1239-
# "reshape"
1240-
result = extract_array(
1241-
result, extract_numpy=True
1242-
).reshape( # type: ignore[union-attr]
1243-
self.shape
1244-
)
1238+
result = extract_array(result, extract_numpy=True).reshape(self.shape)
12451239
return result
12461240

12471241
def _time_shift(

pandas/core/construction.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,20 @@
99
from typing import (
1010
TYPE_CHECKING,
1111
Any,
12+
Optional,
1213
Sequence,
14+
TypeVar,
15+
Union,
1316
cast,
17+
overload,
1418
)
1519
import warnings
1620

1721
import numpy as np
1822
import numpy.ma as ma
1923

2024
from pandas._libs import lib
25+
from pandas._libs.tslibs.period import Period
2126
from pandas._typing import (
2227
AnyArrayLike,
2328
ArrayLike,
@@ -73,6 +78,9 @@
7378
)
7479

7580

81+
T = TypeVar("T")
82+
83+
7684
def array(
7785
data: Sequence[object] | AnyArrayLike,
7886
dtype: Dtype | None = None,
@@ -329,7 +337,8 @@ def array(
329337
if dtype is None:
330338
inferred_dtype = lib.infer_dtype(data, skipna=True)
331339
if inferred_dtype == "period":
332-
return PeriodArray._from_sequence(data, copy=copy)
340+
period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data)
341+
return PeriodArray._from_sequence(period_data, copy=copy)
333342

334343
elif inferred_dtype == "interval":
335344
return IntervalArray(data, copy=copy)
@@ -376,9 +385,23 @@ def array(
376385
return PandasArray._from_sequence(data, dtype=dtype, copy=copy)
377386

378387

388+
@overload
389+
def extract_array(
390+
obj: Series | Index, extract_numpy: bool = False, extract_range: bool = False
391+
) -> ArrayLike:
392+
...
393+
394+
395+
@overload
396+
def extract_array(
397+
obj: T, extract_numpy: bool = False, extract_range: bool = False
398+
) -> T | ArrayLike:
399+
...
400+
401+
379402
def extract_array(
380-
obj: object, extract_numpy: bool = False, extract_range: bool = False
381-
) -> Any | ArrayLike:
403+
obj: T, extract_numpy: bool = False, extract_range: bool = False
404+
) -> T | ArrayLike:
382405
"""
383406
Extract the ndarray or ExtensionArray from a Series or Index.
384407
@@ -425,12 +448,15 @@ def extract_array(
425448
if isinstance(obj, ABCRangeIndex):
426449
if extract_range:
427450
return obj._values
428-
return obj
451+
# https://github.com/python/mypy/issues/1081
452+
# error: Incompatible return value type (got "RangeIndex", expected
453+
# "Union[T, Union[ExtensionArray, ndarray[Any, Any]]]")
454+
return obj # type: ignore[return-value]
429455

430-
obj = obj._values
456+
return obj._values
431457

432458
elif extract_numpy and isinstance(obj, ABCPandasArray):
433-
obj = obj.to_numpy()
459+
return obj.to_numpy()
434460

435461
return obj
436462

pandas/core/internals/construction.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -913,12 +913,7 @@ def _list_of_series_to_arrays(
913913
values = extract_array(s, extract_numpy=True)
914914
aligned_values.append(algorithms.take_nd(values, indexer))
915915

916-
# error: Argument 1 to "vstack" has incompatible type "List[ExtensionArray]";
917-
# expected "Sequence[Union[Union[int, float, complex, str, bytes, generic],
918-
# Sequence[Union[int, float, complex, str, bytes, generic]],
919-
# Sequence[Sequence[Any]], _SupportsArray]]"
920-
content = np.vstack(aligned_values) # type: ignore[arg-type]
921-
916+
content = np.vstack(aligned_values)
922917
return content, columns
923918

924919

pandas/io/formats/format.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1641,9 +1641,7 @@ def _format_strings(self) -> list[str]:
16411641

16421642
formatter = self.formatter
16431643
if formatter is None:
1644-
# error: Item "ndarray" of "Union[Any, Union[ExtensionArray, ndarray]]" has
1645-
# no attribute "_formatter"
1646-
formatter = values._formatter(boxed=True) # type: ignore[union-attr]
1644+
formatter = values._formatter(boxed=True)
16471645

16481646
if isinstance(values, Categorical):
16491647
# Categorical is special for now, so that we can preserve tzinfo

pandas/io/pytables.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pandas._libs.tslibs import timezones
4040
from pandas._typing import (
41+
AnyArrayLike,
4142
ArrayLike,
4243
DtypeArg,
4344
Shape,
@@ -3036,7 +3037,7 @@ def write_array_empty(self, key: str, value: ArrayLike):
30363037
node._v_attrs.shape = value.shape
30373038

30383039
def write_array(
3039-
self, key: str, obj: DataFrame | Series, items: Index | None = None
3040+
self, key: str, obj: AnyArrayLike, items: Index | None = None
30403041
) -> None:
30413042
# TODO: we only have a few tests that get here, the only EA
30423043
# that gets passed is DatetimeArray, and we never have

0 commit comments

Comments
 (0)