Skip to content

Commit 950408e

Browse files
authored
ENH: dt64/td64 comparison support non-nano (pandas-dev#47691)
* ENH: dt64/td64 comparison support non-nano * mypy fixup
1 parent c711be0 commit 950408e

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

pandas/_libs/tslibs/np_datetime.pyi

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import numpy as np
22

3+
from pandas._typing import npt
4+
35
class OutOfBoundsDatetime(ValueError): ...
46
class OutOfBoundsTimedelta(ValueError): ...
57

@@ -10,3 +12,6 @@ def astype_overflowsafe(
1012
arr: np.ndarray, dtype: np.dtype, copy: bool = ...
1113
) -> np.ndarray: ...
1214
def is_unitless(dtype: np.dtype) -> bool: ...
15+
def compare_mismatched_resolutions(
16+
left: np.ndarray, right: np.ndarray, op
17+
) -> npt.NDArray[np.bool_]: ...

pandas/_libs/tslibs/np_datetime.pyx

+80
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,14 @@ from cpython.object cimport (
2020
import_datetime()
2121

2222
import numpy as np
23+
2324
cimport numpy as cnp
2425

2526
cnp.import_array()
2627
from numpy cimport (
2728
int64_t,
2829
ndarray,
30+
uint8_t,
2931
)
3032

3133
from pandas._libs.tslibs.util cimport get_c_string_buf_and_size
@@ -370,3 +372,81 @@ cpdef ndarray astype_overflowsafe(
370372
cnp.PyArray_MultiIter_NEXT(mi)
371373

372374
return iresult.view(dtype)
375+
376+
377+
# TODO: try to upstream this fix to numpy
378+
def compare_mismatched_resolutions(ndarray left, ndarray right, op):
379+
"""
380+
Overflow-safe comparison of timedelta64/datetime64 with mismatched resolutions.
381+
382+
>>> left = np.array([500], dtype="M8[Y]")
383+
>>> right = np.array([0], dtype="M8[ns]")
384+
>>> left < right # <- wrong!
385+
array([ True])
386+
"""
387+
388+
if left.dtype.kind != right.dtype.kind or left.dtype.kind not in ["m", "M"]:
389+
raise ValueError("left and right must both be timedelta64 or both datetime64")
390+
391+
cdef:
392+
int op_code = op_to_op_code(op)
393+
NPY_DATETIMEUNIT left_unit = get_unit_from_dtype(left.dtype)
394+
NPY_DATETIMEUNIT right_unit = get_unit_from_dtype(right.dtype)
395+
396+
# equiv: result = np.empty((<object>left).shape, dtype="bool")
397+
ndarray result = cnp.PyArray_EMPTY(
398+
left.ndim, left.shape, cnp.NPY_BOOL, 0
399+
)
400+
401+
ndarray lvalues = left.view("i8")
402+
ndarray rvalues = right.view("i8")
403+
404+
cnp.broadcast mi = cnp.PyArray_MultiIterNew3(result, lvalues, rvalues)
405+
int64_t lval, rval
406+
bint res_value
407+
408+
Py_ssize_t i, N = left.size
409+
npy_datetimestruct ldts, rdts
410+
411+
412+
for i in range(N):
413+
# Analogous to: lval = lvalues[i]
414+
lval = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 1))[0]
415+
416+
# Analogous to: rval = rvalues[i]
417+
rval = (<int64_t*>cnp.PyArray_MultiIter_DATA(mi, 2))[0]
418+
419+
if lval == NPY_DATETIME_NAT or rval == NPY_DATETIME_NAT:
420+
res_value = op_code == Py_NE
421+
422+
else:
423+
pandas_datetime_to_datetimestruct(lval, left_unit, &ldts)
424+
pandas_datetime_to_datetimestruct(rval, right_unit, &rdts)
425+
426+
res_value = cmp_dtstructs(&ldts, &rdts, op_code)
427+
428+
# Analogous to: result[i] = res_value
429+
(<uint8_t*>cnp.PyArray_MultiIter_DATA(mi, 0))[0] = res_value
430+
431+
cnp.PyArray_MultiIter_NEXT(mi)
432+
433+
return result
434+
435+
436+
import operator
437+
438+
439+
cdef int op_to_op_code(op):
440+
# TODO: should exist somewhere?
441+
if op is operator.eq:
442+
return Py_EQ
443+
if op is operator.ne:
444+
return Py_NE
445+
if op is operator.le:
446+
return Py_LE
447+
if op is operator.lt:
448+
return Py_LT
449+
if op is operator.ge:
450+
return Py_GE
451+
if op is operator.gt:
452+
return Py_GT

pandas/core/arrays/datetimelike.py

+19
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
RoundTo,
4747
round_nsint64,
4848
)
49+
from pandas._libs.tslibs.np_datetime import compare_mismatched_resolutions
4950
from pandas._libs.tslibs.timestamps import integer_op_not_supported
5051
from pandas._typing import (
5152
ArrayLike,
@@ -1065,6 +1066,24 @@ def _cmp_method(self, other, op):
10651066
)
10661067
return result
10671068

