Skip to content

Commit bff90a3

Browse files
jbrockmendelWillAyd
authored andcommitted
REF: de-duplicate groupby_helper code (pandas-dev#28934)
1 parent b63f829 commit bff90a3

File tree

2 files changed

+54
-88
lines changed

2 files changed

+54
-88
lines changed

pandas/_libs/groupby.pyx

+2-1
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ def group_any_all(uint8_t[:] out,
372372
const uint8_t[:] mask,
373373
object val_test,
374374
bint skipna):
375-
"""Aggregated boolean values to show truthfulness of group elements
375+
"""
376+
Aggregated boolean values to show truthfulness of group elements.
376377
377378
Parameters
378379
----------

pandas/_libs/groupby_helper.pxi.in

+52-87
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,18 @@ ctypedef fused rank_t:
2020
object
2121

2222

23+
cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
24+
if rank_t is object:
25+
# Should never be used, but we need to avoid the `val != val` below
26+
# or else cython will raise about gil acquisition.
27+
raise NotImplementedError
28+
29+
elif rank_t is int64_t:
30+
return is_datetimelike and val == NPY_NAT
31+
else:
32+
return val != val
33+
34+
2335
@cython.wraparound(False)
2436
@cython.boundscheck(False)
2537
def group_last(rank_t[:, :] out,
@@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out,
6173
for j in range(K):
6274
val = values[i, j]
6375

64-
# not nan
65-
if rank_t is int64_t:
66-
# need a special notna check
67-
if val != NPY_NAT:
68-
nobs[lab, j] += 1
69-
resx[lab, j] = val
70-
else:
71-
if val == val:
72-
nobs[lab, j] += 1
73-
resx[lab, j] = val
76+
if val == val:
77+
# NB: use _treat_as_na here once
78+
# conditional-nogil is available.
79+
nobs[lab, j] += 1
80+
resx[lab, j] = val
7481

7582
for i in range(ncounts):
7683
for j in range(K):
7784
if nobs[i, j] == 0:
78-
if rank_t is int64_t:
79-
out[i, j] = NPY_NAT
80-
else:
81-
out[i, j] = NAN
85+
out[i, j] = NAN
8286
else:
8387
out[i, j] = resx[i, j]
8488
else:
@@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out,
9296
for j in range(K):
9397
val = values[i, j]
9498

95-
# not nan
96-
if rank_t is int64_t:
97-
# need a special notna check
98-
if val != NPY_NAT:
99-
nobs[lab, j] += 1
100-
resx[lab, j] = val
101-
else:
102-
if val == val:
103-
nobs[lab, j] += 1
104-
resx[lab, j] = val
99+
if not _treat_as_na(val, True):
100+
# TODO: Sure we always want is_datetimelike=True?
101+
nobs[lab, j] += 1
102+
resx[lab, j] = val
105103

106104
for i in range(ncounts):
107105
for j in range(K):
@@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out,
113111
break
114112
else:
115113
out[i, j] = NAN
114+
116115
else:
117116
out[i, j] = resx[i, j]
118117

@@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out,
121120
# block.
122121
raise RuntimeError("empty group with uint64_t")
123122

124-
125123
group_last_float64 = group_last["float64_t"]
126124
group_last_float32 = group_last["float32_t"]
127125
group_last_int64 = group_last["int64_t"]
@@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out,
169167
for j in range(K):
170168
val = values[i, j]
171169

172-
# not nan
173170
if val == val:
171+
# NB: use _treat_as_na here once
172+
# conditional-nogil is available.
174173
nobs[lab, j] += 1
175174
if nobs[lab, j] == rank:
176175
resx[lab, j] = val
@@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out,
193192
for j in range(K):
194193
val = values[i, j]
195194

196-
# not nan
197-
if rank_t is int64_t:
198-
# need a special notna check
199-
if val != NPY_NAT:
200-
nobs[lab, j] += 1
201-
if nobs[lab, j] == rank:
202-
resx[lab, j] = val
203-
else:
204-
if val == val:
205-
nobs[lab, j] += 1
206-
if nobs[lab, j] == rank:
207-
resx[lab, j] = val
195+
if not _treat_as_na(val, True):
196+
# TODO: Sure we always want is_datetimelike=True?
197+
nobs[lab, j] += 1
198+
if nobs[lab, j] == rank:
199+
resx[lab, j] = val
208200

209201
for i in range(ncounts):
210202
for j in range(K):
@@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out,
487479
for j in range(K):
488480
val = values[i, j]
489481

490-
# not nan
491-
if groupby_t is int64_t:
492-
if val != nan_val:
493-
nobs[lab, j] += 1
494-
if val > maxx[lab, j]:
495-
maxx[lab, j] = val
496-
else:
497-
if val == val:
498-
nobs[lab, j] += 1
499-
if val > maxx[lab, j]:
500-
maxx[lab, j] = val
482+
if not _treat_as_na(val, True):
483+
# TODO: Sure we always want is_datetimelike=True?
484+
nobs[lab, j] += 1
485+
if val > maxx[lab, j]:
486+
maxx[lab, j] = val
501487

502488
for i in range(ncounts):
503489
for j in range(K):
@@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out,
563549
for j in range(K):
564550
val = values[i, j]
565551

566-
# not nan
567-
if groupby_t is int64_t:
568-
if val != nan_val:
569-
nobs[lab, j] += 1
570-
if val < minx[lab, j]:
571-
minx[lab, j] = val
572-
else:
573-
if val == val:
574-
nobs[lab, j] += 1
575-
if val < minx[lab, j]:
576-
minx[lab, j] = val
552+
if not _treat_as_na(val, True):
553+
# TODO: Sure we always want is_datetimelike=True?
554+
nobs[lab, j] += 1
555+
if val < minx[lab, j]:
556+
minx[lab, j] = val
577557

578558
for i in range(ncounts):
579559
for j in range(K):
@@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out,
643623
for j in range(K):
644624
val = values[i, j]
645625

646-
# val = nan
647-
if groupby_t is int64_t:
648-
if is_datetimelike and val == NPY_NAT:
649-
out[i, j] = NPY_NAT
650-
else:
651-
mval = accum[lab, j]
652-
if val < mval:
653-
accum[lab, j] = mval = val
654-
out[i, j] = mval
626+
if _treat_as_na(val, is_datetimelike):
627+
out[i, j] = val
655628
else:
656-
if val == val:
657-
mval = accum[lab, j]
658-
if val < mval:
659-
accum[lab, j] = mval = val
660-
out[i, j] = mval
629+
mval = accum[lab, j]
630+
if val < mval:
631+
accum[lab, j] = mval = val
632+
out[i, j] = mval
661633

662634

663635
@cython.boundscheck(False)
@@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out,
712684
for j in range(K):
713685
val = values[i, j]
714686

715-
if groupby_t is int64_t:
716-
if is_datetimelike and val == NPY_NAT:
717-
out[i, j] = NPY_NAT
718-
else:
719-
mval = accum[lab, j]
720-
if val > mval:
721-
accum[lab, j] = mval = val
722-
out[i, j] = mval
687+
if _treat_as_na(val, is_datetimelike):
688+
out[i, j] = val
723689
else:
724-
if val == val:
725-
mval = accum[lab, j]
726-
if val > mval:
727-
accum[lab, j] = mval = val
728-
out[i, j] = mval
690+
mval = accum[lab, j]
691+
if val > mval:
692+
accum[lab, j] = mval = val
693+
out[i, j] = mval

0 commit comments

Comments
 (0)