Skip to content

Commit df6b665

Browse files
committed
CLN: overriding IndexEngine subclasses (#35498)
1 parent 7a00be7 commit df6b665

File tree

2 files changed

+49
-29
lines changed

2 files changed

+49
-29
lines changed

pandas/_libs/index.pyx

+39-29
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ from pandas._libs import (
3636
)
3737
from pandas._libs.missing import (
3838
checknull,
39+
is_matching_na,
3940
isnaobj,
4041
)
4142

@@ -270,6 +271,25 @@ cdef class IndexEngine:
270271
self._ensure_mapping_populated()
271272
return self.mapping.lookup(values)
272273

274+
def get_stargets(self, ndarray targets) -> set:
275+
return set(targets)
276+
277+
def convert_val_if_nan(self, object val) -> object:
278+
# unable to utilize val if nan when updating
279+
# hashable data structures (ie. sets, dict)
280+
if checknull(val):
281+
return -1
282+
else:
283+
return val
284+
285+
def should_update_d(self, object target, object val) -> bool:
286+
# -1 in targets could be either -1 or nan
287+
# ensures values in d[-1] to be included only once
288+
if target == val or is_matching_na(target, val):
289+
return True
290+
291+
return False
292+
273293
def get_indexer_non_unique(self, ndarray targets):
274294
"""
275295
Return an indexer suitable for taking from a non unique index
@@ -293,11 +313,7 @@ cdef class IndexEngine:
293313

294314
self._ensure_mapping_populated()
295315
values = np.array(self._get_index_values(), copy=False)
296-
values_mask = isnaobj(values)
297-
targets_mask = isnaobj(targets)
298-
stargets = set(targets)
299-
if -1 not in stargets and targets_mask.any():
300-
stargets.add(-1)
316+
stargets = self.get_stargets(targets)
301317
n = len(values)
302318
n_t = len(targets)
303319
if n > 10_000:
@@ -328,22 +344,15 @@ cdef class IndexEngine:
328344
if stargets:
329345
# otherwise, map by iterating through all items in the index
330346
for i in range(n):
331-
if values_mask[i]:
332-
val = -1
333-
else:
334-
val = values[i]
347+
val = self.convert_val_if_nan(values[i])
335348

336349
if val in stargets:
337350
if val not in d:
338351
d[val] = []
339352
d[val].append(i)
340353

341354
for i in range(n_t):
342-
nan_target = targets_mask[i]
343-
if nan_target:
344-
val = -1
345-
else:
346-
val = targets[i]
355+
val = self.convert_val_if_nan(targets[i])
347356

348357
# found
349358
if val in d:
@@ -354,21 +363,7 @@ cdef class IndexEngine:
354363
n_alloc += 10_000
355364
result = np.resize(result, n_alloc)
356365

357-
# -1 in targets could be either -1 or nan
358-
# ensures values in d[-1] to be included only once
359-
if val == -1:
360-
nan_val = values_mask[j]
361-
# nan
362-
if nan_target:
363-
if nan_val:
364-
result[count] = j
365-
count += 1
366-
# -1
367-
else:
368-
if not nan_val:
369-
result[count] = j
370-
count += 1
371-
else:
366+
if self.should_update_d(targets[i], values[j]):
372367
result[count] = j
373368
count += 1
374369

@@ -419,6 +414,13 @@ cdef class ObjectEngine(IndexEngine):
419414
cdef _make_hash_table(self, Py_ssize_t n):
420415
return _hash.PyObjectHashTable(n)
421416

417+
def get_stargets(self, ndarray targets) -> set:
418+
stargets = set(targets)
419+
# account for NA-like targets
420+
if -1 not in stargets and isnaobj(targets).any():
421+
stargets.add(-1)
422+
423+
return stargets
422424

423425
cdef class DatetimeEngine(Int64Engine):
424426

@@ -490,6 +492,14 @@ cdef class DatetimeEngine(Int64Engine):
490492
except KeyError:
491493
raise KeyError(val)
492494

495+
def get_stargets(self, ndarray targets) -> set:
496+
stargets = set(targets)
497+
# account for NaTs
498+
if -1 not in stargets and isnaobj(targets).any():
499+
stargets.add(-1)
500+
501+
return stargets
502+
493503
def get_indexer_non_unique(self, ndarray targets):
494504
# we may get datetime64[ns] or timedelta64[ns], cast these to int64
495505
return super().get_indexer_non_unique(targets.view("i8"))

pandas/_libs/index_class_helper.pxi.in

+10
Original file line numberDiff line numberDiff line change
@@ -77,4 +77,14 @@ cdef class {{name}}Engine(IndexEngine):
7777

7878
return self._unpack_bool_indexer(indexer, val)
7979

80+
{{if name in {'Float64', 'Float32'} }}
81+
def get_stargets(self, ndarray targets):
82+
stargets = set(targets)
83+
# account for nans
84+
if -1 not in stargets and np.isnan(targets).any():
85+
stargets.add(-1)
86+
87+
return stargets
88+
{{endif}}
89+
8090
{{endfor}}

0 commit comments

Comments
 (0)