Skip to content

Commit ad7dcef

Browse files
authored
BUG: numeric_only with axis=1 in DataFrame.corrwith and DataFrameGroupBy.cummin/max (#47724)
* BUG: DataFrame.corrwith and DataFrameGroupBy.cummin/cummax with numeric_only=True * test improvements
1 parent e3698dd commit ad7dcef

File tree

3 files changed

+97
-8
lines changed

3 files changed

+97
-8
lines changed

pandas/core/frame.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -10560,7 +10560,8 @@ def corrwith(
1056010560
else:
1056110561
return this.apply(lambda x: other.corr(x, method=method), axis=axis)
1056210562

10563-
other = other._get_numeric_data()
10563+
if numeric_only_bool:
10564+
other = other._get_numeric_data()
1056410565
left, right = this.align(other, join="inner", copy=False)
1056510566

1056610567
if axis == 1:
@@ -10573,11 +10574,15 @@ def corrwith(
1057310574
right = right + left * 0
1057410575

1057510576
# demeaned data
10576-
ldem = left - left.mean()
10577-
rdem = right - right.mean()
10577+
ldem = left - left.mean(numeric_only=numeric_only_bool)
10578+
rdem = right - right.mean(numeric_only=numeric_only_bool)
1057810579

1057910580
num = (ldem * rdem).sum()
10580-
dom = (left.count() - 1) * left.std() * right.std()
10581+
dom = (
10582+
(left.count() - 1)
10583+
* left.std(numeric_only=numeric_only_bool)
10584+
* right.std(numeric_only=numeric_only_bool)
10585+
)
1058110586

1058210587
correl = num / dom
1058310588

pandas/core/groupby/groupby.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -3630,7 +3630,11 @@ def cummin(self, axis=0, numeric_only=False, **kwargs) -> NDFrameT:
36303630
skipna = kwargs.get("skipna", True)
36313631
if axis != 0:
36323632
f = lambda x: np.minimum.accumulate(x, axis)
3633-
return self._python_apply_general(f, self._selected_obj, is_transform=True)
3633+
numeric_only_bool = self._resolve_numeric_only("cummax", numeric_only, axis)
3634+
obj = self._selected_obj
3635+
if numeric_only_bool:
3636+
obj = obj._get_numeric_data()
3637+
return self._python_apply_general(f, obj, is_transform=True)
36343638

36353639
return self._cython_transform(
36363640
"cummin", numeric_only=numeric_only, skipna=skipna
@@ -3650,7 +3654,11 @@ def cummax(self, axis=0, numeric_only=False, **kwargs) -> NDFrameT:
36503654
skipna = kwargs.get("skipna", True)
36513655
if axis != 0:
36523656
f = lambda x: np.maximum.accumulate(x, axis)
3653-
return self._python_apply_general(f, self._selected_obj, is_transform=True)
3657+
numeric_only_bool = self._resolve_numeric_only("cummax", numeric_only, axis)
3658+
obj = self._selected_obj
3659+
if numeric_only_bool:
3660+
obj = obj._get_numeric_data()
3661+
return self._python_apply_general(f, obj, is_transform=True)
36543662

36553663
return self._cython_transform(
36563664
"cummax", numeric_only=numeric_only, skipna=skipna

pandas/tests/groupby/test_function.py

+78-2
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,81 @@ def test_idxmin_idxmax_axis1():
555555
gb2.idxmax(axis=1)
556556

557557

558+
@pytest.mark.parametrize("numeric_only", [True, False, None])
559+
def test_axis1_numeric_only(request, groupby_func, numeric_only):
560+
if groupby_func in ("idxmax", "idxmin"):
561+
pytest.skip("idxmax and idx_min tested in test_idxmin_idxmax_axis1")
562+
if groupby_func in ("mad", "tshift"):
563+
pytest.skip("mad and tshift are deprecated")
564+
if groupby_func in ("corrwith", "skew"):
565+
msg = "GH#47723 groupby.corrwith and skew do not correctly implement axis=1"
566+
request.node.add_marker(pytest.mark.xfail(reason=msg))
567+
568+
df = DataFrame(np.random.randn(10, 4), columns=["A", "B", "C", "D"])
569+
df["E"] = "x"
570+
groups = [1, 2, 3, 1, 2, 3, 1, 2, 3, 4]
571+
gb = df.groupby(groups)
572+
method = getattr(gb, groupby_func)
573+
args = (0,) if groupby_func == "fillna" else ()
574+
kwargs = {"axis": 1}
575+
if numeric_only is not None:
576+
# when numeric_only is None we don't pass any argument
577+
kwargs["numeric_only"] = numeric_only
578+
579+
# Functions without numeric_only and axis args
580+
no_args = ("cumprod", "cumsum", "diff", "fillna", "pct_change", "rank", "shift")
581+
# Functions with axis args
582+
has_axis = (
583+
"cumprod",
584+
"cumsum",
585+
"diff",
586+
"pct_change",
587+
"rank",
588+
"shift",
589+
"cummax",
590+
"cummin",
591+
"idxmin",
592+
"idxmax",
593+
"fillna",
594+
)
595+
if numeric_only is not None and groupby_func in no_args:
596+
msg = "got an unexpected keyword argument 'numeric_only'"
597+
with pytest.raises(TypeError, match=msg):
598+
method(*args, **kwargs)
599+
elif groupby_func not in has_axis:
600+
msg = "got an unexpected keyword argument 'axis'"
601+
warn = FutureWarning if groupby_func == "skew" and not numeric_only else None
602+
with tm.assert_produces_warning(warn, match="Dropping of nuisance columns"):
603+
with pytest.raises(TypeError, match=msg):
604+
method(*args, **kwargs)
605+
# fillna and shift are successful even on object dtypes
606+
elif (numeric_only is None or not numeric_only) and groupby_func not in (
607+
"fillna",
608+
"shift",
609+
):
610+
msgs = (
611+
# cummax, cummin, rank
612+
"not supported between instances of",
613+
# cumprod
614+
"can't multiply sequence by non-int of type 'float'",
615+
# cumsum, diff, pct_change
616+
"unsupported operand type",
617+
)
618+
with pytest.raises(TypeError, match=f"({'|'.join(msgs)})"):
619+
method(*args, **kwargs)
620+
else:
621+
result = method(*args, **kwargs)
622+
623+
df_expected = df.drop(columns="E").T if numeric_only else df.T
624+
expected = getattr(df_expected, groupby_func)(*args).T
625+
if groupby_func == "shift" and not numeric_only:
626+
# shift with axis=1 leaves the leftmost column as numeric
627+
# but transposing for expected gives us object dtype
628+
expected = expected.astype(float)
629+
630+
tm.assert_equal(result, expected)
631+
632+
558633
def test_groupby_cumprod():
559634
# GH 4095
560635
df = DataFrame({"key": ["b"] * 10, "value": 2})
@@ -1321,7 +1396,7 @@ def test_deprecate_numeric_only(
13211396
assert "b" not in result.columns
13221397
elif (
13231398
# kernels that work on any dtype and have numeric_only arg
1324-
kernel in ("first", "last", "corrwith")
1399+
kernel in ("first", "last")
13251400
or (
13261401
# kernels that work on any dtype and don't have numeric_only arg
13271402
kernel in ("any", "all", "bfill", "ffill", "fillna", "nth", "nunique")
@@ -1339,7 +1414,8 @@ def test_deprecate_numeric_only(
13391414
"(not allowed for this dtype"
13401415
"|must be a string or a number"
13411416
"|cannot be performed against 'object' dtypes"
1342-
"|must be a string or a real number)"
1417+
"|must be a string or a real number"
1418+
"|unsupported operand type)"
13431419
)
13441420
with pytest.raises(TypeError, match=msg):
13451421
method(*args, **kwargs)

0 commit comments

Comments
 (0)