Skip to content

Commit c1f7b39

Browse files
jbrockmendeljreback
authored andcommitted
BUG: fix+test op(NaT, ndarray), also simplify (#27807)
1 parent a656d24 commit c1f7b39

File tree

5 files changed

+131
-74
lines changed

5 files changed

+131
-74
lines changed

pandas/_libs/tslibs/nattype.pyx

+74-30
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,9 @@ cdef class _NaT(datetime):
9292
# int64_t value
9393
# object freq
9494

95+
# higher than np.ndarray and np.matrix
96+
__array_priority__ = 100
97+
9598
def __hash__(_NaT self):
9699
# py3k needs this defined here
97100
return hash(self.value)
@@ -103,61 +106,102 @@ cdef class _NaT(datetime):
103106
if ndim == -1:
104107
return _nat_scalar_rules[op]
105108

106-
if ndim == 0:
109+
elif util.is_array(other):
110+
result = np.empty(other.shape, dtype=np.bool_)
111+
result.fill(_nat_scalar_rules[op])
112+
return result
113+
114+
elif ndim == 0:
107115
if is_datetime64_object(other):
108116
return _nat_scalar_rules[op]
109117
else:
110118
raise TypeError('Cannot compare type %r with type %r' %
111119
(type(self).__name__, type(other).__name__))
120+
112121
# Note: instead of passing "other, self, _reverse_ops[op]", we observe
113122
# that `_nat_scalar_rules` is invariant under `_reverse_ops`,
114123
# rendering it unnecessary.
115124
return PyObject_RichCompare(other, self, op)
116125

117126
def __add__(self, other):
127+
if self is not c_NaT:
128+
# cython __radd__ semantics
129+
self, other = other, self
130+
118131
if PyDateTime_Check(other):
119132
return c_NaT
120-
133+
elif PyDelta_Check(other):
134+
return c_NaT
135+
elif is_datetime64_object(other) or is_timedelta64_object(other):
136+
return c_NaT
121137
elif hasattr(other, 'delta'):
122138
# Timedelta, offsets.Tick, offsets.Week
123139
return c_NaT
124-
elif getattr(other, '_typ', None) in ['dateoffset', 'series',
125-
'period', 'datetimeindex',
126-
'datetimearray',
127-
'timedeltaindex',
128-
'timedeltaarray']:
129-
# Duplicate logic in _Timestamp.__add__ to avoid needing
130-
# to subclass; allows us to @final(_Timestamp.__add__)
131-
return NotImplemented
132-
return c_NaT
140+
141+
elif is_integer_object(other) or util.is_period_object(other):
142+
# For Period compat
143+
# TODO: the integer behavior is deprecated, remove it
144+
return c_NaT
145+
146+
elif util.is_array(other):
147+
if other.dtype.kind in 'mM':
148+
# If we are adding to datetime64, we treat NaT as timedelta
149+
# Either way, result dtype is datetime64
150+
result = np.empty(other.shape, dtype="datetime64[ns]")
151+
result.fill("NaT")
152+
return result
153+
154+
return NotImplemented
133155

134156
def __sub__(self, other):
135157
# Duplicate some logic from _Timestamp.__sub__ to avoid needing
136158
# to subclass; allows us to @final(_Timestamp.__sub__)
159+
cdef:
160+
bint is_rsub = False
161+
162+
if self is not c_NaT:
163+
# cython __rsub__ semantics
164+
self, other = other, self
165+
is_rsub = True
166+
137167
if PyDateTime_Check(other):
138-
return NaT
168+
return c_NaT
139169
elif PyDelta_Check(other):
140-
return NaT
170+
return c_NaT
171+
elif is_datetime64_object(other) or is_timedelta64_object(other):
172+
return c_NaT
173+
elif hasattr(other, 'delta'):
174+
# offsets.Tick, offsets.Week
175+
return c_NaT
141176

142-
elif getattr(other, '_typ', None) == 'datetimeindex':
143-
# a Timestamp-DatetimeIndex -> yields a negative TimedeltaIndex
144-
return -other.__sub__(self)
177+
elif is_integer_object(other) or util.is_period_object(other):
178+
# For Period compat
179+
# TODO: the integer behavior is deprecated, remove it
180+
return c_NaT
145181

146-
elif getattr(other, '_typ', None) == 'timedeltaindex':
147-
# a Timestamp-TimedeltaIndex -> yields a negative TimedeltaIndex
148-
return (-other).__add__(self)
182+
elif util.is_array(other):
183+
if other.dtype.kind == 'm':
184+
if not is_rsub:
185+
# NaT - timedelta64 we treat NaT as datetime64, so result
186+
# is datetime64
187+
result = np.empty(other.shape, dtype="datetime64[ns]")
188+
result.fill("NaT")
189+
return result
190+
191+
# timedelta64 - NaT we have to treat NaT as timedelta64
192+
# for this to be meaningful, and the result is timedelta64
193+
result = np.empty(other.shape, dtype="timedelta64[ns]")
194+
result.fill("NaT")
195+
return result
196+
197+
elif other.dtype.kind == 'M':
198+
# We treat NaT as a datetime, so regardless of whether this is
199+
# NaT - other or other - NaT, the result is timedelta64
200+
result = np.empty(other.shape, dtype="timedelta64[ns]")
201+
result.fill("NaT")
202+
return result
149203

150-
elif hasattr(other, 'delta'):
151-
# offsets.Tick, offsets.Week
152-
neg_other = -other
153-
return self + neg_other
154-
155-
elif getattr(other, '_typ', None) in ['period', 'series',
156-
'periodindex', 'dateoffset',
157-
'datetimearray',
158-
'timedeltaarray']:
159-
return NotImplemented
160-
return NaT
204+
return NotImplemented
161205

162206
def __pos__(self):
163207
return NaT

pandas/core/indexes/datetimes.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import numpy as np
66

7-
from pandas._libs import Timestamp, index as libindex, lib, tslib as libts
7+
from pandas._libs import NaT, Timestamp, index as libindex, lib, tslib as libts
88
import pandas._libs.join as libjoin
99
from pandas._libs.tslibs import ccalendar, fields, parsing, timezones
1010
from pandas.util._decorators import Appender, Substitution, cache_readonly
@@ -1281,7 +1281,9 @@ def insert(self, loc, item):
12811281
raise ValueError("Passed item and index have different timezone")
12821282
# check freq can be preserved on edge cases
12831283
if self.size and self.freq is not None:
1284-
if (loc == 0 or loc == -len(self)) and item + self.freq == self[0]:
1284+
if item is NaT:
1285+
pass
1286+
elif (loc == 0 or loc == -len(self)) and item + self.freq == self[0]:
12851287
freq = self.freq
12861288
elif (loc == len(self)) and item - self.freq == self[-1]:
12871289
freq = self.freq

pandas/tests/arithmetic/test_timedelta64.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,7 @@ def test_td64arr_div_nat_invalid(self, box_with_array):
16101610
rng = timedelta_range("1 days", "10 days", name="foo")
16111611
rng = tm.box_expected(rng, box_with_array)
16121612

1613-
with pytest.raises(TypeError, match="'?true_divide'? cannot use operands"):
1613+
with pytest.raises(TypeError, match="unsupported operand type"):
16141614
rng / pd.NaT
16151615
with pytest.raises(TypeError, match="Cannot divide NaTType by"):
16161616
pd.NaT / rng

pandas/tests/scalar/period/test_period.py

+7-39
Original file line numberDiff line numberDiff line change
@@ -1298,23 +1298,13 @@ def test_add_offset_nat(self):
12981298
timedelta(365),
12991299
]:
13001300
assert p + o is NaT
1301-
1302-
if isinstance(o, np.timedelta64):
1303-
with pytest.raises(TypeError):
1304-
o + p
1305-
else:
1306-
assert o + p is NaT
1301+
assert o + p is NaT
13071302

