diff --git a/pandas/_libs/index.pyi b/pandas/_libs/index.pyi index 446a980487cde..86f2429575ebb 100644 --- a/pandas/_libs/index.pyi +++ b/pandas/_libs/index.pyi @@ -3,6 +3,7 @@ import numpy as np from pandas._typing import npt from pandas import MultiIndex +from pandas.core.arrays import ExtensionArray class IndexEngine: over_size_threshold: bool @@ -63,3 +64,21 @@ class BaseMultiIndexCodesEngine: method: str, limit: int | None, ) -> npt.NDArray[np.intp]: ... + +class ExtensionEngine: + def __init__(self, values: "ExtensionArray"): ... + def __contains__(self, val: object) -> bool: ... + def get_loc(self, val: object) -> int | slice | np.ndarray: ... + def get_indexer(self, values: np.ndarray) -> npt.NDArray[np.intp]: ... + def get_indexer_non_unique( + self, + targets: np.ndarray, + ) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ... + @property + def is_unique(self) -> bool: ... + @property + def is_monotonic_increasing(self) -> bool: ... + @property + def is_monotonic_decreasing(self) -> bool: ... + def sizeof(self, deep: bool = ...) -> int: ... + def clear_mapping(self): ... diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index c3b86165e6d2c..fe8e8d92f699a 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -797,3 +797,274 @@ cdef class BaseMultiIndexCodesEngine: # Generated from template. include "index_class_helper.pxi" + + +@cython.internal +@cython.freelist(32) +cdef class SharedEngine: + cdef readonly: + object values # ExtensionArray + bint over_size_threshold + + cdef: + bint unique, monotonic_inc, monotonic_dec + bint need_monotonic_check, need_unique_check + + def __contains__(self, val: object) -> bool: + # We assume before we get here: + # - val is hashable + try: + self.get_loc(val) + return True + except KeyError: + return False + + def clear_mapping(self): + # for compat with IndexEngine + pass + + @property + def is_unique(self) -> bool: + if self.need_unique_check: + arr = self.values.unique() + self.unique = len(arr) == len(self.values) + + self.need_unique_check = False + return self.unique + + cdef _do_monotonic_check(self): + raise NotImplementedError + + @property + def is_monotonic_increasing(self) -> bool: + if self.need_monotonic_check: + self._do_monotonic_check() + + return self.monotonic_inc == 1 + + @property + def is_monotonic_decreasing(self) -> bool: + if self.need_monotonic_check: + self._do_monotonic_check() + + return self.monotonic_dec == 1 + + cdef _call_monotonic(self, values): + return algos.is_monotonic(values, timelike=False) + + def sizeof(self, deep: bool = False) -> int: + """ return the sizeof our mapping """ + return 0 + + def __sizeof__(self) -> int: + return self.sizeof() + + cdef _check_type(self, object obj): + raise NotImplementedError + + cpdef get_loc(self, object val): + # -> Py_ssize_t | slice | ndarray[bool] + cdef: + Py_ssize_t loc + + if is_definitely_invalid_key(val): + raise TypeError(f"'{val}' is an invalid key") + + self._check_type(val) + + if self.over_size_threshold and self.is_monotonic_increasing: + if not self.is_unique: + return self._get_loc_duplicates(val) + + values = self.values + + loc = self._searchsorted_left(val) + if loc >= len(values): + raise KeyError(val) + if values[loc] != val: + raise KeyError(val) + return loc + + if not self.unique: + return self._get_loc_duplicates(val) + + return self._get_loc_duplicates(val) + + cdef inline _get_loc_duplicates(self, object val): + # -> Py_ssize_t | slice | ndarray[bool] + cdef: + Py_ssize_t diff + + if self.is_monotonic_increasing: + values = self.values + try: + left = values.searchsorted(val, side='left') + right = values.searchsorted(val, side='right') + except TypeError: + # e.g. GH#29189 get_loc(None) with a Float64Index + raise KeyError(val) + + diff = right - left + if diff == 0: + raise KeyError(val) + elif diff == 1: + return left + else: + return slice(left, right) + + return self._maybe_get_bool_indexer(val) + + cdef Py_ssize_t _searchsorted_left(self, val) except? -1: + """ + See ObjectEngine._searchsorted_left.__doc__. + """ + try: + loc = self.values.searchsorted(val, side="left") + except TypeError as err: + # GH#35788 e.g. val=None with float64 values + raise KeyError(val) + return loc + + cdef ndarray _get_bool_indexer(self, val): + raise NotImplementedError + + cdef _maybe_get_bool_indexer(self, object val): + # Returns ndarray[bool] or int + cdef: + ndarray[uint8_t, ndim=1, cast=True] indexer + + indexer = self._get_bool_indexer(val) + return _unpack_bool_indexer(indexer, val) + + def get_indexer(self, values) -> np.ndarray: + # values : type(self.values) + # Note: we only get here with self.is_unique + cdef: + Py_ssize_t i, N = len(values) + + res = np.empty(N, dtype=np.intp) + + for i in range(N): + val = values[i] + try: + loc = self.get_loc(val) + # Because we are unique, loc should always be an integer + except KeyError: + loc = -1 + else: + assert util.is_integer_object(loc), (loc, val) + res[i] = loc + + return res + + def get_indexer_non_unique(self, targets): + """ + Return an indexer suitable for taking from a non unique index + return the labels in the same order as the target + and a missing indexer into the targets (which correspond + to the -1 indices in the results + Parameters + ---------- + targets : type(self.values) + Returns + ------- + indexer : np.ndarray[np.intp] + missing : np.ndarray[np.intp] + """ + cdef: + Py_ssize_t i, N = len(targets) + + indexer = [] + missing = [] + + # See also IntervalIndex.get_indexer_pointwise + for i in range(N): + val = targets[i] + + try: + locs = self.get_loc(val) + except KeyError: + locs = np.array([-1], dtype=np.intp) + missing.append(i) + else: + if isinstance(locs, slice): + # Only needed for get_indexer_non_unique + locs = np.arange(locs.start, locs.stop, locs.step, dtype=np.intp) + elif util.is_integer_object(locs): + locs = np.array([locs], dtype=np.intp) + else: + assert locs.dtype.kind == "b" + locs = locs.nonzero()[0] + + indexer.append(locs) + + try: + indexer = np.concatenate(indexer, dtype=np.intp) + except TypeError: + # numpy<1.20 doesn't accept dtype keyword + indexer = np.concatenate(indexer).astype(np.intp, copy=False) + missing = np.array(missing, dtype=np.intp) + + return indexer, missing + + +cdef class ExtensionEngine(SharedEngine): + def __init__(self, values: "ExtensionArray"): + self.values = values + + self.over_size_threshold = len(values) >= _SIZE_CUTOFF + self.need_unique_check = True + self.need_monotonic_check = True + self.need_unique_check = True + + cdef _do_monotonic_check(self): + cdef: + bint is_unique + + values = self.values + if values._hasna: + self.monotonic_inc = 0 + self.monotonic_dec = 0 + + nunique = len(values.unique()) + self.unique = nunique == len(values) + self.need_unique_check = 0 + return + + try: + ranks = values._rank() + + except TypeError: + self.monotonic_inc = 0 + self.monotonic_dec = 0 + is_unique = 0 + else: + self.monotonic_inc, self.monotonic_dec, is_unique = \ + self._call_monotonic(ranks) + + self.need_monotonic_check = 0 + + # we can only be sure of uniqueness if is_unique=1 + if is_unique: + self.unique = 1 + self.need_unique_check = 0 + + cdef ndarray _get_bool_indexer(self, val): + if checknull(val): + return self.values.isna().view("uint8") + + try: + return self.values == val + except TypeError: + # e.g. if __eq__ returns a BooleanArray instead of ndarry[bool] + try: + return (self.values == val).to_numpy(dtype=bool, na_value=False) + except (TypeError, AttributeError) as err: + # e.g. (self.values == val) returned a bool + # see test_get_loc_generator[string[pyarrow]] + # e.g. self.value == val raises TypeError bc generator has no len + # see test_get_loc_generator[string[python]] + raise KeyError from err + + cdef _check_type(self, object val): + hash(val) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 26ddf89adaea2..e0bd548eff70a 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -144,7 +144,6 @@ tz_to_dtype, validate_tz_from_dtype, ) -from pandas.core.arrays.masked import BaseMaskedArray from pandas.core.arrays.sparse import SparseDtype from pandas.core.base import ( IndexOpsMixin, @@ -329,6 +328,9 @@ def _left_indexer_unique(self: _IndexT, other: _IndexT) -> npt.NDArray[np.intp]: # Caller is responsible for ensuring other.dtype == self.dtype sv = self._get_engine_target() ov = other._get_engine_target() + # can_use_libjoin assures sv and ov are ndarrays + sv = cast(np.ndarray, sv) + ov = cast(np.ndarray, ov) return libjoin.left_join_indexer_unique(sv, ov) @final @@ -338,6 +340,9 @@ def _left_indexer( # Caller is responsible for ensuring other.dtype == self.dtype sv = self._get_engine_target() ov = other._get_engine_target() + # can_use_libjoin assures sv and ov are ndarrays + sv = cast(np.ndarray, sv) + ov = cast(np.ndarray, ov) joined_ndarray, lidx, ridx = libjoin.left_join_indexer(sv, ov) joined = self._from_join_target(joined_ndarray) return joined, lidx, ridx @@ -349,6 +354,9 @@ def _inner_indexer( # Caller is responsible for ensuring other.dtype == self.dtype sv = self._get_engine_target() ov = other._get_engine_target() + # can_use_libjoin assures sv and ov are ndarrays + sv = cast(np.ndarray, sv) + ov = cast(np.ndarray, ov) joined_ndarray, lidx, ridx = libjoin.inner_join_indexer(sv, ov) joined = self._from_join_target(joined_ndarray) return joined, lidx, ridx @@ -360,6 +368,9 @@ def _outer_indexer( # Caller is responsible for ensuring other.dtype == self.dtype sv = self._get_engine_target() ov = other._get_engine_target() + # can_use_libjoin assures sv and ov are ndarrays + sv = cast(np.ndarray, sv) + ov = cast(np.ndarray, ov) joined_ndarray, lidx, ridx = libjoin.outer_join_indexer(sv, ov) joined = self._from_join_target(joined_ndarray) return joined, lidx, ridx @@ -387,7 +398,9 @@ def _outer_indexer( # associated code in pandas 2.0. _is_backward_compat_public_numeric_index: bool = False - _engine_type: type[libindex.IndexEngine] = libindex.ObjectEngine + _engine_type: type[libindex.IndexEngine] | type[ + libindex.ExtensionEngine + ] = libindex.ObjectEngine # whether we support partial string indexing. Overridden # in DatetimeIndex and PeriodIndex _supports_partial_string_indexing = False @@ -850,23 +863,21 @@ def _cleanup(self) -> None: @cache_readonly def _engine( self, - ) -> libindex.IndexEngine: + ) -> libindex.IndexEngine | libindex.ExtensionEngine: # For base class (object dtype) we get ObjectEngine - - if isinstance(self._values, BaseMaskedArray): - # TODO(ExtensionIndex): use libindex.NullableEngine(self._values) - return libindex.ObjectEngine(self._get_engine_target()) - elif ( - isinstance(self._values, ExtensionArray) + target_values = self._get_engine_target() + if ( + isinstance(target_values, ExtensionArray) and self._engine_type is libindex.ObjectEngine ): - # TODO(ExtensionIndex): use libindex.ExtensionEngine(self._values) - return libindex.ObjectEngine(self._get_engine_target()) + return libindex.ExtensionEngine(target_values) # to avoid a reference cycle, bind `target_values` to a local variable, so # `self` is not passed into the lambda. - target_values = self._get_engine_target() - return self._engine_type(target_values) + target_values = cast(np.ndarray, target_values) + # error: Argument 1 to "ExtensionEngine" has incompatible type + # "ndarray[Any, Any]"; expected "ExtensionArray" + return self._engine_type(target_values) # type:ignore[arg-type] @final @cache_readonly @@ -3322,12 +3333,6 @@ def _wrap_setop_result(self, other: Index, result) -> Index: result = result.rename(name) else: result = self._shallow_copy(result, name=name) - - if type(self) is Index and self.dtype != _dtype_obj: - # i.e. ExtensionArray-backed - # TODO(ExtensionIndex): revert this astype; it is a kludge to make - # it possible to split ExtensionEngine from ExtensionIndex PR. - return result.astype(self.dtype, copy=False) return result @final @@ -3874,12 +3879,16 @@ def _get_indexer( tgt_values = target._get_engine_target() if target._is_multi and self._is_multi: engine = self._engine - # error: "IndexEngine" has no attribute "_extract_level_codes" - tgt_values = engine._extract_level_codes( # type: ignore[attr-defined] + # error: Item "IndexEngine" of "Union[IndexEngine, ExtensionEngine]" + # has no attribute "_extract_level_codes" + tgt_values = engine._extract_level_codes( # type: ignore[union-attr] target ) - indexer = self._engine.get_indexer(tgt_values) + # error: Argument 1 to "get_indexer" of "IndexEngine" has incompatible + # type "Union[ExtensionArray, ndarray[Any, Any]]"; expected + # "ndarray[Any, Any]" + indexer = self._engine.get_indexer(tgt_values) # type:ignore[arg-type] return ensure_platform_int(indexer) @@ -3955,13 +3964,18 @@ def _get_fill_indexer( # TODO: get_indexer_with_fill docstring says values must be _sorted_ # but that doesn't appear to be enforced # error: "IndexEngine" has no attribute "get_indexer_with_fill" - return self._engine.get_indexer_with_fill( # type: ignore[attr-defined] + engine = self._engine + return engine.get_indexer_with_fill( # type: ignore[union-attr] target=target._values, values=self._values, method=method, limit=limit ) if self.is_monotonic_increasing and target.is_monotonic_increasing: target_values = target._get_engine_target() own_values = self._get_engine_target() + if not isinstance(target_values, np.ndarray) or not isinstance( + own_values, np.ndarray + ): + raise NotImplementedError if method == "pad": indexer = libalgos.pad(own_values, target_values, limit=limit) @@ -4889,8 +4903,9 @@ def _can_use_libjoin(self) -> bool: """ Whether we can use the fastpaths implement in _libs.join """ - # Note: this will need to be updated when e.g. Nullable dtypes - # are supported in Indexes. + if type(self) is Index: + # excludes EAs + return isinstance(self.dtype, np.dtype) return not is_interval_dtype(self.dtype) # -------------------------------------------------------------------- @@ -4956,16 +4971,12 @@ def _values(self) -> ExtensionArray | np.ndarray: """ return self._data - def _get_engine_target(self) -> np.ndarray: + def _get_engine_target(self) -> ArrayLike: """ - Get the ndarray that we can pass to the IndexEngine constructor. + Get the ndarray or ExtensionArray that we can pass to the IndexEngine + constructor. """ - # error: Incompatible return value type (got "Union[ExtensionArray, - # ndarray]", expected "ndarray") - if type(self) is Index and isinstance(self._values, ExtensionArray): - # TODO(ExtensionIndex): remove special-case, just use self._values - return self._values.astype(object) - return self._values # type: ignore[return-value] + return self._values def _from_join_target(self, result: np.ndarray) -> ArrayLike: """ @@ -5854,10 +5865,9 @@ def get_indexer_non_unique( tgt_values = target._get_engine_target() if self._is_multi and target._is_multi: engine = self._engine - # error: "IndexEngine" has no attribute "_extract_level_codes" - tgt_values = engine._extract_level_codes( # type: ignore[attr-defined] - target - ) + # Item "IndexEngine" of "Union[IndexEngine, ExtensionEngine]" has + # no attribute "_extract_level_codes" + tgt_values = engine._extract_level_codes(target) # type: ignore[union-attr] indexer, missing = self._engine.get_indexer_non_unique(tgt_values) return ensure_platform_int(indexer), ensure_platform_int(missing)