Skip to content

Commit 4248b23

Browse files
authored
ENH: ExtensionEngine (#45514)
1 parent 1fd31bd commit 4248b23

File tree

3 files changed

+337
-37
lines changed

3 files changed

+337
-37
lines changed

pandas/_libs/index.pyi

+19
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import numpy as np
33
from pandas._typing import npt
44

55
from pandas import MultiIndex
6+
from pandas.core.arrays import ExtensionArray
67

78
class IndexEngine:
89
over_size_threshold: bool
@@ -63,3 +64,21 @@ class BaseMultiIndexCodesEngine:
6364
method: str,
6465
limit: int | None,
6566
) -> npt.NDArray[np.intp]: ...
67+
68+
class ExtensionEngine:
69+
def __init__(self, values: "ExtensionArray"): ...
70+
def __contains__(self, val: object) -> bool: ...
71+
def get_loc(self, val: object) -> int | slice | np.ndarray: ...
72+
def get_indexer(self, values: np.ndarray) -> npt.NDArray[np.intp]: ...
73+
def get_indexer_non_unique(
74+
self,
75+
targets: np.ndarray,
76+
) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp]]: ...
77+
@property
78+
def is_unique(self) -> bool: ...
79+
@property
80+
def is_monotonic_increasing(self) -> bool: ...
81+
@property
82+
def is_monotonic_decreasing(self) -> bool: ...
83+
def sizeof(self, deep: bool = ...) -> int: ...
84+
def clear_mapping(self): ...

pandas/_libs/index.pyx

+271
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,274 @@ cdef class BaseMultiIndexCodesEngine:
797797

