Skip to content

REF: deduplicate group min/max #40584

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 23, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 57 additions & 84 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1156,18 +1156,40 @@ ctypedef fused groupby_t:

@cython.wraparound(False)
@cython.boundscheck(False)
def group_max(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const int64_t[:] labels,
Py_ssize_t min_count=-1):
cdef group_min_max(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const int64_t[:] labels,
Py_ssize_t min_count=-1,
bint compute_max=True):
"""
Only aggregates on axis=0
Compute minimum/maximum of columns of `values`, in row groups `labels`.

Parameters
----------
out : array
Array to store result in.
counts : int64 array
Input as a zeroed array, populated by group sizes during algorithm
values : array
Values to find column-wise min/max of.
labels : int64 array
Labels to group by.
min_count : Py_ssize_t, default -1
The minimum number of non-NA group elements, NA result if threshold
is not met
compute_max : bint, default True
True to compute group-wise max, False to compute min

Notes
-----
This method modifies the `out` parameter, rather than returning an object.
`counts` is modified to hold group sizes
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
groupby_t val, count, nan_val
ndarray[groupby_t, ndim=2] maxx
Py_ssize_t i, j, N, K, lab, ngroups = len(counts)
groupby_t val, nan_val
ndarray[groupby_t, ndim=2] group_min_or_max
bint runtime_error = False
int64_t[:, ::1] nobs

Expand All @@ -1179,18 +1201,17 @@ def group_max(groupby_t[:, ::1] out,
min_count = max(min_count, 1)
nobs = np.zeros((<object>out).shape, dtype=np.int64)

maxx = np.empty_like(out)
group_min_or_max = np.empty_like(out)
if groupby_t is int64_t:
# Note: evaluated at compile-time
maxx[:] = -_int64_max
group_min_or_max[:] = -_int64_max if compute_max else _int64_max
nan_val = NPY_NAT
elif groupby_t is uint64_t:
# NB: We do not define nan_val because there is no such thing
# for uint64_t. We carefully avoid having to reference it in this
# case.
maxx[:] = 0
# for uint64_t. We carefully avoid having to reference it in this
# case.
group_min_or_max[:] = 0 if compute_max else np.iinfo(np.uint64).max
else:
maxx[:] = -np.inf
group_min_or_max[:] = -np.inf if compute_max else np.inf
nan_val = NAN

N, K = (<object>values).shape
Expand All @@ -1208,20 +1229,23 @@ def group_max(groupby_t[:, ::1] out,
if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val > maxx[lab, j]:
maxx[lab, j] = val
if compute_max:
if val > group_min_or_max[lab, j]:
group_min_or_max[lab, j] = val
else:
if val < group_min_or_max[lab, j]:
group_min_or_max[lab, j] = val
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wanted to be able to do something like if PyObject_RichCompareBool(val, ..., cmp_method): ... to simplify the conditional logic, but couldn't find a way to do something like this in Cython without the GIL.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a perf impact here? its checking compute_max at each step in a nested loop

Copy link
Member Author

@mzeitlin11 mzeitlin11 Mar 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't see any perf impact based on ASVs above. For a loop invariant like compute_max, I'd guess this check is optimized away by the compiler. (It could also just be negligible compared to the cost of the indexing operations)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. worth it to de-duplicate a bunch of code.


for i in range(ncounts):
for i in range(ngroups):
for j in range(K):
if nobs[i, j] < min_count:
if groupby_t is uint64_t:
runtime_error = True
break
else:

out[i, j] = nan_val
else:
out[i, j] = maxx[i, j]
out[i, j] = group_min_or_max[i, j]

if runtime_error:
# We cannot raise directly above because that is within a nogil
Expand All @@ -1231,75 +1255,24 @@ def group_max(groupby_t[:, ::1] out,

@cython.wraparound(False)
@cython.boundscheck(False)
def group_min(groupby_t[:, ::1] out,
def group_max(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const int64_t[:] labels,
Py_ssize_t min_count=-1):
"""
Only aggregates on axis=0
"""
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
groupby_t val, count, nan_val
ndarray[groupby_t, ndim=2] minx
bint runtime_error = False
int64_t[:, ::1] nobs

# TODO(cython 3.0):
# Instead of `labels.shape[0]` use `len(labels)`
if not len(values) == labels.shape[0]:
raise AssertionError("len(index) != len(labels)")

min_count = max(min_count, 1)
nobs = np.zeros((<object>out).shape, dtype=np.int64)

minx = np.empty_like(out)
if groupby_t is int64_t:
minx[:] = _int64_max
nan_val = NPY_NAT
elif groupby_t is uint64_t:
# NB: We do not define nan_val because there is no such thing
# for uint64_t. We carefully avoid having to reference it in this
# case.
minx[:] = np.iinfo(np.uint64).max
else:
minx[:] = np.inf
nan_val = NAN
"""See group_min_max.__doc__"""
group_min_max(out, counts, values, labels, min_count=min_count, compute_max=True)

N, K = (<object>values).shape

with nogil:
for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

if not _treat_as_na(val, True):
# TODO: Sure we always want is_datetimelike=True?
nobs[lab, j] += 1
if val < minx[lab, j]:
minx[lab, j] = val

for i in range(ncounts):
for j in range(K):
if nobs[i, j] < min_count:
if groupby_t is uint64_t:
runtime_error = True
break
else:
out[i, j] = nan_val
else:
out[i, j] = minx[i, j]

if runtime_error:
# We cannot raise directly above because that is within a nogil
# block.
raise RuntimeError("empty group with uint64_t")
@cython.wraparound(False)
@cython.boundscheck(False)
def group_min(groupby_t[:, ::1] out,
int64_t[::1] counts,
ndarray[groupby_t, ndim=2] values,
const int64_t[:] labels,
Py_ssize_t min_count=-1):
"""See group_min_max.__doc__"""
group_min_max(out, counts, values, labels, min_count=min_count, compute_max=False)


@cython.boundscheck(False)
Expand Down