Skip to content

Commit b1688ec

Browse files
rhshadrachyehoshuadimarsky
authored andcommitted
BUG: Fix issues with numeric_only deprecation (pandas-dev#47481)
* BUG: Fix issues with numeric_only deprecation * improve test * Change raising to FutureWarning * revert deprecation * Test improvement
1 parent 7e4770f commit b1688ec

File tree

3 files changed

+121
-27
lines changed

3 files changed

+121
-27
lines changed

pandas/core/groupby/generic.py

+28-22
Original file line numberDiff line numberDiff line change
@@ -1610,17 +1610,20 @@ def idxmax(
16101610
numeric_only_arg = numeric_only
16111611

16121612
def func(df):
1613-
res = df._reduce(
1614-
nanops.nanargmax,
1615-
"argmax",
1616-
axis=axis,
1617-
skipna=skipna,
1618-
numeric_only=numeric_only_arg,
1619-
)
1620-
indices = res._values
1621-
index = df._get_axis(axis)
1622-
result = [index[i] if i >= 0 else np.nan for i in indices]
1623-
return df._constructor_sliced(result, index=res.index)
1613+
with warnings.catch_warnings():
1614+
# Suppress numeric_only warnings here, will warn below
1615+
warnings.filterwarnings("ignore", ".*numeric_only in DataFrame.argmax")
1616+
res = df._reduce(
1617+
nanops.nanargmax,
1618+
"argmax",
1619+
axis=axis,
1620+
skipna=skipna,
1621+
numeric_only=numeric_only_arg,
1622+
)
1623+
indices = res._values
1624+
index = df._get_axis(axis)
1625+
result = [index[i] if i >= 0 else np.nan for i in indices]
1626+
return df._constructor_sliced(result, index=res.index)
16241627

16251628
func.__name__ = "idxmax"
16261629
result = self._python_apply_general(func, self._obj_with_exclusions)
@@ -1646,17 +1649,20 @@ def idxmin(
16461649
numeric_only_arg = numeric_only
16471650

16481651
def func(df):
1649-
res = df._reduce(
1650-
nanops.nanargmin,
1651-
"argmin",
1652-
axis=axis,
1653-
skipna=skipna,
1654-
numeric_only=numeric_only_arg,
1655-
)
1656-
indices = res._values
1657-
index = df._get_axis(axis)
1658-
result = [index[i] if i >= 0 else np.nan for i in indices]
1659-
return df._constructor_sliced(result, index=res.index)
1652+
with warnings.catch_warnings():
1653+
# Suppress numeric_only warnings here, will warn below
1654+
warnings.filterwarnings("ignore", ".*numeric_only in DataFrame.argmin")
1655+
res = df._reduce(
1656+
nanops.nanargmin,
1657+
"argmin",
1658+
axis=axis,
1659+
skipna=skipna,
1660+
numeric_only=numeric_only_arg,
1661+
)
1662+
indices = res._values
1663+
index = df._get_axis(axis)
1664+
result = [index[i] if i >= 0 else np.nan for i in indices]
1665+
return df._constructor_sliced(result, index=res.index)
16601666

16611667
func.__name__ = "idxmin"
16621668
result = self._python_apply_general(func, self._obj_with_exclusions)

pandas/core/groupby/groupby.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -1716,8 +1716,9 @@ def _cython_agg_general(
17161716
kwd_name = "numeric_only"
17171717
if how in ["any", "all"]:
17181718
kwd_name = "bool_only"
1719+
kernel = "sum" if how == "add" else how
17191720
raise NotImplementedError(
1720-
f"{type(self).__name__}.{how} does not implement {kwd_name}."
1721+
f"{type(self).__name__}.{kernel} does not implement {kwd_name}."
17211722
)
17221723
elif not is_ser:
17231724
data = data.get_numeric_data(copy=False)
@@ -2194,10 +2195,16 @@ def std(
21942195

21952196
return np.sqrt(self._numba_agg_general(sliding_var, engine_kwargs, ddof))
21962197
else:
2198+
# Resolve numeric_only so that var doesn't warn
2199+
numeric_only_bool = self._resolve_numeric_only(numeric_only, axis=0)
2200+
if numeric_only_bool and self.obj.ndim == 1:
2201+
raise NotImplementedError(
2202+
f"{type(self).__name__}.std does not implement numeric_only."
2203+
)
21972204
result = self._get_cythonized_result(
21982205
libgroupby.group_var,
21992206
cython_dtype=np.dtype(np.float64),
2200-
numeric_only=numeric_only,
2207+
numeric_only=numeric_only_bool,
22012208
needs_counts=True,
22022209
post_processing=lambda vals, inference: np.sqrt(vals),
22032210
ddof=ddof,
@@ -2296,7 +2303,13 @@ def sem(self, ddof: int = 1, numeric_only: bool | lib.NoDefault = lib.no_default
22962303
Series or DataFrame
22972304
Standard error of the mean of values within each group.
22982305
"""
2299-
result = self.std(ddof=ddof, numeric_only=numeric_only)
2306+
# Reolve numeric_only so that std doesn't warn
2307+
numeric_only_bool = self._resolve_numeric_only(numeric_only, axis=0)
2308+
if numeric_only_bool and self.obj.ndim == 1:
2309+
raise NotImplementedError(
2310+
f"{type(self).__name__}.sem does not implement numeric_only."
2311+
)
2312+
result = self.std(ddof=ddof, numeric_only=numeric_only_bool)
23002313
self._maybe_warn_numeric_only_depr("sem", result, numeric_only)
23012314

23022315
if result.ndim == 1:
@@ -3167,6 +3180,10 @@ def quantile(
31673180
b 3.0
31683181
"""
31693182
numeric_only_bool = self._resolve_numeric_only(numeric_only, axis=0)
3183+
if numeric_only_bool and self.obj.ndim == 1:
3184+
raise NotImplementedError(
3185+
f"{type(self).__name__}.quantile does not implement numeric_only"
3186+
)
31703187

31713188
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:
31723189
if is_object_dtype(vals):

pandas/tests/groupby/test_function.py

+73-2
Original file line numberDiff line numberDiff line change
@@ -306,7 +306,7 @@ def test_idxmax(self, gb):
306306
# non-cython calls should not include the grouper
307307
expected = DataFrame([[0.0], [np.nan]], columns=["B"], index=[1, 3])
308308
expected.index.name = "A"
309-
msg = "The default value of numeric_only"
309+
msg = "The default value of numeric_only in DataFrameGroupBy.idxmax"
310310
with tm.assert_produces_warning(FutureWarning, match=msg):
311311
result = gb.idxmax()
312312
tm.assert_frame_equal(result, expected)
@@ -317,7 +317,7 @@ def test_idxmin(self, gb):
317317
# non-cython calls should not include the grouper
318318
expected = DataFrame([[0.0], [np.nan]], columns=["B"], index=[1, 3])
319319
expected.index.name = "A"
320-
msg = "The default value of numeric_only"
320+
msg = "The default value of numeric_only in DataFrameGroupBy.idxmin"
321321
with tm.assert_produces_warning(FutureWarning, match=msg):
322322
result = gb.idxmin()
323323
tm.assert_frame_equal(result, expected)
@@ -1356,6 +1356,77 @@ def test_deprecate_numeric_only(
13561356
method(*args, **kwargs)
13571357

13581358

1359+
def test_deprecate_numeric_only_series(groupby_func, request):
1360+
# GH#46560
1361+
if groupby_func in ("backfill", "mad", "pad", "tshift"):
1362+
pytest.skip("method is deprecated")
1363+
elif groupby_func == "corrwith":
1364+
msg = "corrwith is not implemented on SeriesGroupBy"
1365+
request.node.add_marker(pytest.mark.xfail(reason=msg))
1366+
1367+
ser = Series(list("xyz"))
1368+
gb = ser.groupby([0, 0, 1])
1369+
1370+
if groupby_func == "corrwith":
1371+
args = (ser,)
1372+
elif groupby_func == "corr":
1373+
args = (ser,)
1374+
elif groupby_func == "cov":
1375+
args = (ser,)
1376+
elif groupby_func == "nth":
1377+
args = (0,)
1378+
elif groupby_func == "fillna":
1379+
args = (True,)
1380+
elif groupby_func == "take":
1381+
args = ([0],)
1382+
elif groupby_func == "quantile":
1383+
args = (0.5,)
1384+
else:
1385+
args = ()
1386+
method = getattr(gb, groupby_func)
1387+
1388+
try:
1389+
_ = method(*args)
1390+
except (TypeError, ValueError) as err:
1391+
# ops that only work on numeric dtypes
1392+
assert groupby_func in (
1393+
"corr",
1394+
"cov",
1395+
"cummax",
1396+
"cummin",
1397+
"cumprod",
1398+
"cumsum",
1399+
"diff",
1400+
"idxmax",
1401+
"idxmin",
1402+
"mean",
1403+
"median",
1404+
"pct_change",
1405+
"prod",
1406+
"quantile",
1407+
"sem",
1408+
"skew",
1409+
"std",
1410+
"var",
1411+
)
1412+
assert (
1413+
"could not convert" in str(err).lower()
1414+
or "unsupported operand type" in str(err)
1415+
or "not allowed for this dtype" in str(err)
1416+
or "can't multiply sequence by non-int" in str(err)
1417+
or "cannot be performed against 'object' dtypes" in str(err)
1418+
or "is not supported for object dtype" in str(err)
1419+
), str(err)
1420+
1421+
msgs = (
1422+
"got an unexpected keyword argument 'numeric_only'",
1423+
f"{groupby_func} does not implement numeric_only",
1424+
f"{groupby_func} is not supported for object dtype",
1425+
)
1426+
with pytest.raises((NotImplementedError, TypeError), match=f"({'|'.join(msgs)})"):
1427+
_ = method(*args, numeric_only=True)
1428+
1429+
13591430
@pytest.mark.parametrize("dtype", [int, float, object])
13601431
@pytest.mark.parametrize(
13611432
"kwargs",

0 commit comments

Comments
 (0)