798798
# Generated from template.
799799
include "index_class_helper.pxi"
800+
801+
802+
@cython.internal
803+
@cython.freelist(32)
804+
cdef class SharedEngine:
805+
cdef readonly:
806+
object values # ExtensionArray
807+
bint over_size_threshold
808+
809+
cdef:
810+
bint unique, monotonic_inc, monotonic_dec
811+
bint need_monotonic_check, need_unique_check
812+
813+
def __contains__(self, val: object) -> bool:
814+
# We assume before we get here:
815+
# - val is hashable
816+
try:
817+
self.get_loc(val)
818+
return True
819+
except KeyError:
820+
return False
821+
822+
def clear_mapping(self):
823+
# for compat with IndexEngine
824+
pass
825+
826+
@property
827+
def is_unique(self) -> bool:
828+
if self.need_unique_check:
829+
arr = self.values.unique()
830+
self.unique = len(arr) == len(self.values)
831+
832+
self.need_unique_check = False
833+
return self.unique
834+
835+
cdef _do_monotonic_check(self):
836+
raise NotImplementedError
837+
838+
@property
839+
def is_monotonic_increasing(self) -> bool:
840+
if self.need_monotonic_check:
841+
self._do_monotonic_check()
842+
843+
return self.monotonic_inc == 1
844+
845+
@property
846+
def is_monotonic_decreasing(self) -> bool:
847+
if self.need_monotonic_check:
848+
self._do_monotonic_check()
849+
850+
return self.monotonic_dec == 1
851+
852+
cdef _call_monotonic(self, values):
853+
return algos.is_monotonic(values, timelike=False)
854+
855+
def sizeof(self, deep: bool = False) -> int:
856+
""" return the sizeof our mapping """
857+
return 0
858+
859+
def __sizeof__(self) -> int:
860+
return self.sizeof()
861+
862+
cdef _check_type(self, object obj):
863+
raise NotImplementedError
864+
865+
cpdef get_loc(self, object val):
866+
# -> Py_ssize_t | slice | ndarray[bool]
867+
cdef:
868+
Py_ssize_t loc
869+
870+
if is_definitely_invalid_key(val):
871+
raise TypeError(f"'{val}' is an invalid key")
872+
873+
self._check_type(val)
874+
875+
if self.over_size_threshold and self.is_monotonic_increasing:
876+
if not self.is_unique:
877+
return self._get_loc_duplicates(val)
878+
879+
values = self.values
880+
881+
loc = self._searchsorted_left(val)
882+
if loc >= len(values):
883+
raise KeyError(val)
884+
if values[loc] != val:
885+
raise KeyError(val)
886+
return loc
887+
888+
if not self.unique:
889+
return self._get_loc_duplicates(val)
890+
891+
return self._get_loc_duplicates(val)
892+
893+
cdef inline _get_loc_duplicates(self, object val):
894+
# -> Py_ssize_t | slice | ndarray[bool]
895+
cdef:
896+
Py_ssize_t diff
897+
898+
if self.is_monotonic_increasing:
899+
values = self.values
900+
try:
901+
left = values.searchsorted(val, side='left')
902+
right = values.searchsorted(val, side='right')
903+
except TypeError:
904+
# e.g. GH#29189 get_loc(None) with a Float64Index
905+
raise KeyError(val)
906+
907+
diff = right - left
908+
if diff == 0:
909+
raise KeyError(val)
910+
elif diff == 1:
911+
return left
912+
else:
913+
return slice(left, right)
914+
915+
return self._maybe_get_bool_indexer(val)
916+
917+
cdef Py_ssize_t _searchsorted_left(self, val) except? -1:
918+
"""
919+
See ObjectEngine._searchsorted_left.__doc__.
920+
"""
921+
try:
922+
loc = self.values.searchsorted(val, side="left")
923+
except TypeError as err:
924+
# GH#35788 e.g. val=None with float64 values
925+
raise KeyError(val)
926+
return loc
927+
928+
cdef ndarray _get_bool_indexer(self, val):
929+
raise NotImplementedError
930+
931+
cdef _maybe_get_bool_indexer(self, object val):
932+
# Returns ndarray[bool] or int
933+
cdef:
934+
ndarray[uint8_t, ndim=1, cast=True] indexer
935+
936+
indexer = self._get_bool_indexer(val)
937+
return _unpack_bool_indexer(indexer, val)
938+
939+
def get_indexer(self, values) -> np.ndarray:
940+
# values : type(self.values)
941+
# Note: we only get here with self.is_unique
942+
cdef:
943+
Py_ssize_t i, N = len(values)
944+
945+
res = np.empty(N, dtype=np.intp)
946+
947+
for i in range(N):
948+
val = values[i]
949+
try:
950+
loc = self.get_loc(val)
951+
# Because we are unique, loc should always be an integer
952+
except KeyError:
953+
loc = -1
954+
else:
955+
assert util.is_integer_object(loc), (loc, val)
956+
res[i] = loc
957+
958+
return res
959+
960+
def get_indexer_non_unique(self, targets):
961+
"""
962+
Return an indexer suitable for taking from a non unique index
963+
return the labels in the same order as the target
964+
and a missing indexer into the targets (which correspond
965+
to the -1 indices in the results
966+
Parameters
967+
----------
968+
targets : type(self.values)
969+
Returns
970+
-------
971+
indexer : np.ndarray[np.intp]
972+
missing : np.ndarray[np.intp]
973+
"""
974+
cdef:
975+
Py_ssize_t i, N = len(targets)
976+
977+
indexer = []
978+
missing = []
979+
980+
# See also IntervalIndex.get_indexer_pointwise
981+
for i in range(N):
982+
val = targets[i]
983+
984+
try:
985+
locs = self.get_loc(val)
986+
except KeyError:
987+
locs = np.array([-1], dtype=np.intp)
988+
missing.append(i)
989+
else:
990+
if isinstance(locs, slice):
991+
# Only needed for get_indexer_non_unique
992+
locs = np.arange(locs.start, locs.stop, locs.step, dtype=np.intp)
993+
elif util.is_integer_object(locs):
994+
locs = np.array([locs], dtype=np.intp)
995+
else:
996+
assert locs.dtype.kind == "b"
997+
locs = locs.nonzero()[0]
998+
999+
indexer.append(locs)
1000+
1001+
try:
1002+
indexer = np.concatenate(indexer, dtype=np.intp)
1003+
except TypeError:
1004+
# numpy<1.20 doesn't accept dtype keyword
1005+
indexer = np.concatenate(indexer).astype(np.intp, copy=False)
1006+
missing = np.array(missing, dtype=np.intp)
1007+
1008+
return indexer, missing
1009+
1010+
1011+
cdef class ExtensionEngine(SharedEngine):
1012+
def __init__(self, values: "ExtensionArray"):
1013+
self.values = values
1014+
1015+
self.over_size_threshold = len(values) >= _SIZE_CUTOFF
1016+
self.need_unique_check = True
1017+
self.need_monotonic_check = True
1018+
self.need_unique_check = True
1019+
1020+
cdef _do_monotonic_check(self):
1021+
cdef:
1022+
bint is_unique
1023+
1024+
values = self.values
1025+
if values._hasna:
1026+
self.monotonic_inc = 0
1027+
self.monotonic_dec = 0
1028+
1029+
nunique = len(values.unique())
1030+
self.unique = nunique == len(values)
1031+
self.need_unique_check = 0
1032+
return
1033+
1034+
try:
1035+
ranks = values._rank()
1036+
1037+
except TypeError:
1038+
self.monotonic_inc = 0
1039+
self.monotonic_dec = 0
1040+
is_unique = 0
1041+
else:
1042+
self.monotonic_inc, self.monotonic_dec, is_unique = \
1043+
self._call_monotonic(ranks)
1044+
1045+
self.need_monotonic_check = 0
1046+
1047+
# we can only be sure of uniqueness if is_unique=1
1048+
if is_unique:
1049+
self.unique = 1
1050+
self.need_unique_check = 0
1051+
1052+
cdef ndarray _get_bool_indexer(self, val):
1053+
if checknull(val):
1054+
return self.values.isna().view("uint8")
1055+
1056+
try:
1057+
return self.values == val
1058+
except TypeError:
1059+
# e.g. if __eq__ returns a BooleanArray instead of ndarry[bool]
1060+
try:
1061+
return (self.values == val).to_numpy(dtype=bool, na_value=False)
1062+
except (TypeError, AttributeError) as err:
1063+
# e.g. (self.values == val) returned a bool
1064+
# see test_get_loc_generator[string[pyarrow]]
1065+
# e.g. self.value == val raises TypeError bc generator has no len
1066+
# see test_get_loc_generator[string[python]]
1067+
raise KeyError from err
1068+
1069+
cdef _check_type(self, object val):
1070+
hash(val)

0 commit comments

Comments
 (0)