Skip to content

Commit 3136124

Browse files
Fix comparison between Datetime/Timedelta columns and NULL scalars (#7504)
Fixes #6897 Authors: - @brandon-b-miller Approvers: - GALI PREM SAGAR (@galipremsagar) - Ram (Ramakrishna Prabhu) (@rgsl888prabhu) URL: #7504
1 parent f38daf3 commit 3136124

File tree

3 files changed

+49
-0
lines changed

3 files changed

+49
-0
lines changed

python/cudf/cudf/core/column/datetime.py

+2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def normalize_binop_value(self, other: DatetimeLikeScalar) -> ScalarLike:
178178
return cudf.Scalar(None, dtype=other.dtype)
179179

180180
return cudf.Scalar(other)
181+
elif other is None:
182+
return cudf.Scalar(other, dtype=self.dtype)
181183
else:
182184
raise TypeError(f"cannot normalize {type(other)}")
183185

python/cudf/cudf/core/column/timedelta.py

+2
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ def normalize_binop_value(self, other) -> BinaryOperand:
275275
return cudf.Scalar(other)
276276
elif np.isscalar(other):
277277
return cudf.Scalar(other)
278+
elif other is None:
279+
return cudf.Scalar(other, dtype=self.dtype)
278280
else:
279281
raise TypeError(f"cannot normalize {type(other)}")
280282

python/cudf/cudf/tests/test_binops.py

+45
Original file line numberDiff line numberDiff line change
@@ -1773,6 +1773,51 @@ def decimal_series(input, dtype):
17731773
utils.assert_eq(expect, got)
17741774

17751775

1776+
@pytest.mark.parametrize(
1777+
"dtype",
1778+
[
1779+
"uint8",
1780+
"uint16",
1781+
"uint32",
1782+
"uint64",
1783+
"int8",
1784+
"int16",
1785+
"int32",
1786+
"int64",
1787+
"float32",
1788+
"float64",
1789+
"str",
1790+
"datetime64[ns]",
1791+
"datetime64[us]",
1792+
"datetime64[ms]",
1793+
"datetime64[s]",
1794+
"timedelta64[ns]",
1795+
"timedelta64[us]",
1796+
"timedelta64[ms]",
1797+
"timedelta64[s]",
1798+
],
1799+
)
1800+
@pytest.mark.parametrize("null_scalar", [None, cudf.NA, np.datetime64("NaT")])
1801+
@pytest.mark.parametrize("cmpop", _cmpops)
1802+
def test_column_null_scalar_comparison(dtype, null_scalar, cmpop):
1803+
# This test is meant to validate that comparing
1804+
# a series of any dtype with a null scalar produces
1805+
# a new series where all the elements are <NA>.
1806+
1807+
if isinstance(null_scalar, np.datetime64):
1808+
if np.dtype(dtype).kind not in "mM":
1809+
pytest.skip()
1810+
null_scalar = null_scalar.astype(dtype)
1811+
1812+
dtype = np.dtype(dtype)
1813+
1814+
data = [1, 2, 3, 4, 5]
1815+
sr = cudf.Series(data, dtype=dtype)
1816+
result = cmpop(sr, null_scalar)
1817+
1818+
assert result.isnull().all()
1819+
1820+
17761821
@pytest.mark.parametrize("fn", ["eq", "ne", "lt", "gt", "le", "ge"])
17771822
def test_equality_ops_index_mismatch(fn):
17781823
a = cudf.Series(

0 commit comments

Comments
 (0)