Skip to content

Commit 4a072fa

Browse files
authored
ENH: Add numeric_only to certain groupby ops (#46728)
1 parent d5fcb40 commit 4a072fa

File tree

9 files changed

+182
-46
lines changed

9 files changed

+182
-46
lines changed

doc/source/whatsnew/v1.5.0.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ Other enhancements
120120
- :meth:`DataFrame.reset_index` now accepts a ``names`` argument which renames the index names (:issue:`6878`)
121121
- :meth:`pd.concat` now raises when ``levels`` is given but ``keys`` is None (:issue:`46653`)
122122
- :meth:`pd.concat` now raises when ``levels`` contains duplicate values (:issue:`46653`)
123-
- Added ``numeric_only`` argument to :meth:`DataFrame.corr`, :meth:`DataFrame.corrwith`, and :meth:`DataFrame.cov` (:issue:`46560`)
123+
- Added ``numeric_only`` argument to :meth:`DataFrame.corr`, :meth:`DataFrame.corrwith`, :meth:`DataFrame.cov`, :meth:`DataFrame.idxmin`, :meth:`DataFrame.idxmax`, :meth:`.GroupBy.idxmin`, :meth:`.GroupBy.idxmax`, :meth:`.GroupBy.var`, :meth:`.GroupBy.std`, :meth:`.GroupBy.sem`, and :meth:`.GroupBy.quantile` (:issue:`46560`)
124124
- A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`, :issue:`46725`)
125125
- Added ``validate`` argument to :meth:`DataFrame.join` (:issue:`46622`)
126126
- A :class:`errors.PerformanceWarning` is now thrown when using ``string[pyarrow]`` dtype with methods that don't dispatch to ``pyarrow.compute`` methods (:issue:`42613`)

pandas/core/frame.py

+23-10
Original file line numberDiff line numberDiff line change
@@ -10605,11 +10605,17 @@ def nunique(self, axis: Axis = 0, dropna: bool = True) -> Series:
1060510605
"""
1060610606
return self.apply(Series.nunique, axis=axis, dropna=dropna)
1060710607

10608-
@doc(_shared_docs["idxmin"])
10609-
def idxmin(self, axis: Axis = 0, skipna: bool = True) -> Series:
10608+
@doc(_shared_docs["idxmin"], numeric_only_default="False")
10609+
def idxmin(
10610+
self, axis: Axis = 0, skipna: bool = True, numeric_only: bool = False
10611+
) -> Series:
1061010612
axis = self._get_axis_number(axis)
10613+
if numeric_only:
10614+
data = self._get_numeric_data()
10615+
else:
10616+
data = self
1061110617

10612-
res = self._reduce(
10618+
res = data._reduce(
1061310619
nanops.nanargmin, "argmin", axis=axis, skipna=skipna, numeric_only=False
1061410620
)
1061510621
indices = res._values
@@ -10619,15 +10625,22 @@ def idxmin(self, axis: Axis = 0, skipna: bool = True) -> Series:
1061910625
# error: Item "int" of "Union[int, Any]" has no attribute "__iter__"
1062010626
assert isinstance(indices, np.ndarray) # for mypy
1062110627

10622-
index = self._get_axis(axis)
10628+
index = data._get_axis(axis)
1062310629
result = [index[i] if i >= 0 else np.nan for i in indices]
10624-
return self._constructor_sliced(result, index=self._get_agg_axis(axis))
10630+
return data._constructor_sliced(result, index=data._get_agg_axis(axis))
10631+
10632+
@doc(_shared_docs["idxmax"], numeric_only_default="False")
10633+
def idxmax(
10634+
self, axis: Axis = 0, skipna: bool = True, numeric_only: bool = False
10635+
) -> Series:
1062510636

10626-
@doc(_shared_docs["idxmax"])
10627-
def idxmax(self, axis: Axis = 0, skipna: bool = True) -> Series:
1062810637
axis = self._get_axis_number(axis)
10638+
if numeric_only:
10639+
data = self._get_numeric_data()
10640+
else:
10641+
data = self
1062910642

10630-
res = self._reduce(
10643+
res = data._reduce(
1063110644
nanops.nanargmax, "argmax", axis=axis, skipna=skipna, numeric_only=False
1063210645
)
1063310646
indices = res._values
@@ -10637,9 +10650,9 @@ def idxmax(self, axis: Axis = 0, skipna: bool = True) -> Series:
1063710650
# error: Item "int" of "Union[int, Any]" has no attribute "__iter__"
1063810651
assert isinstance(indices, np.ndarray) # for mypy
1063910652

10640-
index = self._get_axis(axis)
10653+
index = data._get_axis(axis)
1064110654
result = [index[i] if i >= 0 else np.nan for i in indices]
10642-
return self._constructor_sliced(result, index=self._get_agg_axis(axis))
10655+
return data._constructor_sliced(result, index=data._get_agg_axis(axis))
1064310656

1064410657
def _get_agg_axis(self, axis_num: int) -> Index:
1064510658
"""

pandas/core/groupby/generic.py

+15-7
Original file line numberDiff line numberDiff line change
@@ -1555,10 +1555,14 @@ def nunique(self, dropna: bool = True) -> DataFrame:
15551555

15561556
return results
15571557

1558-
@doc(_shared_docs["idxmax"])
1559-
def idxmax(self, axis=0, skipna: bool = True):
1558+
@doc(
1559+
_shared_docs["idxmax"],
1560+
numeric_only_default="True for axis=0, False for axis=1",
1561+
)
1562+
def idxmax(self, axis=0, skipna: bool = True, numeric_only: bool | None = None):
15601563
axis = DataFrame._get_axis_number(axis)
1561-
numeric_only = None if axis == 0 else False
1564+
if numeric_only is None:
1565+
numeric_only = None if axis == 0 else False
15621566

15631567
def func(df):
15641568
# NB: here we use numeric_only=None, in DataFrame it is False GH#38217
@@ -1577,13 +1581,17 @@ def func(df):
15771581
func.__name__ = "idxmax"
15781582
return self._python_apply_general(func, self._obj_with_exclusions)
15791583

1580-
@doc(_shared_docs["idxmin"])
1581-
def idxmin(self, axis=0, skipna: bool = True):
1584+
@doc(
1585+
_shared_docs["idxmin"],
1586+
numeric_only_default="True for axis=0, False for axis=1",
1587+
)
1588+
def idxmin(self, axis=0, skipna: bool = True, numeric_only: bool | None = None):
15821589
axis = DataFrame._get_axis_number(axis)
1583-
numeric_only = None if axis == 0 else False
1590+
if numeric_only is None:
1591+
numeric_only = None if axis == 0 else False
15841592

15851593
def func(df):
1586-
# NB: here we use numeric_only=None, in DataFrame it is False GH#38217
1594+
# NB: here we use numeric_only=None, in DataFrame it is False GH#46560
15871595
res = df._reduce(
15881596
nanops.nanargmin,
15891597
"argmin",

pandas/core/groupby/groupby.py

+57-13
Original file line numberDiff line numberDiff line change
@@ -1502,7 +1502,7 @@ def _python_apply_general(
15021502
)
15031503

15041504
@final
1505-
def _python_agg_general(self, func, *args, **kwargs):
1505+
def _python_agg_general(self, func, *args, raise_on_typeerror=False, **kwargs):
15061506
func = com.is_builtin_func(func)
15071507
f = lambda x: func(x, *args, **kwargs)
15081508

@@ -1520,6 +1520,8 @@ def _python_agg_general(self, func, *args, **kwargs):
15201520
# if this function is invalid for this dtype, we will ignore it.
15211521
result = self.grouper.agg_series(obj, f)
15221522
except TypeError:
1523+
if raise_on_typeerror:
1524+
raise
15231525
warn_dropping_nuisance_columns_deprecated(type(self), "agg")
15241526
continue
15251527

@@ -1593,7 +1595,12 @@ def _agg_py_fallback(
15931595

15941596
@final
15951597
def _cython_agg_general(
1596-
self, how: str, alt: Callable, numeric_only: bool, min_count: int = -1
1598+
self,
1599+
how: str,
1600+
alt: Callable,
1601+
numeric_only: bool,
1602+
min_count: int = -1,
1603+
ignore_failures: bool = True,
15971604
):
15981605
# Note: we never get here with how="ohlc" for DataFrameGroupBy;
15991606
# that goes through SeriesGroupBy
@@ -1629,7 +1636,7 @@ def array_func(values: ArrayLike) -> ArrayLike:
16291636

16301637
# TypeError -> we may have an exception in trying to aggregate
16311638
# continue and exclude the block
1632-
new_mgr = data.grouped_reduce(array_func, ignore_failures=True)
1639+
new_mgr = data.grouped_reduce(array_func, ignore_failures=ignore_failures)
16331640

16341641
if not is_ser and len(new_mgr) < len(data):
16351642
warn_dropping_nuisance_columns_deprecated(type(self), how)
@@ -2041,6 +2048,7 @@ def std(
20412048
ddof: int = 1,
20422049
engine: str | None = None,
20432050
engine_kwargs: dict[str, bool] | None = None,
2051+
numeric_only: bool | lib.NoDefault = lib.no_default,
20442052
):
20452053
"""
20462054
Compute standard deviation of groups, excluding missing values.
@@ -2069,6 +2077,11 @@ def std(
20692077
20702078
.. versionadded:: 1.4.0
20712079
2080+
numeric_only : bool, default True
2081+
Include only `float`, `int` or `boolean` data.
2082+
2083+
.. versionadded:: 1.5.0
2084+
20722085
Returns
20732086
-------
20742087
Series or DataFrame
@@ -2081,8 +2094,9 @@ def std(
20812094
else:
20822095
return self._get_cythonized_result(
20832096
libgroupby.group_var,
2084-
needs_counts=True,
20852097
cython_dtype=np.dtype(np.float64),
2098+
numeric_only=numeric_only,
2099+
needs_counts=True,
20862100
post_processing=lambda vals, inference: np.sqrt(vals),
20872101
ddof=ddof,
20882102
)
@@ -2095,6 +2109,7 @@ def var(
20952109
ddof: int = 1,
20962110
engine: str | None = None,
20972111
engine_kwargs: dict[str, bool] | None = None,
2112+
numeric_only: bool | lib.NoDefault = lib.no_default,
20982113
):
20992114
"""
21002115
Compute variance of groups, excluding missing values.
@@ -2123,6 +2138,11 @@ def var(
21232138
21242139
.. versionadded:: 1.4.0
21252140
2141+
numeric_only : bool, default True
2142+
Include only `float`, `int` or `boolean` data.
2143+
2144+
.. versionadded:: 1.5.0
2145+
21262146
Returns
21272147
-------
21282148
Series or DataFrame
@@ -2133,22 +2153,25 @@ def var(
21332153

21342154
return self._numba_agg_general(sliding_var, engine_kwargs, ddof)
21352155
else:
2156+
numeric_only_bool = self._resolve_numeric_only(numeric_only)
21362157
if ddof == 1:
2137-
numeric_only = self._resolve_numeric_only(lib.no_default)
21382158
return self._cython_agg_general(
21392159
"var",
21402160
alt=lambda x: Series(x).var(ddof=ddof),
2141-
numeric_only=numeric_only,
2161+
numeric_only=numeric_only_bool,
2162+
ignore_failures=numeric_only is lib.no_default,
21422163
)
21432164
else:
21442165
func = lambda x: x.var(ddof=ddof)
21452166
with self._group_selection_context():
2146-
return self._python_agg_general(func)
2167+
return self._python_agg_general(
2168+
func, raise_on_typeerror=not numeric_only_bool
2169+
)
21472170

21482171
@final
21492172
@Substitution(name="groupby")
21502173
@Appender(_common_see_also)
2151-
def sem(self, ddof: int = 1):
2174+
def sem(self, ddof: int = 1, numeric_only: bool | lib.NoDefault = lib.no_default):
21522175
"""
21532176
Compute standard error of the mean of groups, excluding missing values.
21542177
@@ -2159,12 +2182,17 @@ def sem(self, ddof: int = 1):
21592182
ddof : int, default 1
21602183
Degrees of freedom.
21612184
2185+
numeric_only : bool, default True
2186+
Include only `float`, `int` or `boolean` data.
2187+
2188+
.. versionadded:: 1.5.0
2189+
21622190
Returns
21632191
-------
21642192
Series or DataFrame
21652193
Standard error of the mean of values within each group.
21662194
"""
2167-
result = self.std(ddof=ddof)
2195+
result = self.std(ddof=ddof, numeric_only=numeric_only)
21682196
if result.ndim == 1:
21692197
result /= np.sqrt(self.count())
21702198
else:
@@ -2979,7 +3007,12 @@ def nth(
29793007
return result
29803008

29813009
@final
2982-
def quantile(self, q=0.5, interpolation: str = "linear"):
3010+
def quantile(
3011+
self,
3012+
q=0.5,
3013+
interpolation: str = "linear",
3014+
numeric_only: bool | lib.NoDefault = lib.no_default,
3015+
):
29833016
"""
29843017
Return group values at the given quantile, a la numpy.percentile.
29853018
@@ -2989,6 +3022,10 @@ def quantile(self, q=0.5, interpolation: str = "linear"):
29893022
Value(s) between 0 and 1 providing the quantile(s) to compute.
29903023
interpolation : {'linear', 'lower', 'higher', 'midpoint', 'nearest'}
29913024
Method to use when the desired quantile falls between two points.
3025+
numeric_only : bool, default True
3026+
Include only `float`, `int` or `boolean` data.
3027+
3028+
.. versionadded:: 1.5.0
29923029
29933030
Returns
29943031
-------
@@ -3013,6 +3050,7 @@ def quantile(self, q=0.5, interpolation: str = "linear"):
30133050
a 2.0
30143051
b 3.0
30153052
"""
3053+
numeric_only_bool = self._resolve_numeric_only(numeric_only)
30163054

30173055
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:
30183056
if is_object_dtype(vals):
@@ -3106,9 +3144,15 @@ def blk_func(values: ArrayLike) -> ArrayLike:
31063144
obj = self._obj_with_exclusions
31073145
is_ser = obj.ndim == 1
31083146
mgr = self._get_data_to_aggregate()
3109-
3110-
res_mgr = mgr.grouped_reduce(blk_func, ignore_failures=True)
3111-
if not is_ser and len(res_mgr.items) != len(mgr.items):
3147+
data = mgr.get_numeric_data() if numeric_only_bool else mgr
3148+
ignore_failures = numeric_only_bool
3149+
res_mgr = data.grouped_reduce(blk_func, ignore_failures=ignore_failures)
3150+
3151+
if (
3152+
numeric_only is lib.no_default
3153+
and not is_ser
3154+
and len(res_mgr.items) != len(mgr.items)
3155+
):
31123156
warn_dropping_nuisance_columns_deprecated(type(self), "quantile")
31133157

31143158
if len(res_mgr.items) == 0:

pandas/core/shared_docs.py

+8
Original file line numberDiff line numberDiff line change
@@ -749,6 +749,10 @@
749749
skipna : bool, default True
750750
Exclude NA/null values. If an entire row/column is NA, the result
751751
will be NA.
752+
numeric_only : bool, default {numeric_only_default}
753+
Include only `float`, `int` or `boolean` data.
754+
755+
.. versionadded:: 1.5.0
752756
753757
Returns
754758
-------
@@ -812,6 +816,10 @@
812816
skipna : bool, default True
813817
Exclude NA/null values. If an entire row/column is NA, the result
814818
will be NA.
819+
numeric_only : bool, default {numeric_only_default}
820+
Include only `float`, `int` or `boolean` data.
821+
822+
.. versionadded:: 1.5.0
815823
816824
Returns
817825
-------

pandas/tests/frame/test_reductions.py

+22
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,17 @@ def test_idxmin(self, float_frame, int_frame, skipna, axis):
897897
expected = df.apply(Series.idxmin, axis=axis, skipna=skipna)
898898
tm.assert_series_equal(result, expected)
899899

900+
@pytest.mark.parametrize("numeric_only", [True, False])
901+
def test_idxmin_numeric_only(self, numeric_only):
902+
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
903+
if numeric_only:
904+
result = df.idxmin(numeric_only=numeric_only)
905+
expected = Series([2, 1], index=["a", "b"])
906+
tm.assert_series_equal(result, expected)
907+
else:
908+
with pytest.raises(TypeError, match="not allowed for this dtype"):
909+
df.idxmin(numeric_only=numeric_only)
910+
900911
def test_idxmin_axis_2(self, float_frame):
901912
frame = float_frame
902913
msg = "No axis named 2 for object type DataFrame"
@@ -914,6 +925,17 @@ def test_idxmax(self, float_frame, int_frame, skipna, axis):
914925
expected = df.apply(Series.idxmax, axis=axis, skipna=skipna)
915926
tm.assert_series_equal(result, expected)
916927

928+
@pytest.mark.parametrize("numeric_only", [True, False])
929+
def test_idxmax_numeric_only(self, numeric_only):
930+
df = DataFrame({"a": [2, 3, 1], "b": [2, 1, 1], "c": list("xyx")})
931+
if numeric_only:
932+
result = df.idxmax(numeric_only=numeric_only)
933+
expected = Series([1, 0], index=["a", "b"])
934+
tm.assert_series_equal(result, expected)
935+
else:
936+
with pytest.raises(TypeError, match="not allowed for this dtype"):
937+
df.idxmin(numeric_only=numeric_only)
938+
917939
def test_idxmax_axis_2(self, float_frame):
918940
frame = float_frame
919941
msg = "No axis named 2 for object type DataFrame"

pandas/tests/groupby/test_function.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -495,8 +495,9 @@ def test_groupby_non_arithmetic_agg_int_like_precision(i):
495495
("idxmax", {"c_int": [1, 3], "c_float": [0, 2], "c_date": [0, 3]}),
496496
],
497497
)
498+
@pytest.mark.parametrize("numeric_only", [True, False])
498499
@pytest.mark.filterwarnings("ignore:.*Select only valid:FutureWarning")
499-
def test_idxmin_idxmax_returns_int_types(func, values):
500+
def test_idxmin_idxmax_returns_int_types(func, values, numeric_only):
500501
# GH 25444
501502
df = DataFrame(
502503
{
@@ -513,12 +514,15 @@ def test_idxmin_idxmax_returns_int_types(func, values):
513514
df["c_Integer"] = df["c_int"].astype("Int64")
514515
df["c_Floating"] = df["c_float"].astype("Float64")
515516

516-
result = getattr(df.groupby("name"), func)()
517+
result = getattr(df.groupby("name"), func)(numeric_only=numeric_only)
517518

518519
expected = DataFrame(values, index=Index(["A", "B"], name="name"))
519-
expected["c_date_tz"] = expected["c_date"]
520-
expected["c_timedelta"] = expected["c_date"]
521-
expected["c_period"] = expected["c_date"]
520+
if numeric_only:
521+
expected = expected.drop(columns=["c_date"])
522+
else:
523+
expected["c_date_tz"] = expected["c_date"]
524+
expected["c_timedelta"] = expected["c_date"]
525+
expected["c_period"] = expected["c_date"]
522526
expected["c_Integer"] = expected["c_int"]
523527
expected["c_Floating"] = expected["c_float"]
524528

0 commit comments

Comments
 (0)