@@ -67,6 +67,14 @@ cdef enum InterpolationEnumType:
67
67
INTERPOLATION_NEAREST,
68
68
INTERPOLATION_MIDPOINT
69
69
70
+ cdef inline bint check_inf(numeric_t ai) nogil:
71
+
72
+ if numeric_t == cython.float:
73
+ if (ai == MINfloat32) or (ai == MAXfloat32):
74
+ return True
75
+ else :
76
+ if (ai == MINfloat64) or (ai == MAXfloat64):
77
+ return True
70
78
71
79
cdef inline float64_t median_linear(float64_t* a, int n) nogil:
72
80
cdef:
@@ -258,40 +266,18 @@ def group_cumsum(numeric_t[:, ::1] out,
258
266
259
267
# For floats, use Kahan summation to reduce floating-point
260
268
# error (https://en.wikipedia.org/wiki/Kahan_summation_algorithm)
261
- if numeric_t == float32_t:
262
- if (val == MAXfloat32) or (val == MINfloat32):
263
- if (t == MAXfloat32) or (t == MINfloat32):
264
- val = t
265
- out[i, j] = val
266
- accum[lab, j] = val
267
- break
268
- elif val == val:
269
- y = val - compensation[lab, j]
270
- t = accum[lab, j] + y
271
- compensation[lab, j] = t - accum[lab, j] - y
272
- accum[lab, j] = t
273
- out[i, j] = t
274
- else :
275
- out[i, j] = NaN
276
- if not skipna:
277
- accum[lab, j] = NaN
278
- break
279
- elif numeric_t == float64_t:
280
- if (val == MAXfloat64) or (val == MINfloat64):
281
- if (t == MAXfloat64) or (t == MINfloat64):
282
- val = t
283
- out[i, j] = val
284
- accum[lab, j] = val
285
- break
286
- elif val == val:
287
- y = val - compensation[lab, j]
288
- t = accum[lab, j] + y
289
- compensation[lab, j] = t - accum[lab, j] - y
290
- accum[lab, j] = t
291
- out[i, j] = t
292
- if (t == MAXfloat64) or (t == MINfloat64):
293
- compensation[lab, j] = 0
294
- break
269
+ if (numeric_t == float32_t) or (numeric_t == float64_t):
270
+ if val == val:
271
+ # if val or accum are inf/-inf don't use kahan
272
+ if check_inf(val) or check_inf(accum[lab, j]):
273
+ accum[lab, j] += val
274
+ out[i, j] = accum[lab, j]
275
+ else :
276
+ y = val - compensation[lab, j]
277
+ t = accum[lab, j] + y
278
+ compensation[lab, j] = t - accum[lab, j] - y
279
+ accum[lab, j] = t
280
+ out[i, j] = t
295
281
else :
296
282
out[i, j] = NaN
297
283
if not skipna:
0 commit comments