Skip to content

Commit 7f2aa8f

Browse files
authored
ENH: pd.NA comparison with time, date, timedelta (#50901)
* ENH: pd.NA comparison with time, date, timedelta * mypy fixup * fix on nullable dtypes
1 parent d50c3cc commit 7f2aa8f

File tree

8 files changed

+46
-49
lines changed

8 files changed

+46
-49
lines changed

pandas/_libs/missing.pyx

+9
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,11 @@ import numbers
33
from sys import maxsize
44

55
cimport cython
6+
from cpython.datetime cimport (
7+
date,
8+
time,
9+
timedelta,
10+
)
611
from cython cimport Py_ssize_t
712

813
import numpy as np
@@ -307,6 +312,7 @@ def is_numeric_na(values: ndarray) -> ndarray:
307312

308313

309314
def _create_binary_propagating_op(name, is_divmod=False):
315+
is_cmp = name.strip("_") in ["eq", "ne", "le", "lt", "ge", "gt"]
310316

311317
def method(self, other):
312318
if (other is C_NA or isinstance(other, (str, bytes))
@@ -329,6 +335,9 @@ def _create_binary_propagating_op(name, is_divmod=False):
329335
else:
330336
return out
331337

338+
elif is_cmp and isinstance(other, (date, time, timedelta)):
339+
return NA
340+
332341
return NotImplemented
333342

334343
method.__name__ = name

pandas/tests/extension/base/methods.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
import pytest
66

7+
from pandas._typing import Dtype
8+
79
from pandas.core.dtypes.common import is_bool_dtype
810
from pandas.core.dtypes.missing import na_value_for_dtype
911

@@ -260,6 +262,9 @@ def test_fillna_length_mismatch(self, data_missing):
260262
with pytest.raises(ValueError, match=msg):
261263
data_missing.fillna(data_missing.take([1]))
262264

265+
# Subclasses can override if we expect e.g Sparse[bool], boolean, pyarrow[bool]
266+
_combine_le_expected_dtype: Dtype = np.dtype(bool)
267+
263268
def test_combine_le(self, data_repeated):
264269
# GH 20825
265270
# Test that combine works when doing a <= (le) comparison
@@ -268,13 +273,17 @@ def test_combine_le(self, data_repeated):
268273
s2 = pd.Series(orig_data2)
269274
result = s1.combine(s2, lambda x1, x2: x1 <= x2)
270275
expected = pd.Series(
271-
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))]
276+
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
277+
dtype=self._combine_le_expected_dtype,
272278
)
273279
self.assert_series_equal(result, expected)
274280

275281
val = s1.iloc[0]
276282
result = s1.combine(val, lambda x1, x2: x1 <= x2)
277-
expected = pd.Series([a <= val for a in list(orig_data1)])
283+
expected = pd.Series(
284+
[a <= val for a in list(orig_data1)],
285+
dtype=self._combine_le_expected_dtype,
286+
)
278287
self.assert_series_equal(result, expected)
279288

280289
def test_combine_add(self, data_repeated):

pandas/tests/extension/test_arrow.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -972,11 +972,7 @@ def test_factorize(self, data_for_grouping, request):
972972
)
973973
super().test_factorize(data_for_grouping)
974974

975-
@pytest.mark.xfail(
976-
reason="result dtype pyarrow[bool] better than expected dtype object"
977-
)
978-
def test_combine_le(self, data_repeated):
979-
super().test_combine_le(data_repeated)
975+
_combine_le_expected_dtype = "bool[pyarrow]"
980976

981977
def test_combine_add(self, data_repeated, request):
982978
pa_dtype = next(data_repeated(1)).dtype.pyarrow_dtype

pandas/tests/extension/test_boolean.py

+2-17
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ class TestReshaping(base.BaseReshapingTests):
176176

177177

178178
class TestMethods(base.BaseMethodsTests):
179+
_combine_le_expected_dtype = "boolean"
180+
179181
def test_factorize(self, data_for_grouping):
180182
# override because we only have 2 unique values
181183
labels, uniques = pd.factorize(data_for_grouping, use_na_sentinel=True)
@@ -185,23 +187,6 @@ def test_factorize(self, data_for_grouping):
185187
tm.assert_numpy_array_equal(labels, expected_labels)
186188
self.assert_extension_array_equal(uniques, expected_uniques)
187189

