Skip to content

Commit 8e21e60

Browse files
committed
BUG: Prevent addition overflow with TimedeltaIndex
Expands checked-add array addition introduced in pandas-devgh-14237 to include all other addition cases (i.e. TimedeltaIndex and TimeDelta). Follow-up to pandas-devgh-14453.
1 parent 2466ecb commit 8e21e60

File tree

5 files changed

+87
-8
lines changed

5 files changed

+87
-8
lines changed

asv_bench/benchmarks/algorithms.py

+15
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def setup(self):
2424
self.arrneg = np.arange(-1000000, 0)
2525
self.arrmixed = np.array([1, -1]).repeat(500000)
2626

27+
self.arr_nan = np.random.choice([True, False], size=1000000)
28+
self.arrpos_nan = np.random.choice([True, False], size=1000000)
29+
self.arrneg_nan = np.random.choice([True, False], size=1000000)
30+
self.arrmixed_nan = np.random.choice([True, False], size=1000000)
31+
2732
def time_int_factorize(self):
2833
self.int.factorize()
2934

@@ -57,6 +62,16 @@ def time_add_overflow_neg_arr(self):
5762
def time_add_overflow_mixed_arr(self):
5863
self.checked_add(self.arr, self.arrmixed)
5964

65+
def time_add_overflow_first_arg_nan(self):
66+
self.checked_add(self.arr, self.arrmixed, arr_nans=self.arr_nan)
67+
68+
def time_add_overflow_second_arg_nan(self):
69+
self.checked_add(self.arr, self.arrmixed, b_nans=self.arrmixed_arr_nan)
70+
71+
def time_add_overflow_both_arg_nan(self):
72+
self.checked_add(self.arr, self.arrmixed, arr_nans=self.arr_nan,
73+
b_nans=self.arrmixed_arr_nan)
74+
6075

6176
class hashing(object):
6277
goal_time = 0.2

pandas/core/nanops.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -812,15 +812,21 @@ def unique1d(values):
812812
return uniques
813813

814814

