From ac1fb20b1fd4a91079087de49f8b61823f5473eb Mon Sep 17 00:00:00 2001 From: Ilion Beyst Date: Wed, 4 May 2022 22:06:03 +0200 Subject: [PATCH] TYP: narrow type bounds on extract_array --- pandas/core/algorithms.py | 34 ++++++++++++++------------ pandas/core/arrays/categorical.py | 4 +-- pandas/core/arrays/datetimelike.py | 8 +----- pandas/core/construction.py | 35 ++++++++++++++++++++++----- pandas/core/internals/construction.py | 7 +----- pandas/io/formats/format.py | 4 +-- pandas/io/pytables.py | 3 ++- 7 files changed, 54 insertions(+), 41 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 393eb2997f6f0..888e943488953 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -453,30 +453,34 @@ def isin(comps: AnyArrayLike, values: AnyArrayLike) -> npt.NDArray[np.bool_]: else: values = extract_array(values, extract_numpy=True, extract_range=True) - comps = _ensure_arraylike(comps) - comps = extract_array(comps, extract_numpy=True) - if not isinstance(comps, np.ndarray): + comps_array = _ensure_arraylike(comps) + comps_array = extract_array(comps_array, extract_numpy=True) + if not isinstance(comps_array, np.ndarray): # i.e. Extension Array - return comps.isin(values) + return comps_array.isin(values) - elif needs_i8_conversion(comps.dtype): + elif needs_i8_conversion(comps_array.dtype): # Dispatch to DatetimeLikeArrayMixin.isin - return pd_array(comps).isin(values) - elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps.dtype): - # e.g. comps are integers and values are datetime64s - return np.zeros(comps.shape, dtype=bool) + return pd_array(comps_array).isin(values) + elif needs_i8_conversion(values.dtype) and not is_object_dtype(comps_array.dtype): + # e.g. comps_array are integers and values are datetime64s + return np.zeros(comps_array.shape, dtype=bool) # TODO: not quite right ... Sparse/Categorical elif needs_i8_conversion(values.dtype): - return isin(comps, values.astype(object)) + return isin(comps_array, values.astype(object)) elif isinstance(values.dtype, ExtensionDtype): - return isin(np.asarray(comps), np.asarray(values)) + return isin(np.asarray(comps_array), np.asarray(values)) # GH16012 # Ensure np.in1d doesn't get object types or it *may* throw an exception # Albeit hashmap has O(1) look-up (vs. O(logn) in sorted array), # in1d is faster for small sizes - if len(comps) > 1_000_000 and len(values) <= 26 and not is_object_dtype(comps): + if ( + len(comps_array) > 1_000_000 + and len(values) <= 26 + and not is_object_dtype(comps_array) + ): # If the values include nan we need to check for nan explicitly # since np.nan it not equal to np.nan if isna(values).any(): @@ -488,12 +492,12 @@ def f(c, v): f = np.in1d else: - common = np.find_common_type([values.dtype, comps.dtype], []) + common = np.find_common_type([values.dtype, comps_array.dtype], []) values = values.astype(common, copy=False) - comps = comps.astype(common, copy=False) + comps_array = comps_array.astype(common, copy=False) f = htable.ismember - return f(comps, values) + return f(comps_array, values) def factorize_array( diff --git a/pandas/core/arrays/categorical.py b/pandas/core/arrays/categorical.py index 01a04b7aa63d9..4c8d3db7b4672 100644 --- a/pandas/core/arrays/categorical.py +++ b/pandas/core/arrays/categorical.py @@ -447,9 +447,7 @@ def __init__( dtype = CategoricalDtype(categories, dtype.ordered) elif is_categorical_dtype(values.dtype): - # error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no - # attribute "_codes" - old_codes = extract_array(values)._codes # type: ignore[union-attr] + old_codes = extract_array(values)._codes codes = recode_for_categories( old_codes, values.dtype.categories, dtype.categories, copy=copy ) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 9ced8f225c3a8..1930580b63b79 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1235,13 +1235,7 @@ def _addsub_object_array(self, other: np.ndarray, op): res_values = op(self.astype("O"), np.asarray(other)) result = pd_array(res_values.ravel()) - # error: Item "ExtensionArray" of "Union[Any, ExtensionArray]" has no attribute - # "reshape" - result = extract_array( - result, extract_numpy=True - ).reshape( # type: ignore[union-attr] - self.shape - ) + result = extract_array(result, extract_numpy=True).reshape(self.shape) return result def _time_shift( diff --git a/pandas/core/construction.py b/pandas/core/construction.py index 17cdf6665aa99..434302b39fef9 100644 --- a/pandas/core/construction.py +++ b/pandas/core/construction.py @@ -9,8 +9,11 @@ from typing import ( TYPE_CHECKING, Any, + Optional, Sequence, + Union, cast, + overload, ) import warnings @@ -18,11 +21,13 @@ import numpy.ma as ma from pandas._libs import lib +from pandas._libs.tslibs.period import Period from pandas._typing import ( AnyArrayLike, ArrayLike, Dtype, DtypeObj, + T, ) from pandas.errors import IntCastingNaNError from pandas.util._exceptions import find_stack_level @@ -329,7 +334,8 @@ def array( if dtype is None: inferred_dtype = lib.infer_dtype(data, skipna=True) if inferred_dtype == "period": - return PeriodArray._from_sequence(data, copy=copy) + period_data = cast(Union[Sequence[Optional[Period]], AnyArrayLike], data) + return PeriodArray._from_sequence(period_data, copy=copy) elif inferred_dtype == "interval": return IntervalArray(data, copy=copy) @@ -376,9 +382,23 @@ def array( return PandasArray._from_sequence(data, dtype=dtype, copy=copy) +@overload def extract_array( - obj: object, extract_numpy: bool = False, extract_range: bool = False -) -> Any | ArrayLike: + obj: Series | Index, extract_numpy: bool = ..., extract_range: bool = ... +) -> ArrayLike: + ... + + +@overload +def extract_array( + obj: T, extract_numpy: bool = ..., extract_range: bool = ... +) -> T | ArrayLike: + ... + + +def extract_array( + obj: T, extract_numpy: bool = False, extract_range: bool = False +) -> T | ArrayLike: """ Extract the ndarray or ExtensionArray from a Series or Index. @@ -425,12 +445,15 @@ def extract_array( if isinstance(obj, ABCRangeIndex): if extract_range: return obj._values - return obj + # https://github.com/python/mypy/issues/1081 + # error: Incompatible return value type (got "RangeIndex", expected + # "Union[T, Union[ExtensionArray, ndarray[Any, Any]]]") + return obj # type: ignore[return-value] - obj = obj._values + return obj._values elif extract_numpy and isinstance(obj, ABCPandasArray): - obj = obj.to_numpy() + return obj.to_numpy() return obj diff --git a/pandas/core/internals/construction.py b/pandas/core/internals/construction.py index 8451dcb6e412a..7a5db56cb48fe 100644 --- a/pandas/core/internals/construction.py +++ b/pandas/core/internals/construction.py @@ -913,12 +913,7 @@ def _list_of_series_to_arrays( values = extract_array(s, extract_numpy=True) aligned_values.append(algorithms.take_nd(values, indexer)) - # error: Argument 1 to "vstack" has incompatible type "List[ExtensionArray]"; - # expected "Sequence[Union[Union[int, float, complex, str, bytes, generic], - # Sequence[Union[int, float, complex, str, bytes, generic]], - # Sequence[Sequence[Any]], _SupportsArray]]" - content = np.vstack(aligned_values) # type: ignore[arg-type] - + content = np.vstack(aligned_values) return content, columns diff --git a/pandas/io/formats/format.py b/pandas/io/formats/format.py index ef25224e5a847..3019aa1fc2dc7 100644 --- a/pandas/io/formats/format.py +++ b/pandas/io/formats/format.py @@ -1641,9 +1641,7 @@ def _format_strings(self) -> list[str]: formatter = self.formatter if formatter is None: - # error: Item "ndarray" of "Union[Any, Union[ExtensionArray, ndarray]]" has - # no attribute "_formatter" - formatter = values._formatter(boxed=True) # type: ignore[union-attr] + formatter = values._formatter(boxed=True) if isinstance(values, Categorical): # Categorical is special for now, so that we can preserve tzinfo diff --git a/pandas/io/pytables.py b/pandas/io/pytables.py index fc9671c2fc973..c20ce0c847b61 100644 --- a/pandas/io/pytables.py +++ b/pandas/io/pytables.py @@ -38,6 +38,7 @@ ) from pandas._libs.tslibs import timezones from pandas._typing import ( + AnyArrayLike, ArrayLike, DtypeArg, Shape, @@ -3042,7 +3043,7 @@ def write_array_empty(self, key: str, value: ArrayLike): node._v_attrs.shape = value.shape def write_array( - self, key: str, obj: DataFrame | Series, items: Index | None = None + self, key: str, obj: AnyArrayLike, items: Index | None = None ) -> None: # TODO: we only have a few tests that get here, the only EA # that gets passed is DatetimeArray, and we never have