Skip to content

PERF: groupby.quantile #43469

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/source/whatsnew/v1.4.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,8 @@ Performance improvements
- Performance improvement in :meth:`to_datetime` with ``uint`` dtypes (:issue:`42606`)
- Performance improvement in :meth:`Series.sparse.to_coo` (:issue:`42880`)
- Performance improvement in indexing with a :class:`MultiIndex` indexer on another :class:`MultiIndex` (:issue:43370`)
- Performance improvement in :meth:`GroupBy.quantile` (:issue:`43469`)
-

.. ---------------------------------------------------------------------------

Expand Down
4 changes: 2 additions & 2 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,11 @@ def group_ohlc(
min_count: int = ...,
) -> None: ...
def group_quantile(
out: np.ndarray, # ndarray[float64_t]
out: np.ndarray, # ndarray[float64_t, ndim=2]
values: np.ndarray, # ndarray[numeric, ndim=1]
labels: np.ndarray, # ndarray[int64_t]
mask: np.ndarray, # ndarray[uint8_t]
q: float, # float64_t
qs: np.ndarray, # const float64_t[:]
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
) -> None: ...
def group_last(
Expand Down
84 changes: 46 additions & 38 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -770,25 +770,25 @@ def group_ohlc(floating[:, ::1] out,

@cython.boundscheck(False)
@cython.wraparound(False)
def group_quantile(ndarray[float64_t] out,
def group_quantile(ndarray[float64_t, ndim=2] out,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we name these _2d (or similar), thinking of rank of these are 2d based?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not in this file

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kk, prob want to make a pass over the cython code to unify the names (e.g. append 1d or 2d) where appropriate

ndarray[numeric, ndim=1] values,
ndarray[intp_t] labels,
ndarray[uint8_t] mask,
float64_t q,
const float64_t[:] qs,
str interpolation) -> None:
"""
Calculate the quantile per group.

Parameters
----------
out : np.ndarray[np.float64]
out : np.ndarray[np.float64, ndim=2]
Array of aggregated values that will be written to.
values : np.ndarray
Array containing the values to apply the function against.
labels : ndarray[np.intp]
Array containing the unique group labels.
q : float
The quantile value to search for.
qs : ndarray[float64_t]
The quantile values to search for.
interpolation : {'linear', 'lower', 'highest', 'nearest', 'midpoint'}

Notes
Expand All @@ -797,17 +797,20 @@ def group_quantile(ndarray[float64_t] out,
provided `out` parameter.
"""
cdef:
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz
Py_ssize_t i, N=len(labels), ngroups, grp_sz, non_na_sz, k, nqs
Py_ssize_t grp_start=0, idx=0
intp_t lab
uint8_t interp
float64_t q_idx, frac, val, next_val
float64_t q_val, q_idx, frac, val, next_val
ndarray[int64_t] counts, non_na_counts, sort_arr

assert values.shape[0] == N

if not (0 <= q <= 1):
raise ValueError(f"'q' must be between 0 and 1. Got '{q}' instead")
if any(not (0 <= q <= 1) for q in qs):
wrong = [x for x in qs if not (0 <= x <= 1)][0]
raise ValueError(
f"Each 'q' must be between 0 and 1. Got '{wrong}' instead"
)

inter_methods = {
'linear': INTERPOLATION_LINEAR,
Expand All @@ -818,9 +821,10 @@ def group_quantile(ndarray[float64_t] out,
}
interp = inter_methods[interpolation]

counts = np.zeros_like(out, dtype=np.int64)
non_na_counts = np.zeros_like(out, dtype=np.int64)
ngroups = len(counts)
nqs = len(qs)
ngroups = len(out)
counts = np.zeros(ngroups, dtype=np.int64)
non_na_counts = np.zeros(ngroups, dtype=np.int64)

# First figure out the size of every group
with nogil:
Expand Down Expand Up @@ -850,33 +854,37 @@ def group_quantile(ndarray[float64_t] out,
non_na_sz = non_na_counts[i]

if non_na_sz == 0:
out[i] = NaN
for k in range(nqs):
out[i, k] = NaN
else:
# Calculate where to retrieve the desired value
# Casting to int will intentionally truncate result
idx = grp_start + <int64_t>(q * <float64_t>(non_na_sz - 1))

val = values[sort_arr[idx]]
# If requested quantile falls evenly on a particular index
# then write that index's value out. Otherwise interpolate
q_idx = q * (non_na_sz - 1)
frac = q_idx % 1

if frac == 0.0 or interp == INTERPOLATION_LOWER:
out[i] = val
else:
next_val = values[sort_arr[idx + 1]]
if interp == INTERPOLATION_LINEAR:
out[i] = val + (next_val - val) * frac
elif interp == INTERPOLATION_HIGHER:
out[i] = next_val
elif interp == INTERPOLATION_MIDPOINT:
out[i] = (val + next_val) / 2.0
elif interp == INTERPOLATION_NEAREST:
if frac > .5 or (frac == .5 and q > .5): # Always OK?
out[i] = next_val
else:
out[i] = val
for k in range(nqs):
q_val = qs[k]

# Calculate where to retrieve the desired value
# Casting to int will intentionally truncate result
idx = grp_start + <int64_t>(q_val * <float64_t>(non_na_sz - 1))

val = values[sort_arr[idx]]
# If requested quantile falls evenly on a particular index
# then write that index's value out. Otherwise interpolate
q_idx = q_val * (non_na_sz - 1)
frac = q_idx % 1

if frac == 0.0 or interp == INTERPOLATION_LOWER:
out[i, k] = val
else:
next_val = values[sort_arr[idx + 1]]
if interp == INTERPOLATION_LINEAR:
out[i, k] = val + (next_val - val) * frac
elif interp == INTERPOLATION_HIGHER:
out[i, k] = next_val
elif interp == INTERPOLATION_MIDPOINT:
out[i, k] = (val + next_val) / 2.0
elif interp == INTERPOLATION_NEAREST:
if frac > .5 or (frac == .5 and q_val > .5): # Always OK?
out[i, k] = next_val
else:
out[i, k] = val

# Increment the index reference in sorted_arr for the next group
grp_start += grp_sz
Expand Down
Loading