Skip to content

Commit 88fdc75

Browse files
mzeitlin11JulianWgs
authored andcommitted
PERF/BUG: use masked algo in groupby cummin and cummax (pandas-dev#40651)
1 parent c5bb83f commit 88fdc75

File tree

5 files changed

+199
-37
lines changed

5 files changed

+199
-37
lines changed

asv_bench/benchmarks/groupby.py

+28
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,34 @@ def time_frame_agg(self, dtype, method):
505505
self.df.groupby("key").agg(method)
506506

507507

508+
class CumminMax:
509+
param_names = ["dtype", "method"]
510+
params = [
511+
["float64", "int64", "Float64", "Int64"],
512+
["cummin", "cummax"],
513+
]
514+
515+
def setup(self, dtype, method):
516+
N = 500_000
517+
vals = np.random.randint(-10, 10, (N, 5))
518+
null_vals = vals.astype(float, copy=True)
519+
null_vals[::2, :] = np.nan
520+
null_vals[::3, :] = np.nan
521+
df = DataFrame(vals, columns=list("abcde"), dtype=dtype)
522+
null_df = DataFrame(null_vals, columns=list("abcde"), dtype=dtype)
523+
keys = np.random.randint(0, 100, size=N)
524+
df["key"] = keys
525+
null_df["key"] = keys
526+
self.df = df
527+
self.null_df = null_df
528+
529+
def time_frame_transform(self, dtype, method):
530+
self.df.groupby("key").transform(method)
531+
532+
def time_frame_transform_many_nulls(self, dtype, method):
533+
self.null_df.groupby("key").transform(method)
534+
535+
508536
class RankWithTies:
509537
# GH 21237
510538
param_names = ["dtype", "tie_method"]

doc/source/whatsnew/v1.3.0.rst

+3
Original file line numberDiff line numberDiff line change
@@ -630,6 +630,8 @@ Performance improvements
630630
- Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`)
631631
- Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`)
632632
- Performance improvement for concatenation of data with type :class:`CategoricalDtype` (:issue:`40193`)
633+
- Performance improvement in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` with nullable data types (:issue:`37493`)
634+
-
633635

634636
.. ---------------------------------------------------------------------------
635637
@@ -838,6 +840,7 @@ Groupby/resample/rolling
838840
- Bug in :meth:`.GroupBy.any` and :meth:`.GroupBy.all` raising ``ValueError`` when using with nullable type columns holding ``NA`` even with ``skipna=True`` (:issue:`40585`)
839841
- Bug in :meth:`GroupBy.cummin` and :meth:`GroupBy.cummax` incorrectly rounding integer values near the ``int64`` implementations bounds (:issue:`40767`)
840842
- Bug in :meth:`.GroupBy.rank` with nullable dtypes incorrectly raising ``TypeError`` (:issue:`41010`)
843+
- Bug in :meth:`.GroupBy.cummin` and :meth:`.GroupBy.cummax` computing wrong result with nullable data types too large to roundtrip when casting to float (:issue:`37493`)
841844

842845
Reshaping
843846
^^^^^^^^^

pandas/_libs/groupby.pyx

+52-9
Original file line numberDiff line numberDiff line change
@@ -1277,6 +1277,7 @@ def group_min(groupby_t[:, ::1] out,
12771277
@cython.wraparound(False)
12781278
cdef group_cummin_max(groupby_t[:, ::1] out,
12791279
ndarray[groupby_t, ndim=2] values,
1280+
uint8_t[:, ::1] mask,
12801281
const intp_t[:] labels,
12811282
int ngroups,
12821283
bint is_datetimelike,
@@ -1290,6 +1291,9 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
12901291
Array to store cummin/max in.
12911292
values : np.ndarray[groupby_t, ndim=2]
12921293
Values to take cummin/max of.
1294+
mask : np.ndarray[bool] or None
1295+
If not None, indices represent missing values,
1296+
otherwise the mask will not be used
12931297
labels : np.ndarray[np.intp]
12941298
Labels to group by.
12951299
ngroups : int
@@ -1307,11 +1311,14 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
13071311
cdef:
13081312
Py_ssize_t i, j, N, K, size
13091313
groupby_t val, mval
1310-
ndarray[groupby_t, ndim=2] accum
1314+
groupby_t[:, ::1] accum
13111315
intp_t lab
1316+
bint val_is_nan, use_mask
1317+
1318+
use_mask = mask is not None
13121319

13131320
N, K = (<object>values).shape
1314-
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
1321+
accum = np.empty((ngroups, K), dtype=values.dtype)
13151322
if groupby_t is int64_t:
13161323
accum[:] = -_int64_max if compute_max else _int64_max
13171324
elif groupby_t is uint64_t:
@@ -1326,11 +1333,29 @@ cdef group_cummin_max(groupby_t[:, ::1] out,
13261333
if lab < 0:
13271334
continue
13281335
for j in range(K):
1329-
val = values[i, j]
1336+
val_is_nan = False
1337+
1338+
if use_mask:
1339+
if mask[i, j]:
1340+
1341+
# `out` does not need to be set since it
1342+
# will be masked anyway
1343+
val_is_nan = True
1344+
else:
1345+
1346+
# If using the mask, we can avoid grabbing the
1347+
# value unless necessary
1348+
val = values[i, j]
13301349

1331-
if _treat_as_na(val, is_datetimelike):
1332-
out[i, j] = val
1350+
# Otherwise, `out` must be set accordingly if the
1351+
# value is missing
13331352
else:
1353+
val = values[i, j]
1354+
if _treat_as_na(val, is_datetimelike):
1355+
val_is_nan = True
1356+
out[i, j] = val
1357+
1358+
if not val_is_nan:
13341359
mval = accum[lab, j]
13351360
if compute_max:
13361361
if val > mval:
@@ -1347,9 +1372,18 @@ def group_cummin(groupby_t[:, ::1] out,
13471372
ndarray[groupby_t, ndim=2] values,
13481373
const intp_t[:] labels,
13491374
int ngroups,
1350-
bint is_datetimelike) -> None:
1375+
bint is_datetimelike,
1376+
uint8_t[:, ::1] mask=None) -> None:
13511377
"""See group_cummin_max.__doc__"""
1352-
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=False)
1378+
group_cummin_max(
1379+
out,
1380+
values,
1381+
mask,
1382+
labels,
1383+
ngroups,
1384+
is_datetimelike,
1385+
compute_max=False
1386+
)
13531387

13541388

13551389
@cython.boundscheck(False)
@@ -1358,6 +1392,15 @@ def group_cummax(groupby_t[:, ::1] out,
13581392
ndarray[groupby_t, ndim=2] values,
13591393
const intp_t[:] labels,
13601394
int ngroups,
1361-
bint is_datetimelike) -> None:
1395+
bint is_datetimelike,
1396+
uint8_t[:, ::1] mask=None) -> None:
13621397
"""See group_cummin_max.__doc__"""
1363-
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=True)
1398+
group_cummin_max(
1399+
out,
1400+
values,
1401+
mask,
1402+
labels,
1403+
ngroups,
1404+
is_datetimelike,
1405+
compute_max=True
1406+
)

pandas/core/groupby/ops.py

+71-5
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,10 @@
6565
)
6666

6767
from pandas.core.arrays import ExtensionArray
68+
from pandas.core.arrays.masked import (
69+
BaseMaskedArray,
70+
BaseMaskedDtype,
71+
)
6872
import pandas.core.common as com
6973
from pandas.core.frame import DataFrame
7074
from pandas.core.generic import NDFrame
@@ -124,6 +128,8 @@ def __init__(self, kind: str, how: str):
124128
},
125129
}
126130

131+
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax"}
132+
127133
_cython_arity = {"ohlc": 4} # OHLC
128134

129135
# Note: we make this a classmethod and pass kind+how so that caching
@@ -256,6 +262,9 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
256262
out_dtype = "object"
257263
return np.dtype(out_dtype)
258264

265+
def uses_mask(self) -> bool:
266+
return self.how in self._MASKED_CYTHON_FUNCTIONS
267+
259268

260269
class BaseGrouper:
261270
"""
@@ -619,9 +628,45 @@ def _ea_wrap_cython_operation(
619628
f"function is not implemented for this dtype: {values.dtype}"
620629
)
621630

631+
@final
632+
def _masked_ea_wrap_cython_operation(
633+
self,
634+
kind: str,
635+
values: BaseMaskedArray,
636+
how: str,
637+
axis: int,
638+
min_count: int = -1,
639+
**kwargs,
640+
) -> BaseMaskedArray:
641+
"""
642+
Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's
643+
and cython algorithms which accept a mask.
644+
"""
645+
orig_values = values
646+
647+
# Copy to ensure input and result masks don't end up shared
648+
mask = values._mask.copy()
649+
arr = values._data
650+
651+
res_values = self._cython_operation(
652+
kind, arr, how, axis, min_count, mask=mask, **kwargs
653+
)
654+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
655+
assert isinstance(dtype, BaseMaskedDtype)
656+
cls = dtype.construct_array_type()
657+
658+
return cls(res_values.astype(dtype.type, copy=False), mask)
659+
622660
@final
623661
def _cython_operation(
624-
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
662+
self,
663+
kind: str,
664+
values,
665+
how: str,
666+
axis: int,
667+
min_count: int = -1,
668+
mask: np.ndarray | None = None,
669+
**kwargs,
625670
) -> ArrayLike:
626671
"""
627672
Returns the values of a cython operation.
@@ -645,10 +690,16 @@ def _cython_operation(
645690
# if not raise NotImplementedError
646691
cy_op.disallow_invalid_ops(dtype, is_numeric)
647692

693+
func_uses_mask = cy_op.uses_mask()
648694
if is_extension_array_dtype(dtype):
649-
return self._ea_wrap_cython_operation(
650-
kind, values, how, axis, min_count, **kwargs
651-
)
695+
if isinstance(values, BaseMaskedArray) and func_uses_mask:
696+
return self._masked_ea_wrap_cython_operation(
697+
kind, values, how, axis, min_count, **kwargs
698+
)
699+
else:
700+
return self._ea_wrap_cython_operation(
701+
kind, values, how, axis, min_count, **kwargs
702+
)
652703

653704
elif values.ndim == 1:
654705
# expand to 2d, dispatch, then squeeze if appropriate
@@ -659,6 +710,7 @@ def _cython_operation(
659710
how=how,
660711
axis=1,
661712
min_count=min_count,
713+
mask=mask,
662714
**kwargs,
663715
)
664716
if res.shape[0] == 1:
@@ -688,6 +740,9 @@ def _cython_operation(
688740
assert axis == 1
689741
values = values.T
690742

743+
if mask is not None:
744+
mask = mask.reshape(values.shape, order="C")
745+
691746
out_shape = cy_op.get_output_shape(ngroups, values)
692747
func, values = cy_op.get_cython_func_and_vals(values, is_numeric)
693748
out_dtype = cy_op.get_out_dtype(values.dtype)
@@ -708,7 +763,18 @@ def _cython_operation(
708763
func(result, counts, values, comp_ids, min_count)
709764
elif kind == "transform":
710765
# TODO: min_count
711-
func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
766+
if func_uses_mask:
767+
func(
768+
result,
769+
values,
770+
comp_ids,
771+
ngroups,
772+
is_datetimelike,
773+
mask=mask,
774+
**kwargs,
775+
)
776+
else:
777+
func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
712778

713779
if kind == "aggregate":
714780
# i.e. counts is defined. Locations where count<min_count

0 commit comments

Comments
 (0)