Skip to content

ENH: Support mask for groupby var and mean #48078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Sep 2, 2022
Merged
39 changes: 39 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,45 @@ def time_frame_agg(self, dtype, method):
self.df.groupby("key").agg(method)


class GroupByCythonAggEaDtypes:
"""
Benchmarks specifically targeting our cython aggregation algorithms
(using a big enough dataframe with simple key, so a large part of the
time is actually spent in the grouped aggregation).
"""

param_names = ["dtype", "method"]
params = [
["Float64", "Int64", "Int32"],
[
"sum",
"prod",
"min",
"max",
"mean",
"median",
"var",
"first",
"last",
"any",
"all",
],
]

def setup(self, dtype, method):
N = 1_000_000
df = DataFrame(
np.random.randint(0, high=100, size=(N, 10)),
columns=list("abcdefghij"),
dtype=dtype,
)
df["key"] = np.random.randint(0, 100, size=N)
self.df = df

def time_frame_agg(self, dtype, method):
self.df.groupby("key").agg(method)


class Cumulative:
param_names = ["dtype", "method"]
params = [
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ Deprecations

Performance improvements
~~~~~~~~~~~~~~~~~~~~~~~~
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
-

Expand Down
2 changes: 2 additions & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def group_var(
labels: np.ndarray, # const intp_t[:]
min_count: int = ..., # Py_ssize_t
ddof: int = ..., # int64_t
mask: np.ndarray | None = ...,
result_mask: np.ndarray | None = ...,
) -> None: ...
def group_mean(
out: np.ndarray, # floating[:, ::1]
Expand Down
46 changes: 37 additions & 9 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -759,13 +759,16 @@ def group_var(
const intp_t[::1] labels,
Py_ssize_t min_count=-1,
int64_t ddof=1,
const uint8_t[:, ::1] mask=None,
uint8_t[:, ::1] result_mask=None,
) -> None:
cdef:
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
floating val, ct, oldmean
floating[:, ::1] mean
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint isna_entry, uses_mask = not mask is None

assert min_count == -1, "'min_count' only used in sum and prod"

Expand All @@ -790,8 +793,12 @@ def group_var(
for j in range(K):
val = values[i, j]

# not nan
if val == val:
if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = not val == val

if not isna_entry:
nobs[lab, j] += 1
oldmean = mean[lab, j]
mean[lab, j] += (val - oldmean) / nobs[lab, j]
Expand All @@ -801,7 +808,10 @@ def group_var(
for j in range(K):
ct = nobs[i, j]
if ct <= ddof:
out[i, j] = NAN
if uses_mask:
result_mask[i, j] = True
else:
out[i, j] = NAN
else:
out[i, j] /= (ct - ddof)

Expand Down Expand Up @@ -839,9 +849,9 @@ def group_mean(
is_datetimelike : bool
True if `values` contains datetime-like entries.
mask : ndarray[bool, ndim=2], optional
Not used.
Mask of the input values.
result_mask : ndarray[bool, ndim=2], optional
Not used.
Mask of the out array

Notes
-----
Expand All @@ -855,6 +865,7 @@ def group_mean(
mean_t[:, ::1] sumx, compensation
int64_t[:, ::1] nobs
Py_ssize_t len_values = len(values), len_labels = len(labels)
bint isna_entry, uses_mask = not mask is None

assert min_count == -1, "'min_count' only used in sum and prod"

Expand All @@ -867,7 +878,12 @@ def group_mean(
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)

N, K = (<object>values).shape
nan_val = NPY_NAT if is_datetimelike else NAN
if uses_mask:
nan_val = 0
elif is_datetimelike:
nan_val = NPY_NAT
else:
nan_val = NAN

with nogil:
for i in range(N):
Expand All @@ -878,8 +894,15 @@ def group_mean(
counts[lab] += 1
for j in range(K):
val = values[i, j]
# not nan
if val == val and not (is_datetimelike and val == NPY_NAT):

if uses_mask:
isna_entry = mask[i, j]
elif is_datetimelike:
isna_entry = val == NPY_NAT
else:
isna_entry = not val == val

if not isna_entry:
nobs[lab, j] += 1
y = val - compensation[lab, j]
t = sumx[lab, j] + y
Expand All @@ -890,7 +913,12 @@ def group_mean(
for j in range(K):
count = nobs[i, j]
if nobs[i, j] == 0:
out[i, j] = nan_val

if uses_mask:
result_mask[i, j] = True
else:
out[i, j] = nan_val

else:
out[i, j] = sumx[i, j] / count

Expand Down
7 changes: 5 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"ohlc",
"cumsum",
"prod",
"mean",
"var",
}

_cython_arity = {"ohlc": 4} # OHLC
Expand Down Expand Up @@ -598,7 +600,7 @@ def _call_cython_op(
min_count=min_count,
is_datetimelike=is_datetimelike,
)
elif self.how in ["ohlc", "prod"]:
elif self.how in ["var", "ohlc", "prod"]:
func(
Copy link
Member Author

@phofl phofl Aug 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part needs some refactoring when every algo supports masks

result,
counts,
Expand All @@ -607,9 +609,10 @@ def _call_cython_op(
min_count=min_count,
mask=mask,
result_mask=result_mask,
**kwargs,
)
else:
func(result, counts, values, comp_ids, min_count, **kwargs)
func(result, counts, values, comp_ids, min_count)
else:
# TODO: min_count
if self.uses_mask():
Expand Down