Skip to content

Commit 5e97e67

Browse files
lithomas1mroeschkepre-commit-ci[bot]
authored
BUG: Fix metadata propagation in reductions (#53542)
* BUG: Fix metadata propagation in reductions * fix tests * actually fix tests * fix typing * Update pandas/core/reshape/encoding.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Matthew Roeschke <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 90d9e25 commit 5e97e67

File tree

5 files changed

+76
-35
lines changed

5 files changed

+76
-35
lines changed

doc/source/whatsnew/v2.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -772,8 +772,10 @@ Styler
772772

773773
Metadata
774774
^^^^^^^^
775+
- Fixed metadata propagation in :meth:`DataFrame.max`, :meth:`DataFrame.min`, :meth:`DataFrame.prod`, :meth:`DataFrame.mean`, :meth:`Series.mode`, :meth:`DataFrame.median`, :meth:`DataFrame.sem`, :meth:`DataFrame.skew`, :meth:`DataFrame.kurt` (:issue:`28283`)
775776
- Fixed metadata propagation in :meth:`DataFrame.squeeze`, and :meth:`DataFrame.describe` (:issue:`28283`)
776777
- Fixed metadata propagation in :meth:`DataFrame.std` (:issue:`28283`)
778+
-
777779

778780
Other
779781
^^^^^

pandas/core/frame.py

+48-19
Original file line numberDiff line numberDiff line change
@@ -11137,12 +11137,13 @@ def any( # type: ignore[override]
1113711137
bool_only: bool = False,
1113811138
skipna: bool = True,
1113911139
**kwargs,
11140-
) -> Series:
11141-
# error: Incompatible return value type (got "Union[Series, bool]",
11142-
# expected "Series")
11143-
return self._logical_func( # type: ignore[return-value]
11140+
) -> Series | bool:
11141+
result = self._logical_func(
1114411142
"any", nanops.nanany, axis, bool_only, skipna, **kwargs
1114511143
)
11144+
if isinstance(result, Series):
11145+
result = result.__finalize__(self, method="any")
11146+
return result
1114611147

1114711148
@doc(make_doc("all", ndim=2))
1114811149
def all(
@@ -11151,12 +11152,13 @@ def all(
1115111152
bool_only: bool = False,
1115211153
skipna: bool = True,
1115311154
**kwargs,
11154-
) -> Series:
11155-
# error: Incompatible return value type (got "Union[Series, bool]",
11156-
# expected "Series")
11157-
return self._logical_func( # type: ignore[return-value]
11155+
) -> Series | bool:
11156+
result = self._logical_func(
1115811157
"all", nanops.nanall, axis, bool_only, skipna, **kwargs
1115911158
)
11159+
if isinstance(result, Series):
11160+
result = result.__finalize__(self, method="all")
11161+
return result
1116011162

1116111163
@doc(make_doc("min", ndim=2))
1116211164
def min(
@@ -11166,7 +11168,10 @@ def min(
1116611168
numeric_only: bool = False,
1116711169
**kwargs,
1116811170
):
11169-
return super().min(axis, skipna, numeric_only, **kwargs)
11171+
result = super().min(axis, skipna, numeric_only, **kwargs)
11172+
if isinstance(result, Series):
11173+
result = result.__finalize__(self, method="min")
11174+
return result
1117011175

1117111176
@doc(make_doc("max", ndim=2))
1117211177
def max(
@@ -11176,7 +11181,10 @@ def max(
1117611181
numeric_only: bool = False,
1117711182
**kwargs,
1117811183
):
11179-
return super().max(axis, skipna, numeric_only, **kwargs)
11184+
result = super().max(axis, skipna, numeric_only, **kwargs)
11185+
if isinstance(result, Series):
11186+
result = result.__finalize__(self, method="max")
11187+
return result
1118011188

1118111189
@doc(make_doc("sum", ndim=2))
1118211190
def sum(
@@ -11199,7 +11207,8 @@ def prod(
1119911207
min_count: int = 0,
1120011208
**kwargs,
1120111209
):
11202-
return super().prod(axis, skipna, numeric_only, min_count, **kwargs)
11210+
result = super().prod(axis, skipna, numeric_only, min_count, **kwargs)
11211+
return result.__finalize__(self, method="prod")
1120311212

1120411213
@doc(make_doc("mean", ndim=2))
1120511214
def mean(
@@ -11209,7 +11218,10 @@ def mean(
1120911218
numeric_only: bool = False,
1121011219
**kwargs,
1121111220
):
11212-
return super().mean(axis, skipna, numeric_only, **kwargs)
11221+
result = super().mean(axis, skipna, numeric_only, **kwargs)
11222+
if isinstance(result, Series):
11223+
result = result.__finalize__(self, method="mean")
11224+
return result
1121311225

1121411226
@doc(make_doc("median", ndim=2))
1121511227
def median(
@@ -11219,7 +11231,10 @@ def median(
1121911231
numeric_only: bool = False,
1122011232
**kwargs,
1122111233
):
11222-
return super().median(axis, skipna, numeric_only, **kwargs)
11234+
result = super().median(axis, skipna, numeric_only, **kwargs)
11235+
if isinstance(result, Series):
11236+
result = result.__finalize__(self, method="median")
11237+
return result
1122311238

1122411239
@doc(make_doc("sem", ndim=2))
1122511240
def sem(
@@ -11230,7 +11245,10 @@ def sem(
1123011245
numeric_only: bool = False,
1123111246
**kwargs,
1123211247
):
11233-
return super().sem(axis, skipna, ddof, numeric_only, **kwargs)
11248+
result = super().sem(axis, skipna, ddof, numeric_only, **kwargs)
11249+
if isinstance(result, Series):
11250+
result = result.__finalize__(self, method="sem")
11251+
return result
1123411252

1123511253
@doc(make_doc("var", ndim=2))
1123611254
def var(
@@ -11241,7 +11259,10 @@ def var(
1124111259
numeric_only: bool = False,
1124211260
**kwargs,
1124311261
):
11244-
return super().var(axis, skipna, ddof, numeric_only, **kwargs)
11262+
result = super().var(axis, skipna, ddof, numeric_only, **kwargs)
11263+
if isinstance(result, Series):
11264+
result = result.__finalize__(self, method="var")
11265+
return result
1124511266

1124611267
@doc(make_doc("std", ndim=2))
1124711268
def std(
@@ -11252,8 +11273,10 @@ def std(
1125211273
numeric_only: bool = False,
1125311274
**kwargs,
1125411275
):
11255-
result = cast(Series, super().std(axis, skipna, ddof, numeric_only, **kwargs))
11256-
return result.__finalize__(self, method="std")
11276+
result = super().std(axis, skipna, ddof, numeric_only, **kwargs)
11277+
if isinstance(result, Series):
11278+
result = result.__finalize__(self, method="std")
11279+
return result
1125711280

1125811281
@doc(make_doc("skew", ndim=2))
1125911282
def skew(
@@ -11263,7 +11286,10 @@ def skew(
1126311286
numeric_only: bool = False,
1126411287
**kwargs,
1126511288
):
11266-
return super().skew(axis, skipna, numeric_only, **kwargs)
11289+
result = super().skew(axis, skipna, numeric_only, **kwargs)
11290+
if isinstance(result, Series):
11291+
result = result.__finalize__(self, method="skew")
11292+
return result
1126711293

1126811294
@doc(make_doc("kurt", ndim=2))
1126911295
def kurt(
@@ -11273,7 +11299,10 @@ def kurt(
1127311299
numeric_only: bool = False,
1127411300
**kwargs,
1127511301
):
11276-
return super().kurt(axis, skipna, numeric_only, **kwargs)
11302+
result = super().kurt(axis, skipna, numeric_only, **kwargs)
11303+
if isinstance(result, Series):
11304+
result = result.__finalize__(self, method="kurt")
11305+
return result
1127711306

1127811307
kurtosis = kurt
1127911308
product = prod

pandas/core/reshape/encoding.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
Iterable,
77
)
88
import itertools
9-
from typing import TYPE_CHECKING
9+
from typing import (
10+
TYPE_CHECKING,
11+
cast,
12+
)
1013

1114
import numpy as np
1215

@@ -455,10 +458,12 @@ def from_dummies(
455458
f"Received 'data' of type: {type(data).__name__}"
456459
)
457460

458-
if data.isna().any().any():
461+
col_isna_mask = cast(Series, data.isna().any())
462+
463+
if col_isna_mask.any():
459464
raise ValueError(
460465
"Dummy DataFrame contains NA value in column: "
461-
f"'{data.isna().any().idxmax()}'"
466+
f"'{col_isna_mask.idxmax()}'"
462467
)
463468

464469
# index data with a list of all columns that are dummies

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -2202,7 +2202,7 @@ def mode(self, dropna: bool = True) -> Series:
22022202
# Ensure index is type stable (should always use int index)
22032203
return self._constructor(
22042204
res_values, index=range(len(res_values)), name=self.name, copy=False
2205-
)
2205+
).__finalize__(self, method="mode")
22062206

22072207
def unique(self) -> ArrayLike: # pylint: disable=useless-parent-delegation
22082208
"""

pandas/tests/generic/test_finalize.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,8 @@
180180
(pd.DataFrame, frame_data, operator.methodcaller("idxmin")),
181181
(pd.DataFrame, frame_data, operator.methodcaller("idxmax")),
182182
(pd.DataFrame, frame_data, operator.methodcaller("mode")),
183-
pytest.param(
184-
(pd.Series, [0], operator.methodcaller("mode")),
185-
marks=not_implemented_mark,
186-
),
183+
(pd.Series, [0], operator.methodcaller("mode")),
184+
(pd.DataFrame, frame_data, operator.methodcaller("median")),
187185
(
188186
pd.DataFrame,
189187
frame_data,
@@ -363,17 +361,24 @@
363361
# Cumulative reductions
364362
(pd.Series, ([1],), operator.methodcaller("cumsum")),
365363
(pd.DataFrame, frame_data, operator.methodcaller("cumsum")),
364+
(pd.Series, ([1],), operator.methodcaller("cummin")),
365+
(pd.DataFrame, frame_data, operator.methodcaller("cummin")),
366+
(pd.Series, ([1],), operator.methodcaller("cummax")),
367+
(pd.DataFrame, frame_data, operator.methodcaller("cummax")),
368+
(pd.Series, ([1],), operator.methodcaller("cumprod")),
369+
(pd.DataFrame, frame_data, operator.methodcaller("cumprod")),
366370
# Reductions
367-
pytest.param(
368-
(pd.DataFrame, frame_data, operator.methodcaller("any")),
369-
marks=not_implemented_mark,
370-
),
371+
(pd.DataFrame, frame_data, operator.methodcaller("any")),
372+
(pd.DataFrame, frame_data, operator.methodcaller("all")),
373+
(pd.DataFrame, frame_data, operator.methodcaller("min")),
374+
(pd.DataFrame, frame_data, operator.methodcaller("max")),
371375
(pd.DataFrame, frame_data, operator.methodcaller("sum")),
372376
(pd.DataFrame, frame_data, operator.methodcaller("std")),
373-
pytest.param(
374-
(pd.DataFrame, frame_data, operator.methodcaller("mean")),
375-
marks=not_implemented_mark,
376-
),
377+
(pd.DataFrame, frame_data, operator.methodcaller("mean")),
378+
(pd.DataFrame, frame_data, operator.methodcaller("prod")),
379+
(pd.DataFrame, frame_data, operator.methodcaller("sem")),
380+
(pd.DataFrame, frame_data, operator.methodcaller("skew")),
381+
(pd.DataFrame, frame_data, operator.methodcaller("kurt")),
377382
]
378383

379384

0 commit comments

Comments
 (0)