Skip to content

Commit eda43ec

Browse files
mzeitlin11JulianWgs
authored andcommitted
REF: deduplicate group min/max (pandas-dev#40584)
1 parent 9f9f22b commit eda43ec

File tree

1 file changed

+57
-84
lines changed

1 file changed

+57
-84
lines changed

pandas/_libs/groupby.pyx

+57-84
Original file line numberDiff line numberDiff line change
@@ -1127,18 +1127,40 @@ ctypedef fused groupby_t:
11271127

11281128
@cython.wraparound(False)
11291129
@cython.boundscheck(False)
1130-
def group_max(groupby_t[:, ::1] out,
1131-
int64_t[::1] counts,
1132-
ndarray[groupby_t, ndim=2] values,
1133-
const int64_t[:] labels,
1134-
Py_ssize_t min_count=-1):
1130+
cdef group_min_max(groupby_t[:, ::1] out,
1131+
int64_t[::1] counts,
1132+
ndarray[groupby_t, ndim=2] values,
1133+
const int64_t[:] labels,
1134+
Py_ssize_t min_count=-1,
1135+
bint compute_max=True):
11351136
"""
1136-
Only aggregates on axis=0
1137+
Compute minimum/maximum of columns of `values`, in row groups `labels`.
1138+
1139+
Parameters
1140+
----------
1141+
out : array
1142+
Array to store result in.
1143+
counts : int64 array
1144+
Input as a zeroed array, populated by group sizes during algorithm
1145+
values : array
1146+
Values to find column-wise min/max of.
1147+
labels : int64 array
1148+
Labels to group by.
1149+
min_count : Py_ssize_t, default -1
1150+
The minimum number of non-NA group elements, NA result if threshold
1151+
is not met
1152+
compute_max : bint, default True
1153+
True to compute group-wise max, False to compute min
1154+
1155+
Notes
1156+
-----
1157+
This method modifies the `out` parameter, rather than returning an object.
1158+
`counts` is modified to hold group sizes
11371159
"""
11381160
cdef:
1139-
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
1140-
groupby_t val, count, nan_val
1141-
ndarray[groupby_t, ndim=2] maxx
1161+
Py_ssize_t i, j, N, K, lab, ngroups = len(counts)
1162+
groupby_t val, nan_val
1163+
ndarray[groupby_t, ndim=2] group_min_or_max
11421164
bint runtime_error = False
11431165
int64_t[:, ::1] nobs
11441166

@@ -1150,18 +1172,17 @@ def group_max(groupby_t[:, ::1] out,
11501172
min_count = max(min_count, 1)
11511173
nobs = np.zeros((<object>out).shape, dtype=np.int64)
11521174

1153-
maxx = np.empty_like(out)
1175+
group_min_or_max = np.empty_like(out)
11541176
if groupby_t is int64_t:
1155-
# Note: evaluated at compile-time
1156-
maxx[:] = -_int64_max
1177+
group_min_or_max[:] = -_int64_max if compute_max else _int64_max
11571178
nan_val = NPY_NAT
11581179
elif groupby_t is uint64_t:
11591180
# NB: We do not define nan_val because there is no such thing
1160-
# for uint64_t. We carefully avoid having to reference it in this
1161-
# case.
1162-
maxx[:] = 0
1181+
# for uint64_t. We carefully avoid having to reference it in this
1182+
# case.
1183+
group_min_or_max[:] = 0 if compute_max else np.iinfo(np.uint64).max
11631184
else:
1164-
maxx[:] = -np.inf
1185+
group_min_or_max[:] = -np.inf if compute_max else np.inf
11651186
nan_val = NAN
11661187

11671188
N, K = (<object>values).shape
@@ -1179,20 +1200,23 @@ def group_max(groupby_t[:, ::1] out,
11791200
if not _treat_as_na(val, True):
11801201
# TODO: Sure we always want is_datetimelike=True?
11811202
nobs[lab, j] += 1
1182-
if val > maxx[lab, j]:
1183-
maxx[lab, j] = val
1203+
if compute_max:
1204+
if val > group_min_or_max[lab, j]:
1205+
group_min_or_max[lab, j] = val
1206+
else:
1207+
if val < group_min_or_max[lab, j]:
1208+
group_min_or_max[lab, j] = val
11841209

1185-
for i in range(ncounts):
1210+
for i in range(ngroups):
11861211
for j in range(K):
11871212
if nobs[i, j] < min_count:
11881213
if groupby_t is uint64_t:
11891214
runtime_error = True
11901215
break
11911216
else:
1192-
11931217
out[i, j] = nan_val
11941218
else:
1195-
out[i, j] = maxx[i, j]
1219+
out[i, j] = group_min_or_max[i, j]
11961220

11971221
if runtime_error:
11981222
# We cannot raise directly above because that is within a nogil
@@ -1202,75 +1226,24 @@ def group_max(groupby_t[:, ::1] out,
12021226

12031227
@cython.wraparound(False)
12041228
@cython.boundscheck(False)
1205-
def group_min(groupby_t[:, ::1] out,
1229+
def group_max(groupby_t[:, ::1] out,
12061230
int64_t[::1] counts,
12071231
ndarray[groupby_t, ndim=2] values,
12081232
const int64_t[:] labels,
12091233
Py_ssize_t min_count=-1):
1210-
"""
1211-
Only aggregates on axis=0
1212-
"""
1213-
cdef:
1214-
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
1215-
groupby_t val, count, nan_val
1216-
ndarray[groupby_t, ndim=2] minx
1217-
bint runtime_error = False
1218-
int64_t[:, ::1] nobs
1219-
1220-
# TODO(cython 3.0):
1221-
# Instead of `labels.shape[0]` use `len(labels)`
1222-
if not len(values) == labels.shape[0]:
1223-
raise AssertionError("len(index) != len(labels)")
1224-
1225-
min_count = max(min_count, 1)
1226-
nobs = np.zeros((<object>out).shape, dtype=np.int64)
1227-
1228-
minx = np.empty_like(out)
1229-
if groupby_t is int64_t:
1230-
minx[:] = _int64_max
1231-
nan_val = NPY_NAT
1232-
elif groupby_t is uint64_t:
1233-
# NB: We do not define nan_val because there is no such thing
1234-
# for uint64_t. We carefully avoid having to reference it in this
1235-
# case.
1236-
minx[:] = np.iinfo(np.uint64).max
1237-
else:
1238-
minx[:] = np.inf
1239-
nan_val = NAN
1234+
"""See group_min_max.__doc__"""
1235+
group_min_max(out, counts, values, labels, min_count=min_count, compute_max=True)
12401236

1241-
N, K = (<object>values).shape
12421237

1243-
with nogil:
1244-
for i in range(N):
1245-
lab = labels[i]
1246-
if lab < 0:
1247-
continue
1248-
1249-
counts[lab] += 1
1250-
for j in range(K):
1251-
val = values[i, j]
1252-
1253-
if not _treat_as_na(val, True):
1254-
# TODO: Sure we always want is_datetimelike=True?
1255-
nobs[lab, j] += 1
1256-
if val < minx[lab, j]:
1257-
minx[lab, j] = val
1258-
1259-
for i in range(ncounts):
1260-
for j in range(K):
1261-
if nobs[i, j] < min_count:
1262-
if groupby_t is uint64_t:
1263-
runtime_error = True
1264-
break
1265-
else:
1266-
out[i, j] = nan_val
1267-
else:
1268-
out[i, j] = minx[i, j]
1269-
1270-
if runtime_error:
1271-
# We cannot raise directly above because that is within a nogil
1272-
# block.
1273-
raise RuntimeError("empty group with uint64_t")
1238+
@cython.wraparound(False)
1239+
@cython.boundscheck(False)
1240+
def group_min(groupby_t[:, ::1] out,
1241+
int64_t[::1] counts,
1242+
ndarray[groupby_t, ndim=2] values,
1243+
const int64_t[:] labels,
1244+
Py_ssize_t min_count=-1):
1245+
"""See group_min_max.__doc__"""
1246+
group_min_max(out, counts, values, labels, min_count=min_count, compute_max=False)
12741247

12751248

12761249
@cython.boundscheck(False)

0 commit comments

Comments
 (0)