Skip to content

Commit 2d2db8d

Browse files
authored
BUG: get_indexer_non_unique with np.datetime64("NaT") and np.timedelta64("NaT") (#43870)
1 parent a4b9da7 commit 2d2db8d

File tree

9 files changed

+160
-25
lines changed

9 files changed

+160
-25
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -498,6 +498,7 @@ Indexing
498498
- Bug in :meth:`DataFrame.nlargest` and :meth:`Series.nlargest` where sorted result did not count indexes containing ``np.nan`` (:issue:`28984`)
499499
- Bug in indexing on a non-unique object-dtype :class:`Index` with an NA scalar (e.g. ``np.nan``) (:issue:`43711`)
500500
- Bug in :meth:`Series.__setitem__` with object dtype when setting an array with matching size and dtype='datetime64[ns]' or dtype='timedelta64[ns]' incorrectly converting the datetime/timedeltas to integers (:issue:`43868`)
501+
- Bug in :meth:`Index.get_indexer_non_unique` when index contains multiple ``np.datetime64("NaT")`` and ``np.timedelta64("NaT")`` (:issue:`43869`)
501502
-
502503

503504
Missing

pandas/_libs/index.pyx

+44-19
Original file line numberDiff line numberDiff line change
@@ -315,14 +315,14 @@ cdef class IndexEngine:
315315
missing : np.ndarray[np.intp]
316316
"""
317317
cdef:
318-
ndarray values, x
318+
ndarray values
319319
ndarray[intp_t] result, missing
320-
set stargets, remaining_stargets
320+
set stargets, remaining_stargets, found_nas
321321
dict d = {}
322322
object val
323323
Py_ssize_t count = 0, count_missing = 0
324324
Py_ssize_t i, j, n, n_t, n_alloc, start, end
325-
bint d_has_nan = False, stargets_has_nan = False, need_nan_check = True
325+
bint check_na_values = False
326326

327327
values = self.values
328328
stargets = set(targets)
@@ -357,33 +357,58 @@ cdef class IndexEngine:
357357
if stargets:
358358
# otherwise, map by iterating through all items in the index
359359

360+
# short-circuit na check
361+
if values.dtype == object:
362+
check_na_values = True
363+
# keep track of nas in values
364+
found_nas = set()
365+
360366
for i in range(n):
361367
val = values[i]
368+
369+
# GH#43870
370+
# handle lookup for nas
371+
# (ie. np.nan, float("NaN"), Decimal("NaN"), dt64nat, td64nat)
372+
if check_na_values and checknull(val):
373+
match = [na for na in found_nas if is_matching_na(val, na)]
374+
375+
# matching na not found
376+
if not len(match):
377+
found_nas.add(val)
378+
379+
# add na to stargets to utilize `in` for stargets/d lookup
380+
match_stargets = [
381+
x for x in stargets if is_matching_na(val, x)
382+
]
383+
384+
if len(match_stargets):
385+
# add our 'standardized' na
386+
stargets.add(val)
387+
388+
# matching na found
389+
else:
390+
assert len(match) == 1
391+
val = match[0]
392+
362393
if val in stargets:
363394
if val not in d:
364395
d[val] = []
365396
d[val].append(i)
366397

367-
elif util.is_nan(val):
368-
# GH#35392
369-
if need_nan_check:
370-
# Do this check only once
371-
stargets_has_nan = any(util.is_nan(val) for x in stargets)
372-
need_nan_check = False
373-
374-
if stargets_has_nan:
375-
if not d_has_nan:
376-
# use a canonical nan object
377-
d[np.nan] = []
378-
d_has_nan = True
379-
d[np.nan].append(i)
380-
381398
for i in range(n_t):
382399
val = targets[i]
383400

401+
# ensure there are nas in values before looking for a matching na
402+
if check_na_values and checknull(val):
403+
match = [na for na in found_nas if is_matching_na(val, na)]
404+
if len(match):
405+
assert len(match) == 1
406+
val = match[0]
407+
384408
# found
385-
if val in d or (d_has_nan and util.is_nan(val)):
386-
key = val if not util.is_nan(val) else np.nan
409+
if val in d:
410+
key = val
411+
387412
for j in d[key]:
388413

389414
# realloc if needed

pandas/_libs/missing.pyx

+7-6
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@ from pandas._libs cimport util
2020
from pandas._libs.tslibs.nattype cimport (
2121
c_NaT as NaT,
2222
checknull_with_nat,
23+
is_dt64nat,
2324
is_null_datetimelike,
25+
is_td64nat,
2426
)
2527
from pandas._libs.tslibs.np_datetime cimport (
28+
get_datetime64_unit,
2629
get_datetime64_value,
2730
get_timedelta64_value,
2831
)
@@ -82,12 +85,14 @@ cpdef bint is_matching_na(object left, object right, bint nan_matches_none=False
8285
get_datetime64_value(left) == NPY_NAT
8386
and util.is_datetime64_object(right)
8487
and get_datetime64_value(right) == NPY_NAT
88+
and get_datetime64_unit(left) == get_datetime64_unit(right)
8589
)
8690
elif util.is_timedelta64_object(left):
8791
return (
8892
get_timedelta64_value(left) == NPY_NAT
8993
and util.is_timedelta64_object(right)
9094
and get_timedelta64_value(right) == NPY_NAT
95+
and get_datetime64_unit(left) == get_datetime64_unit(right)
9196
)
9297
elif is_decimal_na(left):
9398
return is_decimal_na(right)
@@ -345,20 +350,16 @@ def isneginf_scalar(val: object) -> bool:
345350
cdef inline bint is_null_datetime64(v):
346351
# determine if we have a null for a datetime (or integer versions),
347352
# excluding np.timedelta64('nat')
348-
if checknull_with_nat(v):
353+
if checknull_with_nat(v) or is_dt64nat(v):
349354
return True
350-
elif util.is_datetime64_object(v):
351-
return get_datetime64_value(v) == NPY_NAT
352355
return False
353356

354357

355358
cdef inline bint is_null_timedelta64(v):
356359
# determine if we have a null for a timedelta (or integer versions),
357360
# excluding np.datetime64('nat')
358-
if checknull_with_nat(v):
361+
if checknull_with_nat(v) or is_td64nat(v):
359362
return True
360-
elif util.is_timedelta64_object(v):
361-
return get_timedelta64_value(v) == NPY_NAT
362363
return False
363364

364365

pandas/_libs/tslibs/nattype.pxd

+2
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@ cdef _NaT c_NaT
1616

1717

1818
cdef bint checknull_with_nat(object val)
19+
cdef bint is_dt64nat(object val)
20+
cdef bint is_td64nat(object val)
1921
cpdef bint is_null_datetimelike(object val, bint inat_is_null=*)

pandas/_libs/tslibs/nattype.pyx

+16
Original file line numberDiff line numberDiff line change
@@ -1133,6 +1133,22 @@ cdef inline bint checknull_with_nat(object val):
11331133
"""
11341134
return val is None or util.is_nan(val) or val is c_NaT
11351135

1136+
cdef inline bint is_dt64nat(object val):
1137+
"""
1138+
Is this a np.datetime64 object np.datetime64("NaT").
1139+
"""
1140+
if util.is_datetime64_object(val):
1141+
return get_datetime64_value(val) == NPY_NAT
1142+
return False
1143+
1144+
cdef inline bint is_td64nat(object val):
1145+
"""
1146+
Is this a np.timedelta64 object np.timedelta64("NaT").
1147+
"""
1148+
if util.is_timedelta64_object(val):
1149+
return get_timedelta64_value(val) == NPY_NAT
1150+
return False
1151+
11361152

11371153
cpdef bint is_null_datetimelike(object val, bint inat_is_null=True):
11381154
"""

pandas/_testing/__init__.py

+19
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,25 @@
157157
)
158158

159159
NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")]
160+
NP_NAT_OBJECTS = [
161+
cls("NaT", unit)
162+
for cls in [np.datetime64, np.timedelta64]
163+
for unit in [
164+
"Y",
165+
"M",
166+
"W",
167+
"D",
168+
"h",
169+
"m",
170+
"s",
171+
"ms",
172+
"us",
173+
"ns",
174+
"ps",
175+
"fs",
176+
"as",
177+
]
178+
]
160179

161180
EMPTY_STRING_PATTERN = re.compile("^$")
162181

pandas/conftest.py

+13
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,19 @@ def unique_nulls_fixture(request):
332332
# Generate cartesian product of unique_nulls_fixture:
333333
unique_nulls_fixture2 = unique_nulls_fixture
334334

335+
336+
@pytest.fixture(params=tm.NP_NAT_OBJECTS, ids=lambda x: type(x).__name__)
337+
def np_nat_fixture(request):
338+
"""
339+
Fixture for each NaT type in numpy.
340+
"""
341+
return request.param
342+
343+
344+
# Generate cartesian product of np_nat_fixture:
345+
np_nat_fixture2 = np_nat_fixture
346+
347+
335348
# ----------------------------------------------------------------
336349
# Classes
337350
# ----------------------------------------------------------------

pandas/tests/indexes/object/test_indexing.py

+53
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from decimal import Decimal
2+
13
import numpy as np
24
import pytest
35

@@ -90,12 +92,63 @@ def test_get_indexer_non_unique_nas(self, nulls_fixture):
9092
# matching-but-not-identical nans
9193
if is_matching_na(nulls_fixture, float("NaN")):
9294
index = Index(["a", float("NaN"), "b", float("NaN")])
95+
match_but_not_identical = True
96+
elif is_matching_na(nulls_fixture, Decimal("NaN")):
97+
index = Index(["a", Decimal("NaN"), "b", Decimal("NaN")])
98+
match_but_not_identical = True
99+
else:
100+
match_but_not_identical = False
101+
102+
if match_but_not_identical:
93103
indexer, missing = index.get_indexer_non_unique([nulls_fixture])
94104

95105
expected_indexer = np.array([1, 3], dtype=np.intp)
96106
tm.assert_numpy_array_equal(indexer, expected_indexer)
97107
tm.assert_numpy_array_equal(missing, expected_missing)
98108

109+
@pytest.mark.filterwarnings("ignore:elementwise comp:DeprecationWarning")
110+
def test_get_indexer_non_unique_np_nats(self, np_nat_fixture, np_nat_fixture2):
111+
expected_missing = np.array([], dtype=np.intp)
112+
# matching-but-not-identical nats
113+
if is_matching_na(np_nat_fixture, np_nat_fixture2):
114+
# ensure nats are different objects
115+
index = Index(
116+
np.array(
117+
["2021-10-02", np_nat_fixture.copy(), np_nat_fixture2.copy()],
118+
dtype=object,
119+
),
120+
dtype=object,
121+
)
122+
# pass as index to prevent target from being casted to DatetimeIndex
123+
indexer, missing = index.get_indexer_non_unique(
124+
Index([np_nat_fixture], dtype=object)
125+
)
126+
expected_indexer = np.array([1, 2], dtype=np.intp)
127+
tm.assert_numpy_array_equal(indexer, expected_indexer)
128+
tm.assert_numpy_array_equal(missing, expected_missing)
129+
# dt64nat vs td64nat
130+
else:
131+
index = Index(
132+
np.array(
133+
[
134+
"2021-10-02",
135+
np_nat_fixture,
136+
np_nat_fixture2,
137+
np_nat_fixture,
138+
np_nat_fixture2,
139+
],
140+
dtype=object,
141+
),
142+
dtype=object,
143+
)
144+
# pass as index to prevent target from being casted to DatetimeIndex
145+
indexer, missing = index.get_indexer_non_unique(
146+
Index([np_nat_fixture], dtype=object)
147+
)
148+
expected_indexer = np.array([1, 3], dtype=np.intp)
149+
tm.assert_numpy_array_equal(indexer, expected_indexer)
150+
tm.assert_numpy_array_equal(missing, expected_missing)
151+
99152

100153
class TestSliceLocs:
101154
@pytest.mark.parametrize(

pandas/tests/indexes/test_indexing.py

+5
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,11 @@ def test_maybe_cast_slice_bound_kind_deprecated(index):
320320
np.array([1, 2], dtype=np.intp),
321321
),
322322
(["a", "b", "a", np.nan], [np.nan], np.array([3], dtype=np.intp)),
323+
(
324+
np.array(["b", np.nan, float("NaN"), "b"], dtype=object),
325+
Index([np.nan], dtype=object),
326+
np.array([1, 2], dtype=np.intp),
327+
),
323328
],
324329
)
325330
def test_get_indexer_non_unique_multiple_nans(idx, target, expected):

0 commit comments

Comments
 (0)