Skip to content

PERF: Groupby.idxmax #52339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1819,6 +1819,70 @@ def group_min(
)


@cython.wraparound(False)
@cython.boundscheck(False)
def group_argmin_argmax(
int64_t[:, ::1] out,
int64_t[::1] counts,
ndarray[numeric_t, ndim=2] values,
const intp_t[::1] labels,
bint is_datetimelike=False,
const uint8_t[:, ::1] mask=None,
str name="argmin",
Py_ssize_t min_count=-1,
uint8_t[:, ::1] result_mask=None # ignored for now!
):
cdef:
Py_ssize_t i, j, N, K, lab
numeric_t val
numeric_t[:, ::1] group_min_or_max
int64_t[:, ::1] nobs
bint uses_mask = mask is not None
bint isna_entry
bint compute_max = name == "argmax"

assert min_count == -1, "'min_count' only used in sum and prod"

# TODO(cython3):
# Instead of `labels.shape[0]` use `len(labels)`
if not len(values) == labels.shape[0]:
raise AssertionError("len(index) != len(labels)")

nobs = np.zeros((<object>out).shape, dtype=np.int64)

group_min_or_max = np.empty((<object>out).shape, dtype=(<object>values).dtype)
group_min_or_max[:] = _get_min_or_max(<numeric_t>0, compute_max, is_datetimelike)
out[:] = -1

N, K = (<object>values).shape

with nogil:
for i in range(N):
lab = labels[i]
if lab < 0:
continue

counts[lab] += 1
for j in range(K):
val = values[i, j]

if uses_mask:
isna_entry = mask[i, j]
else:
isna_entry = _treat_as_na(val, is_datetimelike)

if not isna_entry:
nobs[lab, j] += 1
if compute_max:
if val > group_min_or_max[lab, j]:
group_min_or_max[lab, j] = val
out[lab, j] = i
else:
if val < group_min_or_max[lab, j]:
group_min_or_max[lab, j] = val
out[lab, j] = i


