Skip to content

REF: share _get_bool_indexer #43803

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 29, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 38 additions & 35 deletions pandas/_libs/index.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,40 @@ cdef inline bint is_definitely_invalid_key(object val):
return False


cdef ndarray _get_bool_indexer(ndarray values, object val):
"""
Return a ndarray[bool] of locations where val matches self.values.

If val is not NA, this is equivalent to `self.values == val`
"""
# Caller is responsible for ensuring _check_type has already been called
cdef:
ndarray[uint8_t, ndim=1, cast=True] indexer
Py_ssize_t i
object item

if values.descr.type_num == cnp.NPY_OBJECT:
# i.e. values.dtype == object
if not checknull(val):
indexer = values == val

else:
# We need to check for _matching_ NA values
indexer = np.empty(len(values), dtype=np.uint8)

for i in range(len(values)):
item = values[i]
indexer[i] = is_matching_na(item, val)

else:
if util.is_nan(val):
indexer = np.isnan(values)
else:
indexer = values == val

return indexer.view(bool)


# Don't populate hash tables in monotonic indexes larger than this
_SIZE_CUTOFF = 1_000_000

Expand Down Expand Up @@ -83,12 +117,13 @@ cdef class IndexEngine:
if is_definitely_invalid_key(val):
raise TypeError(f"'{val}' is an invalid key")

self._check_type(val)

if self.over_size_threshold and self.is_monotonic_increasing:
if not self.is_unique:
return self._get_loc_duplicates(val)
values = self.values

self._check_type(val)
loc = self._searchsorted_left(val)
if loc >= len(values):
raise KeyError(val)
Expand All @@ -100,8 +135,6 @@ cdef class IndexEngine:
if not self.unique:
return self._get_loc_duplicates(val)

self._check_type(val)

try:
return self.mapping.get_item(val)
except (TypeError, ValueError, OverflowError):
Expand Down Expand Up @@ -148,17 +181,9 @@ cdef class IndexEngine:
cdef:
ndarray[uint8_t, ndim=1, cast=True] indexer

indexer = self._get_bool_indexer(val)
indexer = _get_bool_indexer(self.values, val)
return self._unpack_bool_indexer(indexer, val)

cdef ndarray _get_bool_indexer(self, object val):
"""
Return a ndarray[bool] of locations where val matches self.values.

If val is not NA, this is equivalent to `self.values == val`
"""
raise NotImplementedError("Implemented by subclasses")

cdef _unpack_bool_indexer(self,
ndarray[uint8_t, ndim=1, cast=True] indexer,
object val):
Expand Down Expand Up @@ -253,16 +278,13 @@ cdef class IndexEngine:

values = self.values
self.mapping = self._make_hash_table(len(values))
self._call_map_locations(values)
self.mapping.map_locations(values)

if len(self.mapping) == len(values):
self.unique = 1

self.need_unique_check = 0

cdef void _call_map_locations(self, ndarray values):
self.mapping.map_locations(values)

def clear_mapping(self):
self.mapping = None
self.need_monotonic_check = 1
Expand Down Expand Up @@ -430,25 +452,6 @@ cdef class ObjectEngine(IndexEngine):
raise KeyError(val) from err
return loc

cdef ndarray _get_bool_indexer(self, object val):
# We need to check for equality and for matching NAs
cdef:
ndarray values = self.values

if not checknull(val):
return values == val

cdef:
ndarray[uint8_t] result = np.empty(len(values), dtype=np.uint8)
Py_ssize_t i
object item

for i in range(len(values)):
item = values[i]
result[i] = is_matching_na(item, val)

return result.view(bool)


cdef class DatetimeEngine(Int64Engine):

Expand Down
23 changes: 0 additions & 23 deletions pandas/_libs/index_class_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,5 @@ cdef class {{name}}Engine(IndexEngine):
raise KeyError(val)
{{endif}}

cdef void _call_map_locations(self, ndarray[{{dtype}}_t] values):
self.mapping.map_locations(values)

cdef ndarray _get_bool_indexer(self, object val):
cdef:
ndarray[uint8_t, ndim=1, cast=True] indexer
ndarray[{{dtype}}_t, ndim=1] values

self._check_type(val)

values = self.values

{{if name in {'Float64', 'Float32'} }}
if util.is_nan(val):
indexer = np.isnan(values)
else:
indexer = values == val
{{else}}
indexer = values == val
{{endif}}

return indexer


{{endfor}}