diff --git a/doc/source/whatsnew/v1.4.0.rst b/doc/source/whatsnew/v1.4.0.rst index d6ad5eb2003ce..583180d834010 100644 --- a/doc/source/whatsnew/v1.4.0.rst +++ b/doc/source/whatsnew/v1.4.0.rst @@ -498,6 +498,7 @@ Indexing - Bug in :meth:`DataFrame.nlargest` and :meth:`Series.nlargest` where sorted result did not count indexes containing ``np.nan`` (:issue:`28984`) - Bug in indexing on a non-unique object-dtype :class:`Index` with an NA scalar (e.g. ``np.nan``) (:issue:`43711`) - Bug in :meth:`Series.__setitem__` with object dtype when setting an array with matching size and dtype='datetime64[ns]' or dtype='timedelta64[ns]' incorrectly converting the datetime/timedeltas to integers (:issue:`43868`) +- Bug in :meth:`Index.get_indexer_non_unique` when index contains multiple ``np.datetime64("NaT")`` and ``np.timedelta64("NaT")`` (:issue:`43869`) - Missing diff --git a/pandas/_libs/index.pyx b/pandas/_libs/index.pyx index 43fe9f1d091c8..92837a43e2b69 100644 --- a/pandas/_libs/index.pyx +++ b/pandas/_libs/index.pyx @@ -315,14 +315,14 @@ cdef class IndexEngine: missing : np.ndarray[np.intp] """ cdef: - ndarray values, x + ndarray values ndarray[intp_t] result, missing - set stargets, remaining_stargets + set stargets, remaining_stargets, found_nas dict d = {} object val Py_ssize_t count = 0, count_missing = 0 Py_ssize_t i, j, n, n_t, n_alloc, start, end - bint d_has_nan = False, stargets_has_nan = False, need_nan_check = True + bint check_na_values = False values = self.values stargets = set(targets) @@ -357,33 +357,58 @@ cdef class IndexEngine: if stargets: # otherwise, map by iterating through all items in the index + # short-circuit na check + if values.dtype == object: + check_na_values = True + # keep track of nas in values + found_nas = set() + for i in range(n): val = values[i] + + # GH#43870 + # handle lookup for nas + # (ie. np.nan, float("NaN"), Decimal("NaN"), dt64nat, td64nat) + if check_na_values and checknull(val): + match = [na for na in found_nas if is_matching_na(val, na)] + + # matching na not found + if not len(match): + found_nas.add(val) + + # add na to stargets to utilize `in` for stargets/d lookup + match_stargets = [ + x for x in stargets if is_matching_na(val, x) + ] + + if len(match_stargets): + # add our 'standardized' na + stargets.add(val) + + # matching na found + else: + assert len(match) == 1 + val = match[0] + if val in stargets: if val not in d: d[val] = [] d[val].append(i) - elif util.is_nan(val): - # GH#35392 - if need_nan_check: - # Do this check only once - stargets_has_nan = any(util.is_nan(val) for x in stargets) - need_nan_check = False - - if stargets_has_nan: - if not d_has_nan: - # use a canonical nan object - d[np.nan] = [] - d_has_nan = True - d[np.nan].append(i) - for i in range(n_t): val = targets[i] + # ensure there are nas in values before looking for a matching na + if check_na_values and checknull(val): + match = [na for na in found_nas if is_matching_na(val, na)] + if len(match): + assert len(match) == 1 + val = match[0] + # found - if val in d or (d_has_nan and util.is_nan(val)): - key = val if not util.is_nan(val) else np.nan + if val in d: + key = val + for j in d[key]: # realloc if needed diff --git a/pandas/_libs/missing.pyx b/pandas/_libs/missing.pyx index cbe79d11fbfc9..90f409d371e6b 100644 --- a/pandas/_libs/missing.pyx +++ b/pandas/_libs/missing.pyx @@ -20,9 +20,12 @@ from pandas._libs cimport util from pandas._libs.tslibs.nattype cimport ( c_NaT as NaT, checknull_with_nat, + is_dt64nat, is_null_datetimelike, + is_td64nat, ) from pandas._libs.tslibs.np_datetime cimport ( + get_datetime64_unit, get_datetime64_value, get_timedelta64_value, ) @@ -82,12 +85,14 @@ cpdef bint is_matching_na(object left, object right, bint nan_matches_none=False get_datetime64_value(left) == NPY_NAT and util.is_datetime64_object(right) and get_datetime64_value(right) == NPY_NAT + and get_datetime64_unit(left) == get_datetime64_unit(right) ) elif util.is_timedelta64_object(left): return ( get_timedelta64_value(left) == NPY_NAT and util.is_timedelta64_object(right) and get_timedelta64_value(right) == NPY_NAT + and get_datetime64_unit(left) == get_datetime64_unit(right) ) elif is_decimal_na(left): return is_decimal_na(right) @@ -345,20 +350,16 @@ def isneginf_scalar(val: object) -> bool: cdef inline bint is_null_datetime64(v): # determine if we have a null for a datetime (or integer versions), # excluding np.timedelta64('nat') - if checknull_with_nat(v): + if checknull_with_nat(v) or is_dt64nat(v): return True - elif util.is_datetime64_object(v): - return get_datetime64_value(v) == NPY_NAT return False cdef inline bint is_null_timedelta64(v): # determine if we have a null for a timedelta (or integer versions), # excluding np.datetime64('nat') - if checknull_with_nat(v): + if checknull_with_nat(v) or is_td64nat(v): return True - elif util.is_timedelta64_object(v): - return get_timedelta64_value(v) == NPY_NAT return False diff --git a/pandas/_libs/tslibs/nattype.pxd b/pandas/_libs/tslibs/nattype.pxd index d38f4518f9bf0..35319bd88053a 100644 --- a/pandas/_libs/tslibs/nattype.pxd +++ b/pandas/_libs/tslibs/nattype.pxd @@ -16,4 +16,6 @@ cdef _NaT c_NaT cdef bint checknull_with_nat(object val) +cdef bint is_dt64nat(object val) +cdef bint is_td64nat(object val) cpdef bint is_null_datetimelike(object val, bint inat_is_null=*) diff --git a/pandas/_libs/tslibs/nattype.pyx b/pandas/_libs/tslibs/nattype.pyx index 521927cd910ec..23094bdb90483 100644 --- a/pandas/_libs/tslibs/nattype.pyx +++ b/pandas/_libs/tslibs/nattype.pyx @@ -1133,6 +1133,22 @@ cdef inline bint checknull_with_nat(object val): """ return val is None or util.is_nan(val) or val is c_NaT +cdef inline bint is_dt64nat(object val): + """ + Is this a np.datetime64 object np.datetime64("NaT"). + """ + if util.is_datetime64_object(val): + return get_datetime64_value(val) == NPY_NAT + return False + +cdef inline bint is_td64nat(object val): + """ + Is this a np.timedelta64 object np.timedelta64("NaT"). + """ + if util.is_timedelta64_object(val): + return get_timedelta64_value(val) == NPY_NAT + return False + cpdef bint is_null_datetimelike(object val, bint inat_is_null=True): """ diff --git a/pandas/_testing/__init__.py b/pandas/_testing/__init__.py index c54185e324646..e8283a222d86a 100644 --- a/pandas/_testing/__init__.py +++ b/pandas/_testing/__init__.py @@ -157,6 +157,25 @@ ) NULL_OBJECTS = [None, np.nan, pd.NaT, float("nan"), pd.NA, Decimal("NaN")] +NP_NAT_OBJECTS = [ + cls("NaT", unit) + for cls in [np.datetime64, np.timedelta64] + for unit in [ + "Y", + "M", + "W", + "D", + "h", + "m", + "s", + "ms", + "us", + "ns", + "ps", + "fs", + "as", + ] +] EMPTY_STRING_PATTERN = re.compile("^$") diff --git a/pandas/conftest.py b/pandas/conftest.py index 44b805c632723..75711b19dfcfd 100644 --- a/pandas/conftest.py +++ b/pandas/conftest.py @@ -332,6 +332,19 @@ def unique_nulls_fixture(request): # Generate cartesian product of unique_nulls_fixture: unique_nulls_fixture2 = unique_nulls_fixture + +@pytest.fixture(params=tm.NP_NAT_OBJECTS, ids=lambda x: type(x).__name__) +def np_nat_fixture(request): + """ + Fixture for each NaT type in numpy. + """ + return request.param + + +# Generate cartesian product of np_nat_fixture: +np_nat_fixture2 = np_nat_fixture + + # ---------------------------------------------------------------- # Classes # ---------------------------------------------------------------- diff --git a/pandas/tests/indexes/object/test_indexing.py b/pandas/tests/indexes/object/test_indexing.py index 039483cc948df..38bd96921b991 100644 --- a/pandas/tests/indexes/object/test_indexing.py +++ b/pandas/tests/indexes/object/test_indexing.py @@ -1,3 +1,5 @@ +from decimal import Decimal + import numpy as np import pytest @@ -90,12 +92,63 @@ def test_get_indexer_non_unique_nas(self, nulls_fixture): # matching-but-not-identical nans if is_matching_na(nulls_fixture, float("NaN")): index = Index(["a", float("NaN"), "b", float("NaN")]) + match_but_not_identical = True + elif is_matching_na(nulls_fixture, Decimal("NaN")): + index = Index(["a", Decimal("NaN"), "b", Decimal("NaN")]) + match_but_not_identical = True + else: + match_but_not_identical = False + + if match_but_not_identical: indexer, missing = index.get_indexer_non_unique([nulls_fixture]) expected_indexer = np.array([1, 3], dtype=np.intp) tm.assert_numpy_array_equal(indexer, expected_indexer) tm.assert_numpy_array_equal(missing, expected_missing) + @pytest.mark.filterwarnings("ignore:elementwise comp:DeprecationWarning") + def test_get_indexer_non_unique_np_nats(self, np_nat_fixture, np_nat_fixture2): + expected_missing = np.array([], dtype=np.intp) + # matching-but-not-identical nats + if is_matching_na(np_nat_fixture, np_nat_fixture2): + # ensure nats are different objects + index = Index( + np.array( + ["2021-10-02", np_nat_fixture.copy(), np_nat_fixture2.copy()], + dtype=object, + ), + dtype=object, + ) + # pass as index to prevent target from being casted to DatetimeIndex + indexer, missing = index.get_indexer_non_unique( + Index([np_nat_fixture], dtype=object) + ) + expected_indexer = np.array([1, 2], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected_indexer) + tm.assert_numpy_array_equal(missing, expected_missing) + # dt64nat vs td64nat + else: + index = Index( + np.array( + [ + "2021-10-02", + np_nat_fixture, + np_nat_fixture2, + np_nat_fixture, + np_nat_fixture2, + ], + dtype=object, + ), + dtype=object, + ) + # pass as index to prevent target from being casted to DatetimeIndex + indexer, missing = index.get_indexer_non_unique( + Index([np_nat_fixture], dtype=object) + ) + expected_indexer = np.array([1, 3], dtype=np.intp) + tm.assert_numpy_array_equal(indexer, expected_indexer) + tm.assert_numpy_array_equal(missing, expected_missing) + class TestSliceLocs: @pytest.mark.parametrize( diff --git a/pandas/tests/indexes/test_indexing.py b/pandas/tests/indexes/test_indexing.py index ff2cd76ab6377..0a001008c2f1b 100644 --- a/pandas/tests/indexes/test_indexing.py +++ b/pandas/tests/indexes/test_indexing.py @@ -320,6 +320,11 @@ def test_maybe_cast_slice_bound_kind_deprecated(index): np.array([1, 2], dtype=np.intp), ), (["a", "b", "a", np.nan], [np.nan], np.array([3], dtype=np.intp)), + ( + np.array(["b", np.nan, float("NaN"), "b"], dtype=object), + Index([np.nan], dtype=object), + np.array([1, 2], dtype=np.intp), + ), ], ) def test_get_indexer_non_unique_multiple_nans(idx, target, expected):