@@ -20,6 +20,18 @@ ctypedef fused rank_t:
20
20
object
21
21
22
22
23
+ cdef inline bint _treat_as_na(rank_t val, bint is_datetimelike) nogil:
24
+ if rank_t is object:
25
+ # Should never be used, but we need to avoid the `val != val` below
26
+ # or else cython will raise about gil acquisition.
27
+ raise NotImplementedError
28
+
29
+ elif rank_t is int64_t:
30
+ return is_datetimelike and val == NPY_NAT
31
+ else:
32
+ return val != val
33
+
34
+
23
35
@cython.wraparound(False)
24
36
@cython.boundscheck(False)
25
37
def group_last(rank_t[:, :] out,
@@ -61,24 +73,16 @@ def group_last(rank_t[:, :] out,
61
73
for j in range(K):
62
74
val = values[i, j]
63
75
64
- # not nan
65
- if rank_t is int64_t:
66
- # need a special notna check
67
- if val != NPY_NAT:
68
- nobs[lab, j] += 1
69
- resx[lab, j] = val
70
- else:
71
- if val == val:
72
- nobs[lab, j] += 1
73
- resx[lab, j] = val
76
+ if val == val:
77
+ # NB: use _treat_as_na here once
78
+ # conditional-nogil is available.
79
+ nobs[lab, j] += 1
80
+ resx[lab, j] = val
74
81
75
82
for i in range(ncounts):
76
83
for j in range(K):
77
84
if nobs[i, j] == 0:
78
- if rank_t is int64_t:
79
- out[i, j] = NPY_NAT
80
- else:
81
- out[i, j] = NAN
85
+ out[i, j] = NAN
82
86
else:
83
87
out[i, j] = resx[i, j]
84
88
else:
@@ -92,16 +96,10 @@ def group_last(rank_t[:, :] out,
92
96
for j in range(K):
93
97
val = values[i, j]
94
98
95
- # not nan
96
- if rank_t is int64_t:
97
- # need a special notna check
98
- if val != NPY_NAT:
99
- nobs[lab, j] += 1
100
- resx[lab, j] = val
101
- else:
102
- if val == val:
103
- nobs[lab, j] += 1
104
- resx[lab, j] = val
99
+ if not _treat_as_na(val, True):
100
+ # TODO: Sure we always want is_datetimelike=True?
101
+ nobs[lab, j] += 1
102
+ resx[lab, j] = val
105
103
106
104
for i in range(ncounts):
107
105
for j in range(K):
@@ -113,6 +111,7 @@ def group_last(rank_t[:, :] out,
113
111
break
114
112
else:
115
113
out[i, j] = NAN
114
+
116
115
else:
117
116
out[i, j] = resx[i, j]
118
117
@@ -121,7 +120,6 @@ def group_last(rank_t[:, :] out,
121
120
# block.
122
121
raise RuntimeError("empty group with uint64_t")
123
122
124
-
125
123
group_last_float64 = group_last["float64_t"]
126
124
group_last_float32 = group_last["float32_t"]
127
125
group_last_int64 = group_last["int64_t"]
@@ -169,8 +167,9 @@ def group_nth(rank_t[:, :] out,
169
167
for j in range(K):
170
168
val = values[i, j]
171
169
172
- # not nan
173
170
if val == val:
171
+ # NB: use _treat_as_na here once
172
+ # conditional-nogil is available.
174
173
nobs[lab, j] += 1
175
174
if nobs[lab, j] == rank:
176
175
resx[lab, j] = val
@@ -193,18 +192,11 @@ def group_nth(rank_t[:, :] out,
193
192
for j in range(K):
194
193
val = values[i, j]
195
194
196
- # not nan
197
- if rank_t is int64_t:
198
- # need a special notna check
199
- if val != NPY_NAT:
200
- nobs[lab, j] += 1
201
- if nobs[lab, j] == rank:
202
- resx[lab, j] = val
203
- else:
204
- if val == val:
205
- nobs[lab, j] += 1
206
- if nobs[lab, j] == rank:
207
- resx[lab, j] = val
195
+ if not _treat_as_na(val, True):
196
+ # TODO: Sure we always want is_datetimelike=True?
197
+ nobs[lab, j] += 1
198
+ if nobs[lab, j] == rank:
199
+ resx[lab, j] = val
208
200
209
201
for i in range(ncounts):
210
202
for j in range(K):
@@ -487,17 +479,11 @@ def group_max(groupby_t[:, :] out,
487
479
for j in range(K):
488
480
val = values[i, j]
489
481
490
- # not nan
491
- if groupby_t is int64_t:
492
- if val != nan_val:
493
- nobs[lab, j] += 1
494
- if val > maxx[lab, j]:
495
- maxx[lab, j] = val
496
- else:
497
- if val == val:
498
- nobs[lab, j] += 1
499
- if val > maxx[lab, j]:
500
- maxx[lab, j] = val
482
+ if not _treat_as_na(val, True):
483
+ # TODO: Sure we always want is_datetimelike=True?
484
+ nobs[lab, j] += 1
485
+ if val > maxx[lab, j]:
486
+ maxx[lab, j] = val
501
487
502
488
for i in range(ncounts):
503
489
for j in range(K):
@@ -563,17 +549,11 @@ def group_min(groupby_t[:, :] out,
563
549
for j in range(K):
564
550
val = values[i, j]
565
551
566
- # not nan
567
- if groupby_t is int64_t:
568
- if val != nan_val:
569
- nobs[lab, j] += 1
570
- if val < minx[lab, j]:
571
- minx[lab, j] = val
572
- else:
573
- if val == val:
574
- nobs[lab, j] += 1
575
- if val < minx[lab, j]:
576
- minx[lab, j] = val
552
+ if not _treat_as_na(val, True):
553
+ # TODO: Sure we always want is_datetimelike=True?
554
+ nobs[lab, j] += 1
555
+ if val < minx[lab, j]:
556
+ minx[lab, j] = val
577
557
578
558
for i in range(ncounts):
579
559
for j in range(K):
@@ -643,21 +623,13 @@ def group_cummin(groupby_t[:, :] out,
643
623
for j in range(K):
644
624
val = values[i, j]
645
625
646
- # val = nan
647
- if groupby_t is int64_t:
648
- if is_datetimelike and val == NPY_NAT:
649
- out[i, j] = NPY_NAT
650
- else:
651
- mval = accum[lab, j]
652
- if val < mval:
653
- accum[lab, j] = mval = val
654
- out[i, j] = mval
626
+ if _treat_as_na(val, is_datetimelike):
627
+ out[i, j] = val
655
628
else:
656
- if val == val:
657
- mval = accum[lab, j]
658
- if val < mval:
659
- accum[lab, j] = mval = val
660
- out[i, j] = mval
629
+ mval = accum[lab, j]
630
+ if val < mval:
631
+ accum[lab, j] = mval = val
632
+ out[i, j] = mval
661
633
662
634
663
635
@cython.boundscheck(False)
@@ -712,17 +684,10 @@ def group_cummax(groupby_t[:, :] out,
712
684
for j in range(K):
713
685
val = values[i, j]
714
686
715
- if groupby_t is int64_t:
716
- if is_datetimelike and val == NPY_NAT:
717
- out[i, j] = NPY_NAT
718
- else:
719
- mval = accum[lab, j]
720
- if val > mval:
721
- accum[lab, j] = mval = val
722
- out[i, j] = mval
687
+ if _treat_as_na(val, is_datetimelike):
688
+ out[i, j] = val
723
689
else:
724
- if val == val:
725
- mval = accum[lab, j]
726
- if val > mval:
727
- accum[lab, j] = mval = val
728
- out[i, j] = mval
690
+ mval = accum[lab, j]
691
+ if val > mval:
692
+ accum[lab, j] = mval = val
693
+ out[i, j] = mval
0 commit comments