Skip to content

Commit 76c21ba

Browse files
jbrockmendelyehoshuadimarsky
authored andcommitted
PERF: avoid casting in groupby.min/max (pandas-dev#46745)
1 parent 68a66a2 commit 76c21ba

File tree

3 files changed

+69
-83
lines changed

3 files changed

+69
-83
lines changed

pandas/_libs/dtypes.pxd

-12
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,3 @@ ctypedef fused numeric_t:
3434
ctypedef fused numeric_object_t:
3535
numeric_t
3636
object
37-
38-
# i64 + u64 + all float types
39-
ctypedef fused iu_64_floating_t:
40-
float64_t
41-
float32_t
42-
int64_t
43-
uint64_t
44-
45-
# i64 + u64 + all float types + object
46-
ctypedef fused iu_64_floating_obj_t:
47-
iu_64_floating_t
48-
object

pandas/_libs/groupby.pyx

+64-62
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,6 @@ from pandas._libs.algos import (
4444
)
4545

4646
from pandas._libs.dtypes cimport (
47-
iu_64_floating_obj_t,
48-
iu_64_floating_t,
4947
numeric_object_t,
5048
numeric_t,
5149
)
@@ -1019,13 +1017,13 @@ cdef numeric_t _get_na_val(numeric_t val, bint is_datetimelike):
10191017

10201018

10211019
# TODO(cython3): GH#31710 use memorviews once cython 0.30 is released so we can
1022-
# use `const iu_64_floating_obj_t[:, :] values`
1020+
# use `const numeric_object_t[:, :] values`
10231021
@cython.wraparound(False)
10241022
@cython.boundscheck(False)
10251023
def group_last(
1026-
iu_64_floating_obj_t[:, ::1] out,
1024+
numeric_object_t[:, ::1] out,
10271025
int64_t[::1] counts,
1028-
ndarray[iu_64_floating_obj_t, ndim=2] values,
1026+
ndarray[numeric_object_t, ndim=2] values,
10291027
const intp_t[::1] labels,
10301028
const uint8_t[:, :] mask,
10311029
uint8_t[:, ::1] result_mask=None,
@@ -1037,8 +1035,8 @@ def group_last(
10371035
"""
10381036
cdef:
10391037
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
1040-
iu_64_floating_obj_t val
1041-
ndarray[iu_64_floating_obj_t, ndim=2] resx
1038+
numeric_object_t val
1039+
ndarray[numeric_object_t, ndim=2] resx
10421040
ndarray[int64_t, ndim=2] nobs
10431041
bint uses_mask = mask is not None
10441042
bint isna_entry
@@ -1050,14 +1048,14 @@ def group_last(
10501048

10511049
min_count = max(min_count, 1)
10521050
nobs = np.zeros((<object>out).shape, dtype=np.int64)
1053-
if iu_64_floating_obj_t is object:
1051+
if numeric_object_t is object:
10541052
resx = np.empty((<object>out).shape, dtype=object)
10551053
else:
10561054
resx = np.empty_like(out)
10571055

10581056
N, K = (<object>values).shape
10591057

1060-
if iu_64_floating_obj_t is object:
1058+
if numeric_object_t is object:
10611059
# TODO(cython3): De-duplicate once conditional-nogil is available
10621060
for i in range(N):
10631061
lab = labels[i]
@@ -1118,28 +1116,27 @@ def group_last(
11181116
# set a placeholder value in out[i, j].
11191117
if uses_mask:
11201118
result_mask[i, j] = True
1121-
elif iu_64_floating_obj_t is int64_t:
1119+
elif numeric_object_t is float32_t or numeric_object_t is float64_t:
1120+
out[i, j] = NAN
1121+
elif numeric_object_t is int64_t:
11221122
# Per above, this is a placeholder in
11231123
# non-is_datetimelike cases.
11241124
out[i, j] = NPY_NAT
1125-
elif iu_64_floating_obj_t is uint64_t:
1125+
else:
11261126
# placeholder, see above
11271127
out[i, j] = 0
1128-
else:
1129-
out[i, j] = NAN
1130-
11311128
else:
11321129
out[i, j] = resx[i, j]
11331130

11341131

11351132
# TODO(cython3): GH#31710 use memorviews once cython 0.30 is released so we can
1136-
# use `const iu_64_floating_obj_t[:, :] values`
1133+
# use `const numeric_object_t[:, :] values`
11371134
@cython.wraparound(False)
11381135
@cython.boundscheck(False)
11391136
def group_nth(
1140-
iu_64_floating_obj_t[:, ::1] out,
1137+
numeric_object_t[:, ::1] out,
11411138
int64_t[::1] counts,
1142-
ndarray[iu_64_floating_obj_t, ndim=2] values,
1139+
ndarray[numeric_object_t, ndim=2] values,
11431140
const intp_t[::1] labels,
11441141
const uint8_t[:, :] mask,
11451142
uint8_t[:, ::1] result_mask=None,
@@ -1152,8 +1149,8 @@ def group_nth(
11521149
"""
11531150
cdef:
11541151
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
1155-
iu_64_floating_obj_t val
1156-
ndarray[iu_64_floating_obj_t, ndim=2] resx
1152+
numeric_object_t val
1153+
ndarray[numeric_object_t, ndim=2] resx
11571154
ndarray[int64_t, ndim=2] nobs
11581155
bint uses_mask = mask is not None
11591156
bint isna_entry
@@ -1165,14 +1162,14 @@ def group_nth(
11651162

11661163
min_count = max(min_count, 1)
11671164
nobs = np.zeros((<object>out).shape, dtype=np.int64)
1168-
if iu_64_floating_obj_t is object:
1165+
if numeric_object_t is object:
11691166
resx = np.empty((<object>out).shape, dtype=object)
11701167
else:
11711168
resx = np.empty_like(out)
11721169

11731170
N, K = (<object>values).shape
11741171

1175-
if iu_64_floating_obj_t is object:
1172+
if numeric_object_t is object:
11761173
# TODO(cython3): De-duplicate once conditional-nogil is available
11771174
for i in range(N):
11781175
lab = labels[i]
@@ -1223,6 +1220,7 @@ def group_nth(
12231220
if nobs[lab, j] == rank:
12241221
resx[lab, j] = val
12251222

1223+
# TODO: de-dup this whoel block with group_last?
12261224
for i in range(ncounts):
12271225
for j in range(K):
12281226
if nobs[i, j] < min_count:
@@ -1235,15 +1233,16 @@ def group_nth(
12351233
if uses_mask:
12361234
result_mask[i, j] = True
12371235
out[i, j] = 0
1238-
elif iu_64_floating_obj_t is int64_t:
1236+
elif numeric_object_t is float32_t or numeric_object_t is float64_t:
1237+
out[i, j] = NAN
1238+
elif numeric_object_t is int64_t:
12391239
# Per above, this is a placeholder in
12401240
# non-is_datetimelike cases.
12411241
out[i, j] = NPY_NAT
1242-
elif iu_64_floating_obj_t is uint64_t:
1242+
else:
12431243
# placeholder, see above
12441244
out[i, j] = 0
1245-
else:
1246-
out[i, j] = NAN
1245+
12471246
else:
12481247
out[i, j] = resx[i, j]
12491248

@@ -1252,7 +1251,7 @@ def group_nth(
12521251
@cython.wraparound(False)
12531252
def group_rank(
12541253
float64_t[:, ::1] out,
1255-
ndarray[iu_64_floating_obj_t, ndim=2] values,
1254+
ndarray[numeric_object_t, ndim=2] values,
12561255
const intp_t[::1] labels,
12571256
int ngroups,
12581257
bint is_datetimelike,
@@ -1268,7 +1267,7 @@ def group_rank(
12681267
----------
12691268
out : np.ndarray[np.float64, ndim=2]
12701269
Values to which this method will write its results.
1271-
values : np.ndarray of iu_64_floating_obj_t values to be ranked
1270+
values : np.ndarray of numeric_object_t values to be ranked
12721271
labels : np.ndarray[np.intp]
12731272
Array containing unique label for each group, with its ordering
12741273
matching up to the corresponding record in `values`
@@ -1322,14 +1321,13 @@ def group_rank(
13221321
# group_min, group_max
13231322
# ----------------------------------------------------------------------
13241323

1325-
# TODO: consider implementing for more dtypes
13261324

13271325
@cython.wraparound(False)
13281326
@cython.boundscheck(False)
13291327
cdef group_min_max(
1330-
iu_64_floating_t[:, ::1] out,
1328+
numeric_t[:, ::1] out,
13311329
int64_t[::1] counts,
1332-
ndarray[iu_64_floating_t, ndim=2] values,
1330+
ndarray[numeric_t, ndim=2] values,
13331331
const intp_t[::1] labels,
13341332
Py_ssize_t min_count=-1,
13351333
bint is_datetimelike=False,
@@ -1342,7 +1340,7 @@ cdef group_min_max(
13421340
13431341
Parameters
13441342
----------
1345-
out : np.ndarray[iu_64_floating_t, ndim=2]
1343+
out : np.ndarray[numeric_t, ndim=2]
13461344
Array to store result in.
13471345
counts : np.ndarray[int64]
13481346
Input as a zeroed array, populated by group sizes during algorithm
@@ -1371,8 +1369,8 @@ cdef group_min_max(
13711369
"""
13721370
cdef:
13731371
Py_ssize_t i, j, N, K, lab, ngroups = len(counts)
1374-
iu_64_floating_t val, nan_val
1375-
ndarray[iu_64_floating_t, ndim=2] group_min_or_max
1372+
numeric_t val, nan_val
1373+
ndarray[numeric_t, ndim=2] group_min_or_max
13761374
int64_t[:, ::1] nobs
13771375
bint uses_mask = mask is not None
13781376
bint isna_entry
@@ -1386,16 +1384,20 @@ cdef group_min_max(
13861384
nobs = np.zeros((<object>out).shape, dtype=np.int64)
13871385

13881386
group_min_or_max = np.empty_like(out)
1389-
group_min_or_max[:] = _get_min_or_max(<iu_64_floating_t>0, compute_max, is_datetimelike)
1387+
group_min_or_max[:] = _get_min_or_max(<numeric_t>0, compute_max, is_datetimelike)
13901388

1391-
if iu_64_floating_t is int64_t:
1389+
# NB: We do not define nan_val because there is no such thing
1390+
# for uint64_t. We carefully avoid having to reference it in this
1391+
# case.
1392+
if numeric_t is int64_t:
13921393
nan_val = NPY_NAT
1393-
elif iu_64_floating_t is uint64_t:
1394-
# NB: We do not define nan_val because there is no such thing
1395-
# for uint64_t. We carefully avoid having to reference it in this
1396-
# case.
1397-
pass
1398-
else:
1394+
elif numeric_t is int32_t:
1395+
nan_val = util.INT32_MIN
1396+
elif numeric_t is int16_t:
1397+
nan_val = util.INT16_MIN
1398+
elif numeric_t is int8_t:
1399+
nan_val = util.INT8_MIN
1400+
elif numeric_t is float64_t or numeric_t is float32_t:
13991401
nan_val = NAN
14001402

14011403
N, K = (<object>values).shape
@@ -1439,25 +1441,25 @@ cdef group_min_max(
14391441
# it was initialized with np.empty. Also ensures
14401442
# we can downcast out if appropriate.
14411443
out[i, j] = 0
1442-
elif iu_64_floating_t is int64_t:
1444+
elif numeric_t is float32_t or numeric_t is float64_t:
1445+
out[i, j] = nan_val
1446+
elif numeric_t is int64_t:
14431447
# Per above, this is a placeholder in
14441448
# non-is_datetimelike cases.
14451449
out[i, j] = nan_val
1446-
elif iu_64_floating_t is uint64_t:
1450+
else:
14471451
# placeholder, see above
14481452
out[i, j] = 0
1449-
else:
1450-
out[i, j] = nan_val
14511453
else:
14521454
out[i, j] = group_min_or_max[i, j]
14531455

14541456

14551457
@cython.wraparound(False)
14561458
@cython.boundscheck(False)
14571459
def group_max(
1458-
iu_64_floating_t[:, ::1] out,
1460+
numeric_t[:, ::1] out,
14591461
int64_t[::1] counts,
1460-
ndarray[iu_64_floating_t, ndim=2] values,
1462+
ndarray[numeric_t, ndim=2] values,
14611463
const intp_t[::1] labels,
14621464
Py_ssize_t min_count=-1,
14631465
bint is_datetimelike=False,
@@ -1481,9 +1483,9 @@ def group_max(
14811483
@cython.wraparound(False)
14821484
@cython.boundscheck(False)
14831485
def group_min(
1484-
iu_64_floating_t[:, ::1] out,
1486+
numeric_t[:, ::1] out,
14851487
int64_t[::1] counts,
1486-
ndarray[iu_64_floating_t, ndim=2] values,
1488+
ndarray[numeric_t, ndim=2] values,
14871489
const intp_t[::1] labels,
14881490
Py_ssize_t min_count=-1,
14891491
bint is_datetimelike=False,
@@ -1507,8 +1509,8 @@ def group_min(
15071509
@cython.boundscheck(False)
15081510
@cython.wraparound(False)
15091511
cdef group_cummin_max(
1510-
iu_64_floating_t[:, ::1] out,
1511-
ndarray[iu_64_floating_t, ndim=2] values,
1512+
numeric_t[:, ::1] out,
1513+
ndarray[numeric_t, ndim=2] values,
15121514
const uint8_t[:, ::1] mask,
15131515
uint8_t[:, ::1] result_mask,
15141516
const intp_t[::1] labels,
@@ -1522,9 +1524,9 @@ cdef group_cummin_max(
15221524
15231525
Parameters
15241526
----------
1525-
out : np.ndarray[iu_64_floating_t, ndim=2]
1527+
out : np.ndarray[numeric_t, ndim=2]
15261528
Array to store cummin/max in.
1527-
values : np.ndarray[iu_64_floating_t, ndim=2]
1529+
values : np.ndarray[numeric_t, ndim=2]
15281530
Values to take cummin/max of.
15291531
mask : np.ndarray[bool] or None
15301532
If not None, indices represent missing values,
@@ -1549,25 +1551,25 @@ cdef group_cummin_max(
15491551
This method modifies the `out` parameter, rather than returning an object.
15501552
"""
15511553
cdef:
1552-
iu_64_floating_t[:, ::1] accum
1554+
numeric_t[:, ::1] accum
15531555
Py_ssize_t i, j, N, K
1554-
iu_64_floating_t val, mval, na_val
1556+
numeric_t val, mval, na_val
15551557
uint8_t[:, ::1] seen_na
15561558
intp_t lab
15571559
bint na_possible
15581560
bint uses_mask = mask is not None
15591561
bint isna_entry
15601562

15611563
accum = np.empty((ngroups, (<object>values).shape[1]), dtype=values.dtype)
1562-
accum[:] = _get_min_or_max(<iu_64_floating_t>0, compute_max, is_datetimelike)
1564+
accum[:] = _get_min_or_max(<numeric_t>0, compute_max, is_datetimelike)
15631565

1564-
na_val = _get_na_val(<iu_64_floating_t>0, is_datetimelike)
1566+
na_val = _get_na_val(<numeric_t>0, is_datetimelike)
15651567

15661568
if uses_mask:
15671569
na_possible = True
15681570
# Will never be used, just to avoid uninitialized warning
15691571
na_val = 0
1570-
elif iu_64_floating_t is float64_t or iu_64_floating_t is float32_t:
1572+
elif numeric_t is float64_t or numeric_t is float32_t:
15711573
na_possible = True
15721574
elif is_datetimelike:
15731575
na_possible = True
@@ -1620,8 +1622,8 @@ cdef group_cummin_max(
16201622
@cython.boundscheck(False)
16211623
@cython.wraparound(False)
16221624
def group_cummin(
1623-
iu_64_floating_t[:, ::1] out,
1624-
ndarray[iu_64_floating_t, ndim=2] values,
1625+
numeric_t[:, ::1] out,
1626+
ndarray[numeric_t, ndim=2] values,
16251627
const intp_t[::1] labels,
16261628
int ngroups,
16271629
bint is_datetimelike,
@@ -1646,8 +1648,8 @@ def group_cummin(
16461648
@cython.boundscheck(False)
16471649
@cython.wraparound(False)
16481650
def group_cummax(
1649-
iu_64_floating_t[:, ::1] out,
1650-
ndarray[iu_64_floating_t, ndim=2] values,
1651+
numeric_t[:, ::1] out,
1652+
ndarray[numeric_t, ndim=2] values,
16511653
const intp_t[::1] labels,
16521654
int ngroups,
16531655
bint is_datetimelike,

pandas/core/groupby/ops.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -507,15 +507,9 @@ def _call_cython_op(
507507
values = values.view("int64")
508508
is_numeric = True
509509
elif is_bool_dtype(dtype):
510-
values = values.astype("int64")
511-
elif is_integer_dtype(dtype):
512-
# GH#43329 If the dtype is explicitly of type uint64 the type is not
513-
# changed to prevent overflow.
514-
if dtype != np.uint64:
515-
values = values.astype(np.int64, copy=False)
516-
elif is_numeric:
517-
if not is_complex_dtype(dtype):
518-
values = ensure_float64(values)
510+
values = values.view("uint8")
511+
if values.dtype == "float16":
512+
values = values.astype(np.float32)
519513

520514
values = values.T
521515
if mask is not None:
@@ -601,6 +595,8 @@ def _call_cython_op(
601595
if self.how not in self.cast_blocklist:
602596
# e.g. if we are int64 and need to restore to datetime64/timedelta64
603597
# "rank" is the only member of cast_blocklist we get here
598+
# Casting only needed for float16, bool, datetimelike,
599+
# and self.how in ["add", "prod", "ohlc", "cumprod"]
604600
res_dtype = self._get_result_dtype(orig_values.dtype)
605601
op_result = maybe_downcast_to_dtype(result, res_dtype)
606602
else:

0 commit comments

Comments
 (0)