Skip to content

Commit 447d2c5

Browse files
jbrockmendeljreback
authored andcommitted
BUG: fix+test PA+all-NaT TDA (#27739)
1 parent b6a8aee commit 447d2c5

File tree

3 files changed

+55
-33
lines changed

3 files changed

+55
-33
lines changed

pandas/core/arrays/period.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,12 @@ def _add_delta_tdi(self, other):
714714
"""
715715
assert isinstance(self.freq, Tick) # checked by calling function
716716

717-
delta = self._check_timedeltalike_freq_compat(other)
717+
if not np.all(isna(other)):
718+
delta = self._check_timedeltalike_freq_compat(other)
719+
else:
720+
# all-NaT TimedeltaIndex is equivalent to a single scalar td64 NaT
721+
return self + np.timedelta64("NaT")
722+
718723
return self._addsub_int_array(delta, operator.add).asi8
719724

720725
def _add_delta(self, other):

pandas/core/ops/__init__.py

+21-32
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66
import datetime
77
import operator
8-
from typing import Any, Callable
8+
from typing import Any, Callable, Tuple
99

1010
import numpy as np
1111

@@ -42,7 +42,6 @@
4242
ABCSeries,
4343
ABCSparseArray,
4444
ABCSparseSeries,
45-
ABCTimedeltaArray,
4645
)
4746
from pandas.core.dtypes.missing import isna, notna
4847

@@ -134,14 +133,15 @@ def _maybe_match_name(a, b):
134133
return None
135134

136135

137-
def maybe_upcast_for_op(obj):
136+
def maybe_upcast_for_op(obj, shape: Tuple[int, ...]):
138137
"""
139138
Cast non-pandas objects to pandas types to unify behavior of arithmetic
140139
and comparison operations.
141140
142141
Parameters
143142
----------
144143
obj: object
144+
shape : tuple[int]
145145
146146
Returns
147147
-------
@@ -157,13 +157,22 @@ def maybe_upcast_for_op(obj):
157157
# implementation; otherwise operation against numeric-dtype
158158
# raises TypeError
159159
return Timedelta(obj)
160-
elif isinstance(obj, np.timedelta64) and not isna(obj):
160+
elif isinstance(obj, np.timedelta64):
161+
if isna(obj):
162+
# wrapping timedelta64("NaT") in Timedelta returns NaT,
163+
# which would incorrectly be treated as a datetime-NaT, so
164+
# we broadcast and wrap in a Series
165+
right = np.broadcast_to(obj, shape)
166+
167+
# Note: we use Series instead of TimedeltaIndex to avoid having
168+
# to worry about catching NullFrequencyError.
169+
return pd.Series(right)
170+
161171
# In particular non-nanosecond timedelta64 needs to be cast to
162172
# nanoseconds, or else we get undesired behavior like
163173
# np.timedelta64(3, 'D') / 2 == np.timedelta64(1, 'D')
164-
# The isna check is to avoid casting timedelta64("NaT"), which would
165-
# return NaT and incorrectly be treated as a datetime-NaT.
166174
return Timedelta(obj)
175+
167176
elif isinstance(obj, np.ndarray) and is_timedelta64_dtype(obj):
168177
# GH#22390 Unfortunately we need to special-case right-hand
169178
# timedelta64 dtypes because numpy casts integer dtypes to
@@ -975,7 +984,7 @@ def wrapper(left, right):
975984

976985
left, right = _align_method_SERIES(left, right)
977986
res_name = get_op_result_name(left, right)
978-
right = maybe_upcast_for_op(right)
987+
right = maybe_upcast_for_op(right, left.shape)
979988

980989
if is_categorical_dtype(left):
981990
raise TypeError(
@@ -1003,31 +1012,11 @@ def wrapper(left, right):
10031012
return construct_result(left, result, index=left.index, name=res_name)
10041013

10051014
elif is_timedelta64_dtype(right):
1006-
# We should only get here with non-scalar or timedelta64('NaT')
1007-
# values for right
1008-
# Note: we cannot use dispatch_to_index_op because
1009-
# that may incorrectly raise TypeError when we
1010-
# should get NullFrequencyError
1011-
orig_right = right
1012-
if is_scalar(right):
1013-
# broadcast and wrap in a TimedeltaIndex
1014-
assert np.isnat(right)
1015-
right = np.broadcast_to(right, left.shape)
1016-
right = pd.TimedeltaIndex(right)
1017-
1018-
assert isinstance(right, (pd.TimedeltaIndex, ABCTimedeltaArray, ABCSeries))
1019-
try:
1020-
result = op(left._values, right)
1021-
except NullFrequencyError:
1022-
if orig_right is not right:
1023-
# i.e. scalar timedelta64('NaT')
1024-
# We get a NullFrequencyError because we broadcast to
1025-
# TimedeltaIndex, but this should be TypeError.
1026-
raise TypeError(
1027-
"incompatible type for a datetime/timedelta "
1028-
"operation [{name}]".format(name=op.__name__)
1029-
)
1030-
raise
1015+
# We should only get here with non-scalar values for right
1016+
# upcast by maybe_upcast_for_op
1017+
assert not isinstance(right, (np.timedelta64, np.ndarray))
1018+
1019+
result = op(left._values, right)
10311020

10321021
# We do not pass dtype to ensure that the Series constructor
10331022
# does inference in the case where `result` has object-dtype.

pandas/tests/arithmetic/test_period.py

+28
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import pandas as pd
1313
from pandas import Period, PeriodIndex, Series, period_range
1414
from pandas.core import ops
15+
from pandas.core.arrays import TimedeltaArray
1516
import pandas.util.testing as tm
1617

1718
from pandas.tseries.frequencies import to_offset
@@ -1013,6 +1014,33 @@ def test_parr_add_sub_td64_nat(self, box_transpose_fail):
10131014
with pytest.raises(TypeError):
10141015
other - obj
10151016

1017+
@pytest.mark.parametrize(
1018+
"other",
1019+
[
1020+
np.array(["NaT"] * 9, dtype="m8[ns]"),
1021+
TimedeltaArray._from_sequence(["NaT"] * 9),
1022+
],
1023+
)
1024+
def test_parr_add_sub_tdt64_nat_array(self, box_df_fail, other):
1025+
# FIXME: DataFrame fails because when when operating column-wise
1026+
# timedelta64 entries become NaT and are treated like datetimes
1027+
box = box_df_fail
1028+
1029+
pi = pd.period_range("1994-04-01", periods=9, freq="19D")
1030+
expected = pd.PeriodIndex(["NaT"] * 9, freq="19D")
1031+
1032+
obj = tm.box_expected(pi, box)
1033+
expected = tm.box_expected(expected, box)
1034+
1035+
result = obj + other
1036+
tm.assert_equal(result, expected)
1037+
result = other + obj
1038+
tm.assert_equal(result, expected)
1039+
result = obj - other
1040+
tm.assert_equal(result, expected)
1041+
with pytest.raises(TypeError):
1042+
other - obj
1043+
10161044

10171045
class TestPeriodSeriesArithmetic:
10181046
def test_ops_series_timedelta(self):

0 commit comments

Comments
 (0)