@@ -513,6 +513,15 @@ ctypedef fused mean_t:
513
513
514
514
ctypedef fused sum_t:
515
515
mean_t
516
+ int8_t
517
+ int16_t
518
+ int32_t
519
+ int64_t
520
+
521
+ uint8_t
522
+ uint16_t
523
+ uint32_t
524
+ uint64_t
516
525
object
517
526
518
527
@@ -523,6 +532,8 @@ def group_sum(
523
532
int64_t[::1] counts ,
524
533
ndarray[sum_t , ndim = 2 ] values,
525
534
const intp_t[::1] labels ,
535
+ const uint8_t[:, :] mask ,
536
+ uint8_t[:, ::1] result_mask = None ,
526
537
Py_ssize_t min_count = 0 ,
527
538
bint is_datetimelike = False ,
528
539
) -> None:
@@ -535,6 +546,8 @@ def group_sum(
535
546
sum_t[:, ::1] sumx , compensation
536
547
int64_t[:, ::1] nobs
537
548
Py_ssize_t len_values = len (values), len_labels = len (labels)
549
+ bint uses_mask = mask is not None
550
+ bint isna_entry
538
551
539
552
if len_values != len_labels:
540
553
raise ValueError("len(index ) != len(labels )")
@@ -572,7 +585,8 @@ def group_sum(
572
585
for i in range (ncounts):
573
586
for j in range (K):
574
587
if nobs[i, j] < min_count:
575
- out[i, j] = NAN
588
+ out[i, j] = None
589
+
576
590
else :
577
591
out[i, j] = sumx[i, j]
578
592
else :
@@ -590,11 +604,18 @@ def group_sum(
590
604
# With dt64/td64 values, values have been cast to float64
591
605
# instead if int64 for group_sum, but the logic
592
606
# is otherwise the same as in _treat_as_na
593
- if val == val and not (
594
- sum_t is float64_t
595
- and is_datetimelike
596
- and val == < float64_t> NPY_NAT
597
- ):
607
+ if uses_mask:
608
+ isna_entry = mask[i, j]
609
+ elif (sum_t is float32_t or sum_t is float64_t
610
+ or sum_t is complex64_t or sum_t is complex64_t):
611
+ # avoid warnings because of equality comparison
612
+ isna_entry = not val == val
613
+ elif sum_t is int64_t and is_datetimelike and val == NPY_NAT:
614
+ isna_entry = True
615
+ else :
616
+ isna_entry = False
617
+
618
+ if not isna_entry:
598
619
nobs[lab, j] += 1
599
620
y = val - compensation[lab, j]
600
621
t = sumx[lab, j] + y
@@ -604,7 +625,23 @@ def group_sum(
604
625
for i in range (ncounts):
605
626
for j in range (K):
606
627
if nobs[i, j] < min_count:
607
- out[i, j] = NAN
628
+ # if we are integer dtype, not is_datetimelike, and
629
+ # not uses_mask, then getting here implies that
630
+ # counts[i] < min_count, which means we will
631
+ # be cast to float64 and masked at the end
632
+ # of WrappedCythonOp._call_cython_op. So we can safely
633
+ # set a placeholder value in out[i, j].
634
+ if uses_mask:
635
+ result_mask[i, j] = True
636
+ elif (sum_t is float32_t or sum_t is float64_t
637
+ or sum_t is complex64_t or sum_t is complex64_t):
638
+ out[i, j] = NAN
639
+ elif sum_t is int64_t:
640
+ out[i, j] = NPY_NAT
641
+ else :
642
+ # placeholder, see above
643
+ out[i, j] = 0
644
+
608
645
else :
609
646
out[i, j] = sumx[i, j]
610
647
0 commit comments