Skip to content

Commit ccca5df

Browse files
authored
PERF: Implement groupby idxmax/idxmin in Cython (#54234)
* PERF: Implement groupby idxmax/idxmin in Cython * Update docs * Add ASVs * mypy fixup * Refinements * Revert * Rework * Refinements * fixup * Fixup, show stderr in ASVs * Remove idxmin/idxmax from numba ASVs * WIP * WIP * Rework * Rework * fixup * whatsnew * refinements * cleanup * fixup type-hints in groupby.pyi * Use mask instead of sentinel * fixup * fixup * fixup * seen -> unobserved; add assert * Rework * cleanup * Fixup * fixup * Refinements * fixup * fixup * WIP * Avoid _maybe_mask_result * Add assert
1 parent ff5cae7 commit ccca5df

File tree

14 files changed

+299
-54
lines changed

14 files changed

+299
-54
lines changed

asv_bench/benchmarks/groupby.py

+4
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@
7373
"ffill",
7474
"first",
7575
"head",
76+
"idxmax",
77+
"idxmin",
7678
"last",
7779
"median",
7880
"nunique",
@@ -588,6 +590,8 @@ class GroupByCythonAgg:
588590
"prod",
589591
"min",
590592
"max",
593+
"idxmin",
594+
"idxmax",
591595
"mean",
592596
"median",
593597
"var",

doc/source/user_guide/groupby.rst

+2-2
Original file line numberDiff line numberDiff line change
@@ -517,8 +517,8 @@ listed below, those with a ``*`` do *not* have a Cython-optimized implementation
517517
:meth:`~.DataFrameGroupBy.count`;Compute the number of non-NA values in the groups
518518
:meth:`~.DataFrameGroupBy.cov` * ;Compute the covariance of the groups
519519
:meth:`~.DataFrameGroupBy.first`;Compute the first occurring value in each group
520-
:meth:`~.DataFrameGroupBy.idxmax` *;Compute the index of the maximum value in each group
521-
:meth:`~.DataFrameGroupBy.idxmin` *;Compute the index of the minimum value in each group
520+
:meth:`~.DataFrameGroupBy.idxmax`;Compute the index of the maximum value in each group
521+
:meth:`~.DataFrameGroupBy.idxmin`;Compute the index of the minimum value in each group
522522
:meth:`~.DataFrameGroupBy.last`;Compute the last occurring value in each group
523523
:meth:`~.DataFrameGroupBy.max`;Compute the maximum value in each group
524524
:meth:`~.DataFrameGroupBy.mean`;Compute the mean of each group

doc/source/whatsnew/v2.2.0.rst

+3-1
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,7 @@ Performance improvements
302302
- Performance improvement in :meth:`DataFrame.sort_index` and :meth:`Series.sort_index` when indexed by a :class:`MultiIndex` (:issue:`54835`)
303303
- Performance improvement in :meth:`Index.difference` (:issue:`55108`)
304304
- Performance improvement in :meth:`Series.duplicated` for pyarrow dtypes (:issue:`55255`)
305+
- Performance improvement in :meth:`SeriesGroupBy.idxmax`, :meth:`SeriesGroupBy.idxmin`, :meth:`DataFrameGroupBy.idxmax`, :meth:`DataFrameGroupBy.idxmin` (:issue:`54234`)
305306
- Performance improvement when indexing with more than 4 keys (:issue:`54550`)
306307
- Performance improvement when localizing time to UTC (:issue:`55241`)
307308

@@ -403,10 +404,11 @@ Plotting
403404
Groupby/resample/rolling
404405
^^^^^^^^^^^^^^^^^^^^^^^^
405406
- Bug in :class:`.Rolling` where duplicate datetimelike indexes are treated as consecutive rather than equal with ``closed='left'`` and ``closed='neither'`` (:issue:`20712`)
407+
- Bug in :meth:`.DataFrameGroupBy.idxmin`, :meth:`.DataFrameGroupBy.idxmax`, :meth:`.SeriesGroupBy.idxmin`, and :meth:`.SeriesGroupBy.idxmax` would not retain :class:`.Categorical` dtype when the index was a :class:`.CategoricalIndex` that contained NA values (:issue:`54234`)
408+
- Bug in :meth:`.DataFrameGroupBy.transform` and :meth:`.SeriesGroupBy.transform` when ``observed=False`` and ``f="idxmin"`` or ``f="idxmax"`` would incorrectly raise on unobserved categories (:issue:`54234`)
406409
- Bug in :meth:`DataFrame.resample` not respecting ``closed`` and ``label`` arguments for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55282`)
407410
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.BusinessDay` (:issue:`55281`)
408411
- Bug in :meth:`DataFrame.resample` where bin edges were not correct for :class:`~pandas.tseries.offsets.MonthBegin` (:issue:`55271`)
409-
-
410412

411413
Reshaping
412414
^^^^^^^^^

pandas/_libs/groupby.pyi

+12
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ def group_min(
181181
mask: np.ndarray | None = ...,
182182
result_mask: np.ndarray | None = ...,
183183
) -> None: ...
184+
def group_idxmin_idxmax(
185+
out: npt.NDArray[np.intp],
186+
counts: npt.NDArray[np.int64],
187+
values: np.ndarray, # ndarray[groupby_t, ndim=2]
188+
labels: npt.NDArray[np.intp],
189+
min_count: int = ...,
190+
is_datetimelike: bool = ...,
191+
mask: np.ndarray | None = ...,
192+
name: str = ...,
193+
skipna: bool = ...,
194+
result_mask: np.ndarray | None = ...,
195+
) -> None: ...
184196
def group_cummin(
185197
out: np.ndarray, # groupby_t[:, ::1]
186198
values: np.ndarray, # ndarray[groupby_t, ndim=2]

pandas/_libs/groupby.pyx

+145
Original file line numberDiff line numberDiff line change
@@ -1794,6 +1794,151 @@ cdef group_min_max(
17941794
)
17951795

17961796

1797+
@cython.wraparound(False)
1798+
@cython.boundscheck(False)
1799+
def group_idxmin_idxmax(
1800+
intp_t[:, ::1] out,
1801+
int64_t[::1] counts,
1802+
ndarray[numeric_object_t, ndim=2] values,
1803+
const intp_t[::1] labels,
1804+
Py_ssize_t min_count=-1,
1805+
bint is_datetimelike=False,
1806+
const uint8_t[:, ::1] mask=None,
1807+
str name="idxmin",
1808+
bint skipna=True,
1809+
uint8_t[:, ::1] result_mask=None,
1810+
):
1811+
"""
1812+
Compute index of minimum/maximum of columns of `values`, in row groups `labels`.
1813+
1814+
This function only computes the row number where the minimum/maximum occurs, we'll
1815+
take the corresponding index value after this function.
1816+
1817+
Parameters
1818+
----------
1819+
out : np.ndarray[intp, ndim=2]
1820+
Array to store result in.
1821+
counts : np.ndarray[int64]
1822+
Input as a zeroed array, populated by group sizes during algorithm
1823+
values : np.ndarray[numeric_object_t, ndim=2]
1824+
Values to find column-wise min/max of.
1825+
labels : np.ndarray[np.intp]
1826+
Labels to group by.
1827+
min_count : Py_ssize_t, default -1
1828+
The minimum number of non-NA group elements, NA result if threshold
1829+
is not met.
1830+
is_datetimelike : bool
1831+
True if `values` contains datetime-like entries.
1832+
name : {"idxmin", "idxmax"}, default "idxmin"
1833+
Whether to compute idxmin or idxmax.
1834+
mask : ndarray[bool, ndim=2], optional
1835+
If not None, indices represent missing values,
1836+
otherwise the mask will not be used
1837+
skipna : bool, default True
1838+
Flag to ignore nan values during truth testing
1839+
result_mask : ndarray[bool, ndim=2], optional
1840+
If not None, these specify locations in the output that are NA.
1841+
Modified in-place.
1842+
1843+
Notes
1844+
-----
1845+
This method modifies the `out` parameter, rather than returning an object.
1846+
`counts` is modified to hold group sizes
1847+
"""
1848+
cdef:
1849+
Py_ssize_t i, j, N, K, lab
1850+
numeric_object_t val
1851+
numeric_object_t[:, ::1] group_min_or_max
1852+
bint uses_mask = mask is not None
1853+
bint isna_entry
1854+
bint compute_max = name == "idxmax"
1855+
1856+
assert name == "idxmin" or name == "idxmax"
1857+
1858+
# TODO(cython3):
1859+
# Instead of `labels.shape[0]` use `len(labels)`
1860+
if not len(values) == labels.shape[0]:
1861+
raise AssertionError("len(index) != len(labels)")
1862+
1863+
N, K = (<object>values).shape
1864+
1865+
if numeric_object_t is object:
1866+
group_min_or_max = np.empty((<object>out).shape, dtype=object)
1867+
else:
1868+
group_min_or_max = np.empty_like(out, dtype=values.dtype)
1869+
if N > 0 and K > 0:
1870+
# When N or K is zero, we never use group_min_or_max
1871+
group_min_or_max[:] = _get_min_or_max(
1872+
values[0, 0], compute_max, is_datetimelike
1873+
)
1874+
1875+
# When using transform, we need a valid value for take in the case
1876+
# a category is not observed; these values will be dropped
1877+
out[:] = 0
1878+
1879+
# TODO(cython3): De-duplicate once conditional-nogil is available
1880+
if numeric_object_t is object:
1881+
for i in range(N):
1882+
lab = labels[i]
1883+
if lab < 0:
1884+
continue
1885+
1886+
for j in range(K):
1887+
if not skipna and out[lab, j] == -1:
1888+
# Once we've hit NA there is no going back
1889+
continue
1890+
val = values[i, j]
1891+
1892+
if uses_mask:
1893+
isna_entry = mask[i, j]
1894+
else:
1895+
# TODO(cython3): use _treat_as_na here
1896+
isna_entry = checknull(val)
1897+
1898+
if isna_entry:
1899+
if not skipna:
1900+
out[lab, j] = -1
1901+
else:
1902+
if compute_max:
1903+
if val > group_min_or_max[lab, j]:
1904+
group_min_or_max[lab, j] = val
1905+
out[lab, j] = i
1906+
else:
1907+
if val < group_min_or_max[lab, j]:
1908+
group_min_or_max[lab, j] = val
1909+
out[lab, j] = i
1910+
else:
1911+
with nogil:
1912+
for i in range(N):
1913+
lab = labels[i]
1914+
if lab < 0:
1915+
continue
1916+
1917+
for j in range(K):
1918+
if not skipna and out[lab, j] == -1:
1919+
# Once we've hit NA there is no going back
1920+
continue
1921+
val = values[i, j]
1922+
1923+
if uses_mask:
1924+
isna_entry = mask[i, j]
1925+
else:
1926+
isna_entry = _treat_as_na(val, is_datetimelike)
1927+
1928+
if isna_entry:
1929+
if not skipna:
1930+
out[lab, j] = -1
1931+
else:
1932+
if compute_max:
1933+
if val > group_min_or_max[lab, j]:
1934+
group_min_or_max[lab, j] = val
1935+
out[lab, j] = i
1936+
else:
1937+
if val < group_min_or_max[lab, j]:
1938+
group_min_or_max[lab, j] = val
1939+
out[lab, j] = i
1940+
1941+
17971942
@cython.wraparound(False)
17981943
@cython.boundscheck(False)
17991944
def group_max(

pandas/core/arrays/categorical.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -2701,12 +2701,22 @@ def _groupby_op(
27012701
dtype = self.dtype
27022702
if how in ["sum", "prod", "cumsum", "cumprod", "skew"]:
27032703
raise TypeError(f"{dtype} type does not support {how} operations")
2704-
if how in ["min", "max", "rank"] and not dtype.ordered:
2704+
if how in ["min", "max", "rank", "idxmin", "idxmax"] and not dtype.ordered:
27052705
# raise TypeError instead of NotImplementedError to ensure we
27062706
# don't go down a group-by-group path, since in the empty-groups
27072707
# case that would fail to raise
27082708
raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
2709-
if how not in ["rank", "any", "all", "first", "last", "min", "max"]:
2709+
if how not in [
2710+
"rank",
2711+
"any",
2712+
"all",
2713+
"first",
2714+
"last",
2715+
"min",
2716+
"max",
2717+
"idxmin",
2718+
"idxmax",
2719+
]:
27102720
if kind == "transform":
27112721
raise TypeError(f"{dtype} type does not support {how} operations")
27122722
raise TypeError(f"{dtype} dtype does not support aggregation '{how}'")
@@ -2716,7 +2726,7 @@ def _groupby_op(
27162726
if how == "rank":
27172727
assert self.ordered # checked earlier
27182728
npvalues = self._ndarray
2719-
elif how in ["first", "last", "min", "max"]:
2729+
elif how in ["first", "last", "min", "max", "idxmin", "idxmax"]:
27202730
npvalues = self._ndarray
27212731
result_mask = np.zeros(ngroups, dtype=bool)
27222732
else:

pandas/core/arrays/masked.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -1506,9 +1506,13 @@ def _groupby_op(
15061506
arity = op._cython_arity.get(op.how, 1)
15071507
result_mask = np.tile(result_mask, (arity, 1)).T
15081508

1509-
# res_values should already have the correct dtype, we just need to
1510-
# wrap in a MaskedArray
1511-
return self._maybe_mask_result(res_values, result_mask)
1509+
if op.how in ["idxmin", "idxmax"]:
1510+
# Result values are indexes to take, keep as ndarray
1511+
return res_values
1512+
else:
1513+
# res_values should already have the correct dtype, we just need to
1514+
# wrap in a MaskedArray
1515+
return self._maybe_mask_result(res_values, result_mask)
15121516

15131517

15141518
def transpose_homogeneous_masked_arrays(

pandas/core/frame.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@
6868
deprecate_nonkeyword_arguments,
6969
doc,
7070
)
71-
from pandas.util._exceptions import find_stack_level
71+
from pandas.util._exceptions import (
72+
find_stack_level,
73+
rewrite_warning,
74+
)
7275
from pandas.util._validators import (
7376
validate_ascending,
7477
validate_bool_kwarg,
@@ -11369,7 +11372,20 @@ def _get_data() -> DataFrame:
1136911372
row_index = np.tile(np.arange(nrows), ncols)
1137011373
col_index = np.repeat(np.arange(ncols), nrows)
1137111374
ser = Series(arr, index=col_index, copy=False)
11372-
result = ser.groupby(row_index).agg(name, **kwds)
11375+
# GroupBy will raise a warning with SeriesGroupBy as the object,
11376+
# likely confusing users
11377+
with rewrite_warning(
11378+
target_message=(
11379+
f"The behavior of SeriesGroupBy.{name} with all-NA values"
11380+
),
11381+
target_category=FutureWarning,
11382+
new_message=(
11383+
f"The behavior of {type(self).__name__}.{name} with all-NA "
11384+
"values, or any-NA and skipna=False, is deprecated. In "
11385+
"a future version this will raise ValueError"
11386+
),
11387+
):
11388+
result = ser.groupby(row_index).agg(name, **kwds)
1137311389
result.index = df.index
1137411390
if not skipna and name not in ("any", "all"):
1137511391
mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1)

0 commit comments

Comments
 (0)