@cython.boundscheck(False)
@cython.wraparound(False)
cdef group_cummin_max(
Expand Down
71 changes: 71 additions & 0 deletions pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,13 +1138,51 @@ def nsmallest(
def idxmin(
self, axis: Axis | lib.NoDefault = lib.no_default, skipna: bool = True
) -> Series:
if axis is lib.no_default:
alt = lambda x: Series(x).argmin(skipna=skipna)
with com.temp_setattr(self, "observed", True):
argmin = self._cython_agg_general("argmin", alt=alt, skipna=skipna)
Comment on lines +1143 to +1144
Copy link
Member

@rhshadrach rhshadrach Jul 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a bit of an oddity with how missing observations are handled. When there are multiple groupings, we do not include the unobserved categories in e.g. grouper.result_index and fill in any unobserved ones in _wrap_aggregated_output. However, if there is just a single grouping we do include the unobserved categories in e.g. grouper.result_index and so we don't fill them in later on. This makes this approach not work for some cases of categoricals.

I plan to look into making it so we never include the unobserved categories until _wrap_aggreagted_output in the single grouping case. If we can make that work, this approach would work.


obj = self._obj_with_exclusions
res_col = obj.index.take(argmin._values)
if argmin.ndim == 1:
argmin = argmin._constructor(
res_col, index=argmin.index, name=argmin.name
)
else:
argmin = argmin._constructor(
res_col, index=argmin.index, columns=argmin.columns
)
res = self._wrap_agged_manager(argmin._mgr)
out = self._wrap_aggregated_output(res)
return out

result = self._op_via_apply("idxmin", axis=axis, skipna=skipna)
return result.astype(self.obj.index.dtype) if result.empty else result

@doc(Series.idxmax.__doc__)
def idxmax(
self, axis: Axis | lib.NoDefault = lib.no_default, skipna: bool = True
) -> Series:
if axis is lib.no_default:
alt = lambda x: Series(x).argmax(skipna=skipna)
with com.temp_setattr(self, "observed", True):
argmax = self._cython_agg_general("argmax", alt=alt, skipna=skipna)

obj = self._obj_with_exclusions
res_col = obj.index.take(argmax._values)
if argmax.ndim == 1:
argmax = argmax._constructor(
res_col, index=argmax.index, name=argmax.name
)
else:
argmax = argmax._constructor(
res_col, index=argmax.index, columns=argmax.columns
)
res = self._wrap_agged_manager(argmax._mgr)
out = self._wrap_aggregated_output(res)
return out

result = self._op_via_apply("idxmax", axis=axis, skipna=skipna)
return result.astype(self.obj.index.dtype) if result.empty else result

Expand Down Expand Up @@ -2042,6 +2080,22 @@ def idxmax(
Beef co2_emissions
dtype: object
"""
if self.axis == 0 and axis is lib.no_default:
alt = lambda x: Series(x).argmax(skipna=skipna)
with com.temp_setattr(self, "observed", True):
argmax = self._cython_agg_general(
"argmax", alt=alt, skipna=skipna, numeric_only=numeric_only
)

obj = self._obj_with_exclusions
for i in range(argmax.shape[1]):
res_col = obj.index.take(argmax.iloc[:, i]._values)
argmax.isetitem(i, res_col)

res = self._wrap_agged_manager(argmax._mgr)
out = self._wrap_aggregated_output(res)
return out

if axis is not lib.no_default:
if axis is None:
axis = self.axis
Expand All @@ -2057,6 +2111,7 @@ def func(df):
result = self._python_apply_general(
func, self._obj_with_exclusions, not_indexed_same=True
)

return result.astype(self.obj.index.dtype) if result.empty else result

def idxmin(
Expand Down Expand Up @@ -2137,6 +2192,22 @@ def idxmin(
Beef consumption
dtype: object
"""
if self.axis == 0 and axis is lib.no_default:
alt = lambda x: Series(x).argmin(skipna=skipna)
with com.temp_setattr(self, "observed", True):
argmin = self._cython_agg_general(
"argmin", alt=alt, skipna=skipna, numeric_only=numeric_only
)

obj = self._obj_with_exclusions
for i in range(argmin.shape[1]):
res_col = obj.index.take(argmin.iloc[:, i]._values)
argmin.isetitem(i, res_col)

res = self._wrap_agged_manager(argmin._mgr)
out = self._wrap_aggregated_output(res)
return out

if axis is not lib.no_default:
if axis is None:
axis = self.axis
Expand Down
54 changes: 49 additions & 5 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class WrappedCythonOp:
# Functions for which we do _not_ attempt to cast the cython result
# back to the original dtype.
cast_blocklist = frozenset(
["any", "all", "rank", "count", "size", "idxmin", "idxmax"]
["any", "all", "rank", "count", "size", "idxmin", "idxmax", "argmin", "argmax"]
)

def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
Expand All @@ -143,6 +143,8 @@ def __init__(self, kind: str, how: str, has_dropped_na: bool) -> None:
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
"argmin": functools.partial(libgroupby.group_argmin_argmax, name="argmin"),
"argmax": functools.partial(libgroupby.group_argmin_argmax, name="argmax"),
},
"transform": {
"cumprod": "group_cumprod",
Expand Down Expand Up @@ -180,6 +182,13 @@ def _get_cython_function(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
elif how in ["argmin", "argmax"]:
# We have a partial object that does not have __signatures__
# raise NotImplementedError here rather than TypeError later
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
elif how in ["std", "sem"]:
# We have a partial object that does not have __signatures__
return f
Expand Down Expand Up @@ -263,7 +272,17 @@ def _disallow_invalid_ops(self, dtype: DtypeObj):
# don't go down a group-by-group path, since in the empty-groups
# case that would fail to raise
raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
if how not in ["rank", "any", "all", "first", "last", "min", "max"]:
if how not in [
"rank",
"any",
"all",
"first",
"last",
"min",
"max",
"argmin",
"argmax",
]:
if self.kind == "transform":
raise TypeError(f"{dtype} type does not support {how} operations")
raise TypeError(f"{dtype} dtype does not support aggregation '{how}'")
Expand Down Expand Up @@ -323,7 +342,9 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
how = self.how

if how == "rank":
if how in ["argmax", "argmin"]:
out_dtype = "int64"
elif how == "rank":
out_dtype = "float64"
else:
if is_numeric_dtype(dtype):
Expand Down Expand Up @@ -381,7 +402,17 @@ def _ea_wrap_cython_operation(
)

elif isinstance(values, Categorical):
assert self.how in ["rank", "any", "all", "first", "last", "min", "max"]
assert self.how in [
"rank",
"any",
"all",
"first",
"last",
"min",
"max",
"argmin",
"argmax",
]
mask = values.isna()
if self.how == "rank":
assert values.ordered # checked earlier
Expand Down Expand Up @@ -613,7 +644,16 @@ def _call_cython_op(
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
if self.kind == "aggregate":
counts = np.zeros(ngroups, dtype=np.int64)
if self.how in ["min", "max", "mean", "last", "first", "sum"]:
if self.how in [
"min",
"max",
"mean",
"last",
"first",
"sum",
"argmin",
"argmax",
]:
func(
out=result,
counts=counts,
Expand Down Expand Up @@ -686,6 +726,10 @@ def _call_cython_op(
cutoff = max(0 if self.how in ["sum", "prod"] else 1, min_count)
empty_groups = counts < cutoff
if empty_groups.any():
if self.how in ["argmin", "argmax"]:
raise ValueError(
f"attempt to get {self.how} of an empty sequence"
)
if result_mask is not None:
assert result_mask[empty_groups].all()
else:
Expand Down
6 changes: 6 additions & 0 deletions pandas/tests/groupby/test_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -1860,6 +1860,12 @@ def test_category_order_reducer(
op_result = getattr(gb, reduction_func)(*args)
if as_index:
result = op_result.index.get_level_values("a").categories
elif reduction_func in ["idxmax", "idxmin"]:
# We don't expect to get Categorical back
exp = {
key: getattr(gb.get_group(key), reduction_func)(*args) for key in gb.groups
}
pd.concat(exp, axis=1)
else:
result = op_result["a"].cat.categories
expected = Index([1, 4, 3, 2])
Expand Down
1 change: 1 addition & 0 deletions pandas/tests/groupby/test_groupby_dropna.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ def test_categorical_reducers(
else:
values = [(np.nan, np.nan) if e == (4, 4) else e for e in values]
expected["y"] = values
expected["y"] = expected["y"].astype(df.index.dtype)
if reduction_func == "size":
# size, unlike other methods, has the desired behavior in GH#49519
expected = expected.rename(columns={0: "size"})
Expand Down
10 changes: 8 additions & 2 deletions pandas/tests/groupby/test_raises.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,14 @@ def test_groupby_raises_string(
"ffill": (None, ""),
"fillna": (None, ""),
"first": (None, ""),
"idxmax": (TypeError, "'argmax' not allowed for this dtype"),
"idxmin": (TypeError, "'argmin' not allowed for this dtype"),
"idxmax": (
TypeError,
"reduction operation 'argmax' not allowed for this dtype",
),
"idxmin": (
TypeError,
"reduction operation 'argmin' not allowed for this dtype",
),
"last": (None, ""),
"max": (None, ""),
"mean": (TypeError, "Could not convert xy?z?w?t?y?u?i?o? to numeric"),
Expand Down