815-
def _checked_add_with_arr(arr, b):
815+
def _checked_add_with_arr(arr, b, arr_nans=None, b_nans=None):
816816
"""
817+
Perform array addition that checks for underflow and overflow.
818+
817819
Performs the addition of an int64 array and an int64 integer (or array)
818-
but checks that they do not result in overflow first.
820+
but checks that they do not result in overflow first. For elements that
821+
are indicated to be NaN, whether or not there is overflow for that element
822+
is automatically ignored.
819823
820824
Parameters
821825
----------
822826
arr : array addend.
823827
b : array or scalar addend.
828+
arr_nans : array indicating which elements are NaN
829+
b_nans : array or scalar indicating which elements are NaN
824830
825831
Returns
826832
-------
@@ -843,6 +849,17 @@ def _checked_add_with_arr(arr, b):
843849
else:
844850
b2 = np.broadcast_to(b, arr.shape)
845851

852+
# For elements that are NaN, regardless of their value, we should
853+
# ignore whether they overflow or not when doing the checked add.
854+
if arr_nans is not None and b_nans is not None:
855+
not_nan = np.logical_not(arr_nans | b_nans)
856+
elif arr_nans is not None:
857+
not_nan = np.logical_not(arr_nans)
858+
elif b_nans is not None:
859+
not_nan = np.logical_not(b_nans)
860+
else:
861+
not_nan = np.array([True])
862+
846863
# gh-14324: For each element in 'arr' and its corresponding element
847864
# in 'b2', we check the sign of the element in 'b2'. If it is positive,
848865
# we then check whether its sum with the element in 'arr' exceeds
@@ -854,12 +871,14 @@ def _checked_add_with_arr(arr, b):
854871
mask2 = b2 < 0
855872

856873
if not mask1.any():
857-
to_raise = (np.iinfo(np.int64).min - b2 > arr).any()
874+
to_raise = ((np.iinfo(np.int64).min - b2 > arr) & not_nan).any()
858875
elif not mask2.any():
859-
to_raise = (np.iinfo(np.int64).max - b2 < arr).any()
876+
to_raise = ((np.iinfo(np.int64).max - b2 < arr) & not_nan).any()
860877
else:
861-
to_raise = ((np.iinfo(np.int64).max - b2[mask1] < arr[mask1]).any() or
862-
(np.iinfo(np.int64).min - b2[mask2] > arr[mask2]).any())
878+
to_raise = (((np.iinfo(np.int64).max -
879+
b2[mask1] < arr[mask1]) & not_nan[mask1]).any() or
880+
((np.iinfo(np.int64).min -
881+
b2[mask2] > arr[mask2]) & not_nan[mask2]).any())
863882

864883
if to_raise:
865884
raise OverflowError("Overflow in int64 addition")

pandas/tests/test_nanops.py

+27
Original file line numberDiff line numberDiff line change
@@ -1018,11 +1018,38 @@ def test_int64_add_overflow():
10181018
nanops._checked_add_with_arr(np.array([n, n]), np.array([n, n]))
10191019
with tm.assertRaisesRegexp(OverflowError, msg):
10201020
nanops._checked_add_with_arr(np.array([m, n]), np.array([n, n]))
1021+
with tm.assertRaisesRegexp(OverflowError, msg):
1022+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1023+
arr_nans=np.array([False, True]))
1024+
with tm.assertRaisesRegexp(OverflowError, msg):
1025+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1026+
b_nans=np.array([False, True]))
1027+
with tm.assertRaisesRegexp(OverflowError, msg):
1028+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1029+
arr_nans=np.array([False, True]),
1030+
b_nans=np.array([False, True]))
10211031
with tm.assertRaisesRegexp(OverflowError, msg):
10221032
with tm.assert_produces_warning(RuntimeWarning):
10231033
nanops._checked_add_with_arr(np.array([m, m]),
10241034
np.array([np.nan, m]))
10251035

1036+
# Check that the nan boolean arrays override whether or not
1037+
# the addition overflows. We don't check the result but just
1038+
# the fact that an OverflowError is not raised.
1039+
with tm.assertRaises(AssertionError):
1040+
with tm.assertRaisesRegexp(OverflowError, msg):
1041+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1042+
arr_nans=np.array([True, True]))
1043+
with tm.assertRaises(AssertionError):
1044+
with tm.assertRaisesRegexp(OverflowError, msg):
1045+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1046+
b_nans=np.array([True, True]))
1047+
with tm.assertRaises(AssertionError):
1048+
with tm.assertRaisesRegexp(OverflowError, msg):
1049+
nanops._checked_add_with_arr(np.array([m, m]), np.array([m, m]),
1050+
arr_nans=np.array([True, False]),
1051+
b_nans=np.array([False, True]))
1052+
10261053

10271054
if __name__ == '__main__':
10281055
import nose

pandas/tseries/base.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pandas.types.missing import isnull
1818
from pandas.core import common as com, algorithms
1919
from pandas.core.common import AbstractMethodError
20+
from pandas.core.nanops import _checked_add_with_arr
2021

2122
import pandas.formats.printing as printing
2223
import pandas.tslib as tslib
@@ -684,7 +685,8 @@ def _add_delta_td(self, other):
684685
# return the i8 result view
685686

686687
inc = tslib._delta_to_nanoseconds(other)
687-
new_values = (self.asi8 + inc).view('i8')
688+
new_values = _checked_add_with_arr(self.asi8, inc,
689+
arr_nans=self._isnan).view('i8')
688690
if self.hasnans:
689691
new_values[self._isnan] = tslib.iNaT
690692
return new_values.view('i8')
@@ -699,7 +701,9 @@ def _add_delta_tdi(self, other):
699701

700702
self_i8 = self.asi8
701703
other_i8 = other.asi8
702-
new_values = self_i8 + other_i8
704+
new_values = _checked_add_with_arr(self_i8, other_i8,
705+
arr_nans=self._isnan,
706+
b_nans=other._isnan)
703707
if self.hasnans or other.hasnans:
704708
mask = (self._isnan) | (other._isnan)
705709
new_values[mask] = tslib.iNaT

pandas/tseries/tests/test_timedeltas.py

+14
Original file line numberDiff line numberDiff line change
@@ -1964,6 +1964,20 @@ def test_add_overflow(self):
19641964
with tm.assertRaisesRegexp(OverflowError, msg):
19651965
Timestamp('2000') + to_timedelta([106580], 'D')
19661966

1967+
# These should not overflow!
1968+
exp = TimedeltaIndex([pd.NaT])
1969+
result = to_timedelta([pd.NaT]) - Timedelta('1 days')
1970+
tm.assert_index_equal(result, exp)
1971+
1972+
exp = TimedeltaIndex(['4 days', pd.NaT])
1973+
result = to_timedelta(['5 days', pd.NaT]) - Timedelta('1 days')
1974+
tm.assert_index_equal(result, exp)
1975+
1976+
exp = TimedeltaIndex([pd.NaT, pd.NaT, '5 hours'])
1977+
result = (to_timedelta([pd.NaT, '5 days', '1 hours']) +
1978+
to_timedelta(['7 seconds', pd.NaT, '4 hours']))
1979+
tm.assert_index_equal(result, exp)
1980+
19671981
if __name__ == '__main__':
19681982
nose.runmodule(argv=[__file__, '-vvs', '-x', '--pdb', '--pdb-failure'],
19691983
exit=False)

0 commit comments

Comments
 (0)