Skip to content

Commit a0d01b8

Browse files
jbrockmendeljreback
authored andcommitted
add uint64 support for some libgroupby funcs (#28931)
1 parent 79a5f7c commit a0d01b8

File tree

3 files changed

+69
-3
lines changed

3 files changed

+69
-3
lines changed

pandas/_libs/groupby_helper.pxi.in

+60-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ ctypedef fused rank_t:
1616
float64_t
1717
float32_t
1818
int64_t
19+
uint64_t
1920
object
2021

2122

@@ -34,6 +35,7 @@ def group_last(rank_t[:, :] out,
3435
rank_t val
3536
ndarray[rank_t, ndim=2] resx
3637
ndarray[int64_t, ndim=2] nobs
38+
bint runtime_error = False
3739

3840
assert min_count == -1, "'min_count' only used in add and prod"
3941

@@ -106,11 +108,20 @@ def group_last(rank_t[:, :] out,
106108
if nobs[i, j] == 0:
107109
if rank_t is int64_t:
108110
out[i, j] = NPY_NAT
111+
elif rank_t is uint64_t:
112+
runtime_error = True
113+
break
109114
else:
110115
out[i, j] = NAN
111116
else:
112117
out[i, j] = resx[i, j]
113118

119+
if runtime_error:
120+
# We cannot raise directly above because that is within a nogil
121+
# block.
122+
raise RuntimeError("empty group with uint64_t")
123+
124+
114125
group_last_float64 = group_last["float64_t"]
115126
group_last_float32 = group_last["float32_t"]
116127
group_last_int64 = group_last["int64_t"]
@@ -132,6 +143,7 @@ def group_nth(rank_t[:, :] out,
132143
rank_t val
133144
ndarray[rank_t, ndim=2] resx
134145
ndarray[int64_t, ndim=2] nobs
146+
bint runtime_error = False
135147

136148
assert min_count == -1, "'min_count' only used in add and prod"
137149

@@ -199,11 +211,19 @@ def group_nth(rank_t[:, :] out,
199211
if nobs[i, j] == 0:
200212
if rank_t is int64_t:
201213
out[i, j] = NPY_NAT
214+
elif rank_t is uint64_t:
215+
runtime_error = True
216+
break
202217
else:
203218
out[i, j] = NAN
204219
else:
205220
out[i, j] = resx[i, j]
206221

222+
if runtime_error:
223+
# We cannot raise directly above because that is within a nogil
224+
# block.
225+
raise RuntimeError("empty group with uint64_t")
226+
207227

208228
group_nth_float64 = group_nth["float64_t"]
209229
group_nth_float32 = group_nth["float32_t"]
@@ -282,12 +302,16 @@ def group_rank(float64_t[:, :] out,
282302
if ascending ^ (na_option == 'top'):
283303
if rank_t is int64_t:
284304
nan_fill_val = np.iinfo(np.int64).max
305+
elif rank_t is uint64_t:
306+
nan_fill_val = np.iinfo(np.uint64).max
285307
else:
286308
nan_fill_val = np.inf
287309
order = (masked_vals, mask, labels)
288310
else:
289311
if rank_t is int64_t:
290312
nan_fill_val = np.iinfo(np.int64).min
313+
elif rank_t is uint64_t:
314+
nan_fill_val = 0
291315
else:
292316
nan_fill_val = -np.inf
293317

@@ -397,6 +421,7 @@ def group_rank(float64_t[:, :] out,
397421
group_rank_float64 = group_rank["float64_t"]
398422
group_rank_float32 = group_rank["float32_t"]
399423
group_rank_int64 = group_rank["int64_t"]
424+
group_rank_uint64 = group_rank["uint64_t"]
400425
# Note: we do not have a group_rank_object because that would require a
401426
# not-nogil implementation, see GH#19560
402427

@@ -410,6 +435,7 @@ ctypedef fused groupby_t:
410435
float64_t
411436
float32_t
412437
int64_t
438+
uint64_t
413439

414440

415441
@cython.wraparound(False)
@@ -426,6 +452,7 @@ def group_max(groupby_t[:, :] out,
426452
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
427453
groupby_t val, count, nan_val
428454
ndarray[groupby_t, ndim=2] maxx, nobs
455+
bint runtime_error = False
429456

430457
assert min_count == -1, "'min_count' only used in add and prod"
431458

@@ -439,6 +466,11 @@ def group_max(groupby_t[:, :] out,
439466
# Note: evaluated at compile-time
440467
maxx[:] = -_int64_max
441468
nan_val = NPY_NAT
469+
elif groupby_t is uint64_t:
470+
# NB: We do not define nan_val because there is no such thing
471+
# for uint64_t. We carefully avoid having to reference it in this
472+
# case.
473+
maxx[:] = 0
442474
else:
443475
maxx[:] = -np.inf
444476
nan_val = NAN
@@ -462,18 +494,26 @@ def group_max(groupby_t[:, :] out,
462494
if val > maxx[lab, j]:
463495
maxx[lab, j] = val
464496
else:
465-
if val == val and val != nan_val:
497+
if val == val:
466498
nobs[lab, j] += 1
467499
if val > maxx[lab, j]:
468500
maxx[lab, j] = val
469501

470502
for i in range(ncounts):
471503
for j in range(K):
472504
if nobs[i, j] == 0:
505+
if groupby_t is uint64_t:
506+
runtime_error = True
507+
break
473508
out[i, j] = nan_val
474509
else:
475510
out[i, j] = maxx[i, j]
476511

512+
if runtime_error:
513+
# We cannot raise directly above because that is within a nogil
514+
# block.
515+
raise RuntimeError("empty group with uint64_t")
516+
477517

478518
@cython.wraparound(False)
479519
@cython.boundscheck(False)
@@ -489,6 +529,7 @@ def group_min(groupby_t[:, :] out,
489529
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
490530
groupby_t val, count, nan_val
491531
ndarray[groupby_t, ndim=2] minx, nobs
532+
bint runtime_error = False
492533

493534
assert min_count == -1, "'min_count' only used in add and prod"
494535

@@ -501,6 +542,11 @@ def group_min(groupby_t[:, :] out,
501542
if groupby_t is int64_t:
502543
minx[:] = _int64_max
503544
nan_val = NPY_NAT
545+
elif groupby_t is uint64_t:
546+
# NB: We do not define nan_val because there is no such thing
547+
# for uint64_t. We carefully avoid having to reference it in this
548+
# case.
549+
minx[:] = np.iinfo(np.uint64).max
504550
else:
505551
minx[:] = np.inf
506552
nan_val = NAN
@@ -524,18 +570,26 @@ def group_min(groupby_t[:, :] out,
524570
if val < minx[lab, j]:
525571
minx[lab, j] = val
526572
else:
527-
if val == val and val != nan_val:
573+
if val == val:
528574
nobs[lab, j] += 1
529575
if val < minx[lab, j]:
530576
minx[lab, j] = val
531577

532578
for i in range(ncounts):
533579
for j in range(K):
534580
if nobs[i, j] == 0:
581+
if groupby_t is uint64_t:
582+
runtime_error = True
583+
break
535584
out[i, j] = nan_val
536585
else:
537586
out[i, j] = minx[i, j]
538587

588+
if runtime_error:
589+
# We cannot raise directly above because that is within a nogil
590+
# block.
591+
raise RuntimeError("empty group with uint64_t")
592+
539593

540594
@cython.boundscheck(False)
541595
@cython.wraparound(False)
@@ -575,6 +629,8 @@ def group_cummin(groupby_t[:, :] out,
575629
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
576630
if groupby_t is int64_t:
577631
accum[:] = _int64_max
632+
elif groupby_t is uint64_t:
633+
accum[:] = np.iinfo(np.uint64).max
578634
else:
579635
accum[:] = np.inf
580636

@@ -642,6 +698,8 @@ def group_cummax(groupby_t[:, :] out,
642698
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
643699
if groupby_t is int64_t:
644700
accum[:] = -_int64_max
701+
elif groupby_t is uint64_t:
702+
accum[:] = 0
645703
else:
646704
accum[:] = -np.inf
647705

pandas/core/groupby/groupby.py

+8
Original file line numberDiff line numberDiff line change
@@ -1355,7 +1355,15 @@ def f(self, **kwargs):
13551355
return self._cython_agg_general(alias, alt=npfunc, **kwargs)
13561356
except AssertionError as e:
13571357
raise SpecificationError(str(e))
1358+
except DataError:
1359+
pass
13581360
except Exception:
1361+
# TODO: the remaining test cases that get here are from:
1362+
# - AttributeError from _cython_agg_blocks bug passing
1363+
# DataFrame to make_block; see GH#28275
1364+
# - TypeError in _cython_operation calling ensure_float64
1365+
# on object array containing complex numbers;
1366+
# see test_groupby_complex, test_max_nan_bug
13591367
pass
13601368

13611369
# apply a non-cython aggregation

pandas/tests/groupby/test_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,7 @@ def test_median_empty_bins(observed):
378378

379379

380380
@pytest.mark.parametrize(
381-
"dtype", ["int8", "int16", "int32", "int64", "float32", "float64"]
381+
"dtype", ["int8", "int16", "int32", "int64", "float32", "float64", "uint64"]
382382
)
383383
@pytest.mark.parametrize(
384384
"method,data",

0 commit comments

Comments
 (0)