13081303
for freq in ["M", "2M", "3M"]:
13091304
p = Period("NaT", freq=freq)
13101305
for o in [offsets.MonthEnd(2), offsets.MonthEnd(12)]:
13111306
assert p + o is NaT
1312-
1313-
if isinstance(o, np.timedelta64):
1314-
with pytest.raises(TypeError):
1315-
o + p
1316-
else:
1317-
assert o + p is NaT
1307+
assert o + p is NaT
13181308

13191309
for o in [
13201310
offsets.YearBegin(2),
@@ -1324,12 +1314,7 @@ def test_add_offset_nat(self):
13241314
timedelta(365),
13251315
]:
13261316
assert p + o is NaT
1327-
1328-
if isinstance(o, np.timedelta64):
1329-
with pytest.raises(TypeError):
1330-
o + p
1331-
else:
1332-
assert o + p is NaT
1317+
assert o + p is NaT
13331318

13341319
# freq is Tick
13351320
for freq in ["D", "2D", "3D"]:
@@ -1343,12 +1328,7 @@ def test_add_offset_nat(self):
13431328
timedelta(hours=48),
13441329
]:
13451330
assert p + o is NaT
1346-
1347-
if isinstance(o, np.timedelta64):
1348-
with pytest.raises(TypeError):
1349-
o + p
1350-
else:
1351-
assert o + p is NaT
1331+
assert o + p is NaT
13521332

13531333
for o in [
13541334
offsets.YearBegin(2),
@@ -1358,12 +1338,7 @@ def test_add_offset_nat(self):
13581338
timedelta(hours=23),
13591339
]:
13601340
assert p + o is NaT
1361-
1362-
if isinstance(o, np.timedelta64):
1363-
with pytest.raises(TypeError):
1364-
o + p
1365-
else:
1366-
assert o + p is NaT
1341+
assert o + p is NaT
13671342

