Skip to content

Commit 088cb58

Browse files
mzeitlin11vladu
authored andcommitted
REF: deduplicate group cummin/max (pandas-dev#40599)
1 parent 55be1e5 commit 088cb58

File tree

1 file changed

+33
-63
lines changed

1 file changed

+33
-63
lines changed

pandas/_libs/groupby.pyx

+33-63
Original file line numberDiff line numberDiff line change
@@ -1249,26 +1249,30 @@ def group_min(groupby_t[:, ::1] out,
12491249

12501250
@cython.boundscheck(False)
12511251
@cython.wraparound(False)
1252-
def group_cummin(groupby_t[:, ::1] out,
1253-
ndarray[groupby_t, ndim=2] values,
1254-
const int64_t[:] labels,
1255-
int ngroups,
1256-
bint is_datetimelike):
1252+
def group_cummin_max(groupby_t[:, ::1] out,
1253+
ndarray[groupby_t, ndim=2] values,
1254+
const int64_t[:] labels,
1255+
int ngroups,
1256+
bint is_datetimelike,
1257+
bint compute_max):
12571258
"""
1258-
Cumulative minimum of columns of `values`, in row groups `labels`.
1259+
Cumulative minimum/maximum of columns of `values`, in row groups `labels`.
12591260
12601261
Parameters
12611262
----------
12621263
out : array
1263-
Array to store cummin in.
1264+
Array to store cummin/max in.
12641265
values : array
1265-
Values to take cummin of.
1266+
Values to take cummin/max of.
12661267
labels : int64 array
12671268
Labels to group by.
12681269
ngroups : int
12691270
Number of groups, larger than all entries of `labels`.
12701271
is_datetimelike : bool
12711272
True if `values` contains datetime-like entries.
1273+
compute_max : bool
1274+
True if cumulative maximum should be computed, False
1275+
if cumulative minimum should be computed
12721276
12731277
Notes
12741278
-----
@@ -1283,11 +1287,11 @@ def group_cummin(groupby_t[:, ::1] out,
12831287
N, K = (<object>values).shape
12841288
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
12851289
if groupby_t is int64_t:
1286-
accum[:] = _int64_max
1290+
accum[:] = -_int64_max if compute_max else _int64_max
12871291
elif groupby_t is uint64_t:
1288-
accum[:] = np.iinfo(np.uint64).max
1292+
accum[:] = 0 if compute_max else np.iinfo(np.uint64).max
12891293
else:
1290-
accum[:] = np.inf
1294+
accum[:] = -np.inf if compute_max else np.inf
12911295

12921296
with nogil:
12931297
for i in range(N):
@@ -1302,66 +1306,32 @@ def group_cummin(groupby_t[:, ::1] out,
13021306
out[i, j] = val
13031307
else:
13041308
mval = accum[lab, j]
1305-
if val < mval:
1306-
accum[lab, j] = mval = val
1309+
if compute_max:
1310+
if val > mval:
1311+
accum[lab, j] = mval = val
1312+
else:
1313+
if val < mval:
1314+
accum[lab, j] = mval = val
13071315
out[i, j] = mval
13081316

13091317

13101318
@cython.boundscheck(False)
13111319
@cython.wraparound(False)
1312-
def group_cummax(groupby_t[:, ::1] out,
1320+
def group_cummin(groupby_t[:, ::1] out,
13131321
ndarray[groupby_t, ndim=2] values,
13141322
const int64_t[:] labels,
13151323
int ngroups,
13161324
bint is_datetimelike):
1317-
"""
1318-
Cumulative maximum of columns of `values`, in row groups `labels`.
1325+
"""See group_cummin_max.__doc__"""
1326+
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=False)
13191327

1320-
Parameters
1321-
----------
1322-
out : array
1323-
Array to store cummax in.
1324-
values : array
1325-
Values to take cummax of.
1326-
labels : int64 array
1327-
Labels to group by.
1328-
ngroups : int
1329-
Number of groups, larger than all entries of `labels`.
1330-
is_datetimelike : bool
1331-
True if `values` contains datetime-like entries.
13321328

1333-
Notes
1334-
-----
1335-
This method modifies the `out` parameter, rather than returning an object.
1336-
"""
1337-
cdef:
1338-
Py_ssize_t i, j, N, K, size
1339-
groupby_t val, mval
1340-
ndarray[groupby_t, ndim=2] accum
1341-
int64_t lab
1342-
1343-
N, K = (<object>values).shape
1344-
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
1345-
if groupby_t is int64_t:
1346-
accum[:] = -_int64_max
1347-
elif groupby_t is uint64_t:
1348-
accum[:] = 0
1349-
else:
1350-
accum[:] = -np.inf
1351-
1352-
with nogil:
1353-
for i in range(N):
1354-
lab = labels[i]
1355-
1356-
if lab < 0:
1357-
continue
1358-
for j in range(K):
1359-
val = values[i, j]
1360-
1361-
if _treat_as_na(val, is_datetimelike):
1362-
out[i, j] = val
1363-
else:
1364-
mval = accum[lab, j]
1365-
if val > mval:
1366-
accum[lab, j] = mval = val
1367-
out[i, j] = mval
1329+
@cython.boundscheck(False)
1330+
@cython.wraparound(False)
1331+
def group_cummax(groupby_t[:, ::1] out,
1332+
ndarray[groupby_t, ndim=2] values,
1333+
const int64_t[:] labels,
1334+
int ngroups,
1335+
bint is_datetimelike):
1336+
"""See group_cummin_max.__doc__"""
1337+
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=True)

0 commit comments

Comments
 (0)