1069+
if other is NaT:
1070+
if op is operator.ne:
1071+
result = np.ones(self.shape, dtype=bool)
1072+
else:
1073+
result = np.zeros(self.shape, dtype=bool)
1074+
return result
1075+
1076+
if not is_period_dtype(self.dtype):
1077+
self = cast(TimelikeOps, self)
1078+
if self._reso != other._reso:
1079+
if not isinstance(other, type(self)):
1080+
# i.e. Timedelta/Timestamp, cast to ndarray and let
1081+
# compare_mismatched_resolutions handle broadcasting
1082+
other_arr = np.array(other.asm8)
1083+
else:
1084+
other_arr = other._ndarray
1085+
return compare_mismatched_resolutions(self._ndarray, other_arr, op)
1086+
10681087
other_vals = self._unbox(other)
10691088
# GH#37462 comparison on i8 values is almost 2x faster than M8/m8
10701089
result = op(self._ndarray.view("i8"), other_vals.view("i8"))

pandas/tests/arrays/test_datetimes.py

+38
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Tests for DatetimeArray
33
"""
4+
import operator
5+
46
import numpy as np
57
import pytest
68

@@ -169,6 +171,42 @@ def test_repr(self, dta_dti, unit):
169171

170172
assert repr(dta) == repr(dti._data).replace("[ns", f"[{unit}")
171173

174+
# TODO: tests with td64
175+
def test_compare_mismatched_resolutions(self, comparison_op):
176+
# comparison that numpy gets wrong bc of silent overflows
177+
op = comparison_op
178+
179+
iinfo = np.iinfo(np.int64)
180+
vals = np.array([iinfo.min, iinfo.min + 1, iinfo.max], dtype=np.int64)
181+
182+
# Construct so that arr2[1] < arr[1] < arr[2] < arr2[2]
183+
arr = np.array(vals).view("M8[ns]")
184+
arr2 = arr.view("M8[s]")
185+
186+
left = DatetimeArray._simple_new(arr, dtype=arr.dtype)
187+
right = DatetimeArray._simple_new(arr2, dtype=arr2.dtype)
188+
189+
if comparison_op is operator.eq:
190+
expected = np.array([False, False, False])
191+
elif comparison_op is operator.ne:
192+
expected = np.array([True, True, True])
193+
elif comparison_op in [operator.lt, operator.le]:
194+
expected = np.array([False, False, True])
195+
else:
196+
expected = np.array([False, True, False])
197+
198+
result = op(left, right)
199+
tm.assert_numpy_array_equal(result, expected)
200+
201+
result = op(left[1], right)
202+
tm.assert_numpy_array_equal(result, expected)
203+
204+
if op not in [operator.eq, operator.ne]:
205+
# check that numpy still gets this wrong; if it is fixed we may be
206+
# able to remove compare_mismatched_resolutions
207+
np_res = op(left._ndarray, right._ndarray)
208+
tm.assert_numpy_array_equal(np_res[1:], ~expected[1:])
209+
172210

173211
class TestDatetimeArrayComparisons:
174212
# TODO: merge this into tests/arithmetic/test_datetime64 once it is

0 commit comments

Comments
 (0)