Skip to content

Commit 6f6da9a

Browse files
committed
fix kahans summation for the inf case
1 parent 9d55b50 commit 6f6da9a

File tree

3 files changed

+59
-8
lines changed

3 files changed

+59
-8
lines changed

pandas/_libs/groupby.pyx

+23-8
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ from numpy cimport (
2626
uint32_t,
2727
uint64_t,
2828
)
29-
from numpy.math cimport NAN
29+
from numpy.math cimport NAN, isinf
3030

3131
cnp.import_array()
3232

@@ -51,7 +51,14 @@ from pandas._libs.missing cimport checknull
5151
cdef int64_t NPY_NAT = get_nat()
5252
_int64_max = np.iinfo(np.int64).max
5353

54-
cdef float64_t NaN = <float64_t>np.NaN
54+
cdef:
55+
float32_t MINfloat32 = np.NINF
56+
float64_t MINfloat64 = np.NINF
57+
58+
float32_t MAXfloat32 = np.inf
59+
float64_t MAXfloat64 = np.inf
60+
61+
float64_t NaN = <float64_t>np.NaN
5562

5663
cdef enum InterpolationEnumType:
5764
INTERPOLATION_LINEAR,
@@ -251,13 +258,18 @@ def group_cumsum(numeric_t[:, ::1] out,
251258

252259
# For floats, use Kahan summation to reduce floating-point
253260
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
254-
if numeric_t == float32_t or numeric_t == float64_t:
261+
if numeric_t is float32_t or numeric_t is float64_t:
255262
if val == val:
256-
y = val - compensation[lab, j]
257-
t = accum[lab, j] + y
258-
compensation[lab, j] = t - accum[lab, j] - y
259-
accum[lab, j] = t
260-
out[i, j] = t
263+
# if val or accum are inf/-inf don't use kahan
264+
if isinf(val) or isinf(accum[lab, j]):
265+
accum[lab, j] += val
266+
out[i, j] = accum[lab, j]
267+
else:
268+
y = val - compensation[lab, j]
269+
t = accum[lab, j] + y
270+
compensation[lab, j] = t - accum[lab, j] - y
271+
accum[lab, j] = t
272+
out[i, j] = t
261273
else:
262274
out[i, j] = NaN
263275
if not skipna:
@@ -556,6 +568,9 @@ def group_add(add_t[:, ::1] out,
556568
for j in range(K):
557569
val = values[i, j]
558570

571+
if (val == MAXfloat64) or (val == MINfloat64):
572+
sumx[lab, j] = val
573+
break
559574
# not nan
560575
if val == val:
561576
nobs[lab, j] += 1

pandas/_libs/window/aggregations.pyx

+8
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,10 @@ cdef inline void add_sum(float64_t val, int64_t *nobs, float64_t *sum_x,
100100
t = sum_x[0] + y
101101
compensation[0] = t - sum_x[0] - y
102102
sum_x[0] = t
103+
if (val == MINfloat64) or (val == MAXfloat64):
104+
sum_x[0] = val
105+
nobs[0] = nobs[0] + 1
106+
compensation[0] = 0
103107

104108

105109
cdef inline void remove_sum(float64_t val, int64_t *nobs, float64_t *sum_x,
@@ -116,6 +120,10 @@ cdef inline void remove_sum(float64_t val, int64_t *nobs, float64_t *sum_x,
116120
t = sum_x[0] + y
117121
compensation[0] = t - sum_x[0] - y
118122
sum_x[0] = t
123+
if (val == MINfloat64) or (val == MAXfloat64):
124+
sum_x[0] = val
125+
nobs[0] = nobs[0] - 1
126+
compensation[0] = 0
119127

120128

121129
def roll_sum(const float64_t[:] values, ndarray[int64_t] start,

pandas/tests/groupby/test_function.py

+28
Original file line numberDiff line numberDiff line change
@@ -1162,3 +1162,31 @@ def test_mean_on_timedelta():
11621162
pd.to_timedelta([4, 5]), name="time", index=Index(["A", "B"], name="cat")
11631163
)
11641164
tm.assert_series_equal(result, expected)
1165+
1166+
1167+
def test_sum_with_nan_inf():
1168+
df = DataFrame(
1169+
{"a": ["hello", "hello", "world", "world"], "b": [np.inf, 10, np.nan, 10]}
1170+
)
1171+
gb = df.groupby("a")
1172+
result = gb.sum()
1173+
expected = DataFrame(
1174+
[np.inf, 10], index=Index(["hello", "world"], name="a"), columns=["b"]
1175+
)
1176+
tm.assert_frame_equal(result, expected)
1177+
1178+
1179+
def test_cumsum_inf():
1180+
ser = Series([np.inf, 1, 1])
1181+
1182+
result = ser.groupby([1, 1, 1]).cumsum()
1183+
expected = Series([np.inf, np.inf, np.inf])
1184+
tm.assert_series_equal(result, expected)
1185+
1186+
1187+
def test_cumsum_ninf_inf():
1188+
ser = Series([np.inf, 1, 1, -np.inf, 1])
1189+
1190+
result = ser.groupby([1, 1, 1, 1, 1]).cumsum()
1191+
expected = Series([np.inf, np.inf, np.inf, np.nan, np.nan])
1192+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)