13681343
for freq in ["H", "2H", "3H"]:
13691344
p = Period("NaT", freq=freq)
@@ -1376,9 +1351,7 @@ def test_add_offset_nat(self):
13761351
timedelta(days=4, minutes=180),
13771352
]:
13781353
assert p + o is NaT
1379-
1380-
if not isinstance(o, np.timedelta64):
1381-
assert o + p is NaT
1354+
assert o + p is NaT
13821355

13831356
for o in [
13841357
offsets.YearBegin(2),
@@ -1388,12 +1361,7 @@ def test_add_offset_nat(self):
13881361
timedelta(hours=23, minutes=30),
13891362
]:
13901363
assert p + o is NaT
1391-
1392-
if isinstance(o, np.timedelta64):
1393-
with pytest.raises(TypeError):
1394-
o + p
1395-
else:
1396-
assert o + p is NaT
1364+
assert o + p is NaT
13971365

13981366
def test_sub_offset(self):
13991367
# freq is DateOffset

pandas/tests/scalar/test_nat.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from datetime import datetime, timedelta
2+
import operator
23

34
import numpy as np
45
import pytest
@@ -21,6 +22,7 @@
2122
isna,
2223
)
2324
from pandas.core.arrays import DatetimeArray, PeriodArray, TimedeltaArray
25+
from pandas.core.ops import roperator
2426
from pandas.util import testing as tm
2527

2628

@@ -333,8 +335,9 @@ def test_nat_doc_strings(compare):
333335
"value,val_type",
334336
[
335337
(2, "scalar"),
336-
(1.5, "scalar"),
337-
(np.nan, "scalar"),
338+
(1.5, "floating"),
339+
(np.nan, "floating"),
340+
("foo", "str"),
338341
(timedelta(3600), "timedelta"),
339342
(Timedelta("5s"), "timedelta"),
340343
(datetime(2014, 1, 1), "timestamp"),
@@ -348,6 +351,14 @@ def test_nat_arithmetic_scalar(op_name, value, val_type):
348351
# see gh-6873
349352
invalid_ops = {
350353
"scalar": {"right_div_left"},
354+
"floating": {
355+
"right_div_left",
356+
"left_minus_right",
357+
"right_minus_left",
358+
"left_plus_right",
359+
"right_plus_left",
360+
},
361+
"str": set(_ops.keys()),
351362
"timedelta": {"left_times_right", "right_times_left"},
352363
"timestamp": {
353364
"left_times_right",
@@ -366,6 +377,16 @@ def test_nat_arithmetic_scalar(op_name, value, val_type):
366377
and isinstance(value, Timedelta)
367378
):
368379
msg = "Cannot multiply"
380+
elif val_type == "str":
381+
# un-specific check here because the message comes from str
382+
# and varies by method
383+
msg = (
384+
"can only concatenate str|"
385+
"unsupported operand type|"
386+
"can't multiply sequence|"
387+
"Can't convert 'NaTType'|"
388+
"must be str, not NaTType"
389+
)
369390
else:
370391
msg = "unsupported operand type"
371392

@@ -435,6 +456,28 @@ def test_nat_arithmetic_td64_vector(op_name, box):
435456
tm.assert_equal(_ops[op_name](vec, NaT), box_nat)
436457

437458

459+
@pytest.mark.parametrize(
460+
"dtype,op,out_dtype",
461+
[
462+
("datetime64[ns]", operator.add, "datetime64[ns]"),
463+
("datetime64[ns]", roperator.radd, "datetime64[ns]"),
464+
("datetime64[ns]", operator.sub, "timedelta64[ns]"),
465+
("datetime64[ns]", roperator.rsub, "timedelta64[ns]"),
466+
("timedelta64[ns]", operator.add, "datetime64[ns]"),
467+
("timedelta64[ns]", roperator.radd, "datetime64[ns]"),
468+
("timedelta64[ns]", operator.sub, "datetime64[ns]"),
469+
("timedelta64[ns]", roperator.rsub, "timedelta64[ns]"),
470+
],
471+
)
472+
def test_nat_arithmetic_ndarray(dtype, op, out_dtype):
473+
other = np.arange(10).astype(dtype)
474+
result = op(NaT, other)
475+
476+
expected = np.empty(other.shape, dtype=out_dtype)
477+
expected.fill("NaT")
478+
tm.assert_numpy_array_equal(result, expected)
479+
480+
438481
def test_nat_pinned_docstrings():
439482
# see gh-17327
440483
assert NaT.ctime.__doc__ == datetime.ctime.__doc__

0 commit comments

Comments
 (0)