Skip to content

Commit 384f414

Browse files
authored
REF: split out grouped masked cummin/max algo (#41853)
1 parent 1f3e646 commit 384f414

File tree

1 file changed

+61
-28
lines changed

1 file changed

+61
-28
lines changed

pandas/_libs/groupby.pyx

+61-28
Original file line numberDiff line numberDiff line change
@@ -1345,53 +1345,86 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
13451345
This method modifies the `out` parameter, rather than returning an object.
13461346
"""
13471347
cdef:
1348-
Py_ssize_t i, j, N, K, size
1349-
groupby_t val, mval
13501348
groupby_t[:, ::1] accum
1351-
intp_t lab
1352-
bint val_is_nan, use_mask
13531349

1354-
use_mask = mask is not None
1355-
1356-
N, K = (<object>values).shape
1357-
accum = np.empty((ngroups, K), dtype=values.dtype)
1350+
accum = np.empty((ngroups, (<object>values).shape[1]), dtype=values.dtype)
13581351
if groupby_t is int64_t:
13591352
accum[:] = -_int64_max if compute_max else _int64_max
13601353
elif groupby_t is uint64_t:
13611354
accum[:] = 0 if compute_max else np.iinfo(np.uint64).max
13621355
else:
13631356
accum[:] = -np.inf if compute_max else np.inf
13641357

1358+
if mask is not None:
1359+
masked_cummin_max(out, values, mask, labels, accum, compute_max)
1360+
else:
1361+
cummin_max(out, values, labels, accum, is_datetimelike, compute_max)
1362+
1363+
1364+
@cython.boundscheck(False)
1365+
@cython.wraparound(False)
1366+
cdef cummin_max(groupby_t[:, ::1] out,
1367+
ndarray[groupby_t, ndim=2] values,
1368+
const intp_t[:] labels,
1369+
groupby_t[:, ::1] accum,
1370+
bint is_datetimelike,
1371+
bint compute_max):
1372+
"""
1373+
Compute the cumulative minimum/maximum of columns of `values`, in row groups
1374+
`labels`.
1375+
"""
1376+
cdef:
1377+
Py_ssize_t i, j, N, K
1378+
groupby_t val, mval
1379+
intp_t lab
1380+
1381+
N, K = (<object>values).shape
13651382
with nogil:
13661383
for i in range(N):
13671384
lab = labels[i]
1368-
13691385
if lab < 0:
13701386
continue
13711387
for j in range(K):
1372-
val_is_nan = False
1373-
1374-
if use_mask:
1375-
if mask[i, j]:
1376-
1377-
# `out` does not need to be set since it
1378-
# will be masked anyway
1379-
val_is_nan = True
1388+
val = values[i, j]
1389+
if not _treat_as_na(val, is_datetimelike):
1390+
mval = accum[lab, j]
1391+
if compute_max:
1392+
if val > mval:
1393+
accum[lab, j] = mval = val
13801394
else:
1395+
if val < mval:
1396+
accum[lab, j] = mval = val
1397+
out[i, j] = mval
1398+
else:
1399+
out[i, j] = val
13811400

1382-
# If using the mask, we can avoid grabbing the
1383-
# value unless necessary
1384-
val = values[i, j]
13851401

1386-
# Otherwise, `out` must be set accordingly if the
1387-
# value is missing
1388-
else:
1389-
val = values[i, j]
1390-
if _treat_as_na(val, is_datetimelike):
1391-
val_is_nan = True
1392-
out[i, j] = val
1402+
@cython.boundscheck(False)
1403+
@cython.wraparound(False)
1404+
cdef masked_cummin_max(groupby_t[:, ::1] out,
1405+
ndarray[groupby_t, ndim=2] values,
1406+
uint8_t[:, ::1] mask,
1407+
const intp_t[:] labels,
1408+
groupby_t[:, ::1] accum,
1409+
bint compute_max):
1410+
"""
1411+
Compute the cumulative minimum/maximum of columns of `values`, in row groups
1412+
`labels` with a masked algorithm.
1413+
"""
1414+
cdef:
1415+
Py_ssize_t i, j, N, K
1416+
groupby_t val, mval
1417+
intp_t lab
13931418

1394-
if not val_is_nan:
1419+
N, K = (<object>values).shape
1420+
with nogil:
1421+
for i in range(N):
1422+
lab = labels[i]
1423+
if lab < 0:
1424+
continue
1425+
for j in range(K):
1426+
if not mask[i, j]:
1427+
val = values[i, j]
13951428
mval = accum[lab, j]
13961429
if compute_max:
13971430
if val > mval:

0 commit comments

Comments
 (0)