Skip to content

Commit d37b34c

Browse files
authored
REF: support object dtype in libgroupby.group_add (#41294)
1 parent 459d440 commit d37b34c

File tree

1 file changed

+40
-9
lines changed

1 file changed

+40
-9
lines changed

pandas/_libs/groupby.pyx

+40-9
Original file line numberDiff line numberDiff line change
@@ -469,27 +469,28 @@ def group_any_all(int8_t[::1] out,
469469
# group_add, group_prod, group_var, group_mean, group_ohlc
470470
# ----------------------------------------------------------------------
471471

472-
ctypedef fused complexfloating_t:
472+
ctypedef fused add_t:
473473
float64_t
474474
float32_t
475475
complex64_t
476476
complex128_t
477+
object
477478

478479

479480
@cython.wraparound(False)
480481
@cython.boundscheck(False)
481-
def group_add(complexfloating_t[:, ::1] out,
482+
def group_add(add_t[:, ::1] out,
482483
int64_t[::1] counts,
483-
ndarray[complexfloating_t, ndim=2] values,
484+
ndarray[add_t, ndim=2] values,
484485
const intp_t[:] labels,
485486
Py_ssize_t min_count=0) -> None:
486487
"""
487488
Only aggregates on axis=0 using Kahan summation
488489
"""
489490
cdef:
490491
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
491-
complexfloating_t val, count, t, y
492-
complexfloating_t[:, ::1] sumx, compensation
492+
add_t val, t, y
493+
add_t[:, ::1] sumx, compensation
493494
int64_t[:, ::1] nobs
494495
Py_ssize_t len_values = len(values), len_labels = len(labels)
495496

@@ -503,7 +504,8 @@ def group_add(complexfloating_t[:, ::1] out,
503504

504505
N, K = (<object>values).shape
505506

506-
with nogil:
507+
if add_t is object:
508+
# NB: this does not use 'compensation' like the non-object track does.
507509
for i in range(N):
508510
lab = labels[i]
509511
if lab < 0:
@@ -516,9 +518,13 @@ def group_add(complexfloating_t[:, ::1] out,
516518
# not nan
517519
if val == val:
518520
nobs[lab, j] += 1
519-
y = val - compensation[lab, j]
520-
t = sumx[lab, j] + y
521-
compensation[lab, j] = t - sumx[lab, j] - y
521+
522+
if nobs[lab, j] == 1:
523+
# i.e. we havent added anything yet; avoid TypeError
524+
# if e.g. val is a str and sumx[lab, j] is 0
525+
t = val
526+
else:
527+
t = sumx[lab, j] + val
522528
sumx[lab, j] = t
523529

524530
for i in range(ncounts):
@@ -527,6 +533,31 @@ def group_add(complexfloating_t[:, ::1] out,
527533
out[i, j] = NAN
528534
else:
529535
out[i, j] = sumx[i, j]
536+
else:
537+
with nogil:
538+
for i in range(N):
539+
lab = labels[i]
540+
if lab < 0:
541+
continue
542+
543+
counts[lab] += 1
544+
for j in range(K):
545+
val = values[i, j]
546+
547+
# not nan
548+
if val == val:
549+
nobs[lab, j] += 1
550+
y = val - compensation[lab, j]
551+
t = sumx[lab, j] + y
552+
compensation[lab, j] = t - sumx[lab, j] - y
553+
sumx[lab, j] = t
554+
555+
for i in range(ncounts):
556+
for j in range(K):
557+
if nobs[i, j] < min_count:
558+
out[i, j] = NAN
559+
else:
560+
out[i, j] = sumx[i, j]
530561

531562

532563
@cython.wraparound(False)

0 commit comments

Comments
 (0)