188-
def test_combine_le(self, data_repeated):
189-
# override because expected needs to be boolean instead of bool dtype
190-
orig_data1, orig_data2 = data_repeated(2)
191-
s1 = pd.Series(orig_data1)
192-
s2 = pd.Series(orig_data2)
193-
result = s1.combine(s2, lambda x1, x2: x1 <= x2)
194-
expected = pd.Series(
195-
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
196-
dtype="boolean",
197-
)
198-
self.assert_series_equal(result, expected)
199-
200-
val = s1.iloc[0]
201-
result = s1.combine(val, lambda x1, x2: x1 <= x2)
202-
expected = pd.Series([a <= val for a in list(orig_data1)], dtype="boolean")
203-
self.assert_series_equal(result, expected)
204-
205190
def test_searchsorted(self, data_for_sorting, as_series):
206191
# override because we only have 2 unique values
207192
data_for_sorting = pd.array([True, False], dtype="boolean")

pandas/tests/extension/test_floating.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class TestMissing(base.BaseMissingTests):
173173

174174

175175
class TestMethods(base.BaseMethodsTests):
176-
pass
176+
_combine_le_expected_dtype = object # TODO: can we make this boolean?
177177

178178

179179
class TestCasting(base.BaseCastingTests):

pandas/tests/extension/test_integer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ class TestMissing(base.BaseMissingTests):
201201

202202

203203
class TestMethods(base.BaseMethodsTests):
204-
pass
204+
_combine_le_expected_dtype = object # TODO: can we make this boolean?
205205

206206

207207
class TestCasting(base.BaseCastingTests):

pandas/tests/extension/test_sparse.py

+1-22
Original file line numberDiff line numberDiff line change
@@ -270,28 +270,7 @@ def test_fillna_frame(self, data_missing):
270270

271271

272272
class TestMethods(BaseSparseTests, base.BaseMethodsTests):
273-
def test_combine_le(self, data_repeated):
274-
# We return a Series[SparseArray].__le__ returns a
275-
# Series[Sparse[bool]]
276-
# rather than Series[bool]
277-
orig_data1, orig_data2 = data_repeated(2)
278-
s1 = pd.Series(orig_data1)
279-
s2 = pd.Series(orig_data2)
280-
result = s1.combine(s2, lambda x1, x2: x1 <= x2)
281-
expected = pd.Series(
282-
SparseArray(
283-
[a <= b for (a, b) in zip(list(orig_data1), list(orig_data2))],
284-
fill_value=False,
285-
)
286-
)
287-
self.assert_series_equal(result, expected)
288-
289-
val = s1.iloc[0]
290-
result = s1.combine(val, lambda x1, x2: x1 <= x2)
291-
expected = pd.Series(
292-
SparseArray([a <= val for a in list(orig_data1)], fill_value=False)
293-
)
294-
self.assert_series_equal(result, expected)
273+
_combine_le_expected_dtype = "Sparse[bool]"
295274

296275
def test_fillna_copy_frame(self, data_missing):
297276
arr = data_missing.take([1, 1])

pandas/tests/scalar/test_na_scalar.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
from datetime import (
2+
date,
3+
time,
4+
timedelta,
5+
)
16
import pickle
27

38
import numpy as np
@@ -67,7 +72,21 @@ def test_arithmetic_ops(all_arithmetic_functions, other):
6772

6873

6974
@pytest.mark.parametrize(
70-
"other", [NA, 1, 1.0, "a", b"a", np.int64(1), np.nan, np.bool_(True)]
75+
"other",
76+
[
77+
NA,
78+
1,
79+
1.0,
80+
"a",
81+
b"a",
82+
np.int64(1),
83+
np.nan,
84+
np.bool_(True),
85+
time(0),
86+
date(1, 2, 3),
87+
timedelta(1),
88+
pd.NaT,
89+
],
7190
)
7291
def test_comparison_ops(comparison_op, other):
7392
assert comparison_op(NA, other) is NA

0 commit comments

Comments
 (0)