Skip to content

Commit 43a4bad

Browse files
authored
TYP: narrow type bounds on extract_array (#46942)
1 parent 1c82694 commit 43a4bad

File tree

7 files changed

+54
-41
lines changed

7 files changed

+54
-41
lines changed

pandas/core/algorithms.py

+19-15
Original file line numberDiff line numberDiff line change
@@ -453,30 +453,34 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> npt.NDArray[np.bool_]:
453453
else:
454454
values = extract_array(values, extract_numpy=True, extract_range=True)
455455

456-
comps = _ensure_arraylike(comps)
457-
comps = extract_array(comps, extract_numpy=True)
458-
if not isinstance(comps, np.ndarray):
456+
comps_array = _ensure_arraylike(comps)
457+
comps_array = extract_array(comps_array, extract_numpy=True)
458+
if not isinstance(comps_array, np.ndarray):
459459
# i.e. Extension Array
460-
return comps.isin(values)
460+
return comps_array.isin(values)
461461

462-
elif needs_i8_conversion(comps.dtype):
462+
elif needs_i8_conversion(comps_array.dtype):
463463
# Dispatch to DatetimeLikeArrayMixin.isin
464-
return pd_array(comps).isin(values)
465-
elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps.dtype):
466-
# e.g. comps are integers and values are datetime64s
467-
return np.zeros(comps.shape, dtype=bool)
464+
return pd_array(comps_array).isin(values)
465+
elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps_array.dtype):
466+
# e.g. comps_array are integers and values are datetime64s
467+
return np.zeros(comps_array.shape, dtype=bool)
468468
# TODO: not quite right ... Sparse/Categorical
469469
elif needs_i8_conversion(values.dtype):
470-
return isin(comps, values.astype(object))
470+
return isin(comps_array, values.astype(object))
471471

472472
elif isinstance(values.dtype, ExtensionDtype):
473-
return isin(np.asarray(comps), np.asarray(values))
473+
return isin(np.asarray(comps_array), np.asarray(values))
474474

475475
# GH16012
476476
# Ensure np.in1d doesn't get object types or it *may* throw an exception
477477
# Albeit hashmap has O(1) look-up (vs. O(logn) in sorted array),
478478
# in1d is faster for small sizes
479-
if len(comps) > 1_000_000 and len(values) <= 26 and not is_object_dtype(comps):
479+
if (
480+
len(comps_array) > 1_000_000
481+
and len(values) <= 26
482+
and not is_object_dtype(comps_array)
483+
):
480484
# If the values include nan we need to check for nan explicitly
481485
# since np.nan it not equal to np.nan
482486
if isna(values).any():
@@ -488,12 +492,12 @@ def f(c, v):
488492
f = np.in1d
489493

490494
else:
491-
common = np.find_common_type([values.dtype, comps.dtype], [])
495+
common = np.find_common_type([values.dtype, comps_array.dtype], [])
492496
values = values.astype(common, copy=False)
493-
comps = comps.astype(common, copy=False)
497+
comps_array = comps_array.astype(common, copy=False)
494498
f = htable.ismember
495499

496-
return f(comps, values)
500+
return f(comps_array, values)
497501

498502

499503
def factorize_array(

pandas/core/arrays/categorical.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -446,9 +446,7 @@ def __init__(
446446
dtype = CategoricalDtype(categories, dtype.ordered)
447447

448448
elif is_categorical_dtype(values.dtype):
449-
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no
450-
# attribute "_codes"
451-
old_codes = extract_array(values)._codes # type: ignore[union-attr]
449+
old_codes = extract_array(values)._codes
452450
codes = recode_for_categories(
453451
old_codes, values.dtype.categories, dtype.categories, copy=copy
454452
)

pandas/core/arrays/datetimelike.py

+1-7
Original file line numberDiff line numberDiff line change
@@ -1258,13 +1258,7 @@ def _addsub_object_array(self, other: np.ndarray, op):
12581258
res_values = op(self.astype("O"), np.asarray(other))
12591259

12601260
result = pd_array(res_values.ravel())
1261-
# error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no attribute
1262-
# "reshape"
1263-
result = extract_array(
1264-
result, extract_numpy=True
1265-
).reshape( # type: ignore[union-attr]
1266-
self.shape
1267-
)
1261+
result = extract_array(result, extract_numpy=True).reshape(self.shape)
12681262
return result
12691263

12701264
def _time_shift(

pandas/core/construction.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -9,20 +9,25 @@
99
from typing import (
1010
TYPE_CHECKING,
1111
Any,
12+
Optional,
1213
Sequence,
14+
Union,
1315
cast,
16+
overload,
1417
)
1518
import warnings
1619

1720
import numpy as np
1821
import numpy.ma as ma
1922

2023
from pandas._libs import lib
24+
from pandas._libs.tslibs.period import Period
2125
from pandas._typing import (
2226
AnyArrayLike,
2327
ArrayLike,
2428
Dtype,
2529
DtypeObj,
30+
T,
2631
)
2732
from pandas.errors import IntCastingNaNError
2833
from pandas.util._exceptions import find_stack_level
@@ -329,7 +334,8 @@ def array(
329334
if dtype is None:
330335
inferred_dtype = lib.infer_dtype(data, skipna=True)
331336
if inferred_dtype == "period":
332-
return PeriodArray._from_sequence(data, copy=copy)
337+
period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data)
338+
return PeriodArray._from_sequence(period_data, copy=copy)
333339

334340
elif inferred_dtype == "interval":
335341
return IntervalArray(data, copy=copy)
@@ -376,9 +382,23 @@ def array(
376382
return PandasArray._from_sequence(data, dtype=dtype, copy=copy)
377383

378384

385+
@overload
379386
def extract_array(
380-
obj: object, extract_numpy: bool = False, extract_range: bool = False
381-
) -> Any | ArrayLike:
387+
obj: Series | Index, extract_numpy: bool = ..., extract_range: bool = ...
388+
) -> ArrayLike:
389+
...
390+
391+
392+
@overload
393+
def extract_array(
394+
obj: T, extract_numpy: bool = ..., extract_range: bool = ...
395+
) -> T | ArrayLike:
396+
...
397+
398+
399+
def extract_array(
400+
obj: T, extract_numpy: bool = False, extract_range: bool = False
401+
) -> T | ArrayLike:
382402
"""
383403
Extract the ndarray or ExtensionArray from a Series or Index.
384404
@@ -425,12 +445,15 @@ def extract_array(
425445
if isinstance(obj, ABCRangeIndex):
426446
if extract_range:
427447
return obj._values
428-
return obj
448+
# https://github.com/python/mypy/issues/1081
449+
# error: Incompatible return value type (got "RangeIndex", expected
450+
# "Union[T, Union[ExtensionArray, ndarray[Any, Any]]]")
451+
return obj # type: ignore[return-value]
429452

430-
obj = obj._values
453+
return obj._values
431454

432455
elif extract_numpy and isinstance(obj, ABCPandasArray):
433-
obj = obj.to_numpy()
456+
return obj.to_numpy()
434457

435458
return obj
436459

pandas/core/internals/construction.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -913,12 +913,7 @@ def _list_of_series_to_arrays(
913913
values = extract_array(s, extract_numpy=True)
914914
aligned_values.append(algorithms.take_nd(values, indexer))
915915

916-
# error: Argument 1 to "vstack" has incompatible type "List[ExtensionArray]";
917-
# expected "Sequence[Union[Union[int, float, complex, str, bytes, generic],
918-
# Sequence[Union[int, float, complex, str, bytes, generic]],
919-
# Sequence[Sequence[Any]], _SupportsArray]]"
920-
content = np.vstack(aligned_values) # type: ignore[arg-type]
921-
916+
content = np.vstack(aligned_values)
922917
return content, columns
923918

924919

pandas/io/formats/format.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -1641,9 +1641,7 @@ def _format_strings(self) -> list[str]:
16411641

16421642
formatter = self.formatter
16431643
if formatter is None:
1644-
# error: Item "ndarray" of "Union[Any, Union[ExtensionArray, ndarray]]" has
1645-
# no attribute "_formatter"
1646-
formatter = values._formatter(boxed=True) # type: ignore[union-attr]
1644+
formatter = values._formatter(boxed=True)
16471645

16481646
if isinstance(values, Categorical):
16491647
# Categorical is special for now, so that we can preserve tzinfo

pandas/io/pytables.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from pandas._libs.tslibs import timezones
4040
from pandas._typing import (
41+
AnyArrayLike,
4142
ArrayLike,
4243
DtypeArg,
4344
Shape,
@@ -3042,7 +3043,7 @@ def write_array_empty(self, key: str, value: ArrayLike):
30423043
node._v_attrs.shape = value.shape
30433044

30443045
def write_array(
3045-
self, key: str, obj: DataFrame | Series, items: Index | None = None
3046+
self, key: str, obj: AnyArrayLike, items: Index | None = None
30463047
) -> None:
30473048
# TODO: we only have a few tests that get here, the only EA
30483049
# that gets passed is DatetimeArray, and we never have

0 commit comments

Comments
 (0)