Skip to content

Commit 9c509e2

Browse files
authored
ENH: Support mask for groupby var and mean (#48078)
1 parent 854987f commit 9c509e2

File tree

5 files changed

+84
-11
lines changed

5 files changed

+84
-11
lines changed

asv_bench/benchmarks/groupby.py

+39
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,45 @@ def time_frame_agg(self, dtype, method):
560560
self.df.groupby("key").agg(method)
561561

562562

563+
class GroupByCythonAggEaDtypes:
564+
"""
565+
Benchmarks specifically targeting our cython aggregation algorithms
566+
(using a big enough dataframe with simple key, so a large part of the
567+
time is actually spent in the grouped aggregation).
568+
"""
569+
570+
param_names = ["dtype", "method"]
571+
params = [
572+
["Float64", "Int64", "Int32"],
573+
[
574+
"sum",
575+
"prod",
576+
"min",
577+
"max",
578+
"mean",
579+
"median",
580+
"var",
581+
"first",
582+
"last",
583+
"any",
584+
"all",
585+
],
586+
]
587+
588+
def setup(self, dtype, method):
589+
N = 1_000_000
590+
df = DataFrame(
591+
np.random.randint(0, high=100, size=(N, 10)),
592+
columns=list("abcdefghij"),
593+
dtype=dtype,
594+
)
595+
df["key"] = np.random.randint(0, 100, size=N)
596+
self.df = df
597+
598+
def time_frame_agg(self, dtype, method):
599+
self.df.groupby("key").agg(method)
600+
601+
563602
class Cumulative:
564603
param_names = ["dtype", "method"]
565604
params = [

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ Deprecations
100100

101101
Performance improvements
102102
~~~~~~~~~~~~~~~~~~~~~~~~
103+
- Performance improvement in :meth:`.GroupBy.mean` and :meth:`.GroupBy.var` for extension array dtypes (:issue:`37493`)
103104
- Performance improvement for :meth:`MultiIndex.unique` (:issue:`48335`)
104105
-
105106

pandas/_libs/groupby.pyi

+2
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,8 @@ def group_var(
7878
labels: np.ndarray, # const intp_t[:]
7979
min_count: int = ..., # Py_ssize_t
8080
ddof: int = ..., # int64_t
81+
mask: np.ndarray | None = ...,
82+
result_mask: np.ndarray | None = ...,
8183
) -> None: ...
8284
def group_mean(
8385
out: np.ndarray, # floating[:, ::1]

pandas/_libs/groupby.pyx

+37-9
Original file line numberDiff line numberDiff line change
@@ -759,13 +759,16 @@ def group_var(
759759
const intp_t[::1] labels,
760760
Py_ssize_t min_count=-1,
761761
int64_t ddof=1,
762+
const uint8_t[:, ::1] mask=None,
763+
uint8_t[:, ::1] result_mask=None,
762764
) -> None:
763765
cdef:
764766
Py_ssize_t i, j, N, K, lab, ncounts = len(counts)
765767
floating val, ct, oldmean
766768
floating[:, ::1] mean
767769
int64_t[:, ::1] nobs
768770
Py_ssize_t len_values = len(values), len_labels = len(labels)
771+
bint isna_entry, uses_mask = not mask is None
769772

770773
assert min_count == -1, "'min_count' only used in sum and prod"
771774

@@ -790,8 +793,12 @@ def group_var(
790793
for j in range(K):
791794
val = values[i, j]
792795

793-
# not nan
794-
if val == val:
796+
if uses_mask:
797+
isna_entry = mask[i, j]
798+
else:
799+
isna_entry = not val == val
800+
801+
if not isna_entry:
795802
nobs[lab, j] += 1
796803
oldmean = mean[lab, j]
797804
mean[lab, j] += (val - oldmean) / nobs[lab, j]
@@ -801,7 +808,10 @@ def group_var(
801808
for j in range(K):
802809
ct = nobs[i, j]
803810
if ct <= ddof:
804-
out[i, j] = NAN
811+
if uses_mask:
812+
result_mask[i, j] = True
813+
else:
814+
out[i, j] = NAN
805815
else:
806816
out[i, j] /= (ct - ddof)
807817

@@ -839,9 +849,9 @@ def group_mean(
839849
is_datetimelike : bool
840850
True if `values` contains datetime-like entries.
841851
mask : ndarray[bool, ndim=2], optional
842-
Not used.
852+
Mask of the input values.
843853
result_mask : ndarray[bool, ndim=2], optional
844-
Not used.
854+
Mask of the out array
845855

846856
Notes
847857
-----
@@ -855,6 +865,7 @@ def group_mean(
855865
mean_t[:, ::1] sumx, compensation
856866
int64_t[:, ::1] nobs
857867
Py_ssize_t len_values = len(values), len_labels = len(labels)
868+
bint isna_entry, uses_mask = not mask is None
858869

859870
assert min_count == -1, "'min_count' only used in sum and prod"
860871

@@ -867,7 +878,12 @@ def group_mean(
867878
compensation = np.zeros((<object>out).shape, dtype=(<object>out).base.dtype)
868879

869880
N, K = (<object>values).shape
870-
nan_val = NPY_NAT if is_datetimelike else NAN
881+
if uses_mask:
882+
nan_val = 0
883+
elif is_datetimelike:
884+
nan_val = NPY_NAT
885+
else:
886+
nan_val = NAN
871887

872888
with nogil:
873889
for i in range(N):
@@ -878,8 +894,15 @@ def group_mean(
878894
counts[lab] += 1
879895
for j in range(K):
880896
val = values[i, j]
881-
# not nan
882-
if val == val and not (is_datetimelike and val == NPY_NAT):
897+
898+
if uses_mask:
899+
isna_entry = mask[i, j]
900+
elif is_datetimelike:
901+
isna_entry = val == NPY_NAT
902+
else:
903+
isna_entry = not val == val
904+
905+
if not isna_entry:
883906
nobs[lab, j] += 1
884907
y = val - compensation[lab, j]
885908
t = sumx[lab, j] + y
@@ -890,7 +913,12 @@ def group_mean(
890913
for j in range(K):
891914
count = nobs[i, j]
892915
if nobs[i, j] == 0:
893-
out[i, j] = nan_val
916+
917+
if uses_mask:
918+
result_mask[i, j] = True
919+
else:
920+
out[i, j] = nan_val
921+
894922
else:
895923
out[i, j] = sumx[i, j] / count
896924

pandas/core/groupby/ops.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
160160
"ohlc",
161161
"cumsum",
162162
"prod",
163+
"mean",
164+
"var",
163165
}
164166

165167
_cython_arity = {"ohlc": 4} # OHLC
@@ -598,7 +600,7 @@ def _call_cython_op(
598600
min_count=min_count,
599601
is_datetimelike=is_datetimelike,
600602
)
601-
elif self.how in ["ohlc", "prod"]:
603+
elif self.how in ["var", "ohlc", "prod"]:
602604
func(
603605
result,
604606
counts,
@@ -607,9 +609,10 @@ def _call_cython_op(
607609
min_count=min_count,
608610
mask=mask,
609611
result_mask=result_mask,
612+
**kwargs,
610613
)
611614
else:
612-
func(result, counts, values, comp_ids, min_count, **kwargs)
615+
func(result, counts, values, comp_ids, min_count)
613616
else:
614617
# TODO: min_count
615618
if self.uses_mask():

0 commit comments

Comments
 (0)