Skip to content

Commit dc830ea

Browse files
topper-123phofllithomas1glemaitreCharlie-XIAO
authored
ENH: better dtype inference when doing DataFrame reductions (#52788)
* ENH: better dtype inference when doing DataFrame reductions * precommit issues * fix failures * fix failures * mypy + some docs * doc linting linting * refactor to use _reduce_with_wrap * docstring linting * pyarrow failure + linting * pyarrow failure + linting * linting * doc stuff * linting fixes * fix fix doc string * remove _wrap_na_result * doc string example * pyarrow + categorical * silence bugs * silence errors * silence errors II * fix errors III * various fixups * various fixups * delay fixing windows and 32bit failures * BUG: Adding a columns to a Frame with RangeIndex columns using a non-scalar key (#52877) * DOC: Update whatsnew (#52882) * CI: Change development python version to 3.10 (#51133) * CI: Change development python version to 3.10 * Update checks * Remove strict * Remove strict * Fixes * Add dt * Switch python to 3.9 * Remove * Fix * Try attribute * Adjust * Fix mypy * Try fixing doc build * Fix mypy * Fix stubtest * Remove workflow file * Rename back * Update * Rename * Rename * Change python version * Remove * Fix doc errors * Remove pypy * Update ci/deps/actions-pypy-39.yaml Co-authored-by: Thomas Li <[email protected]> * Revert pypy removal * Remove again * Fix * Change to 3.9 * Address --------- Co-authored-by: Thomas Li <[email protected]> * update * update * add docs * fix windows tests * fix windows tests * remove guards for 32bit linux * add bool tests + fix 32-bit failures * fix pre-commit failures * fix mypy failures * rename _reduce_with -> _reduce_and_wrap * assert missing attributes * reduction dtypes on windows and 32bit systems * add tests for min_count=0 * PERF:median with axis=1 * median with axis=1 fix * streamline Block.reduce * fix comments * FIX preserve dtype with datetime columns of different resolution when merging (#53213) * BUG Merge not behaving correctly when having `MultiIndex` with a single level (#53215) * fix merge when MultiIndex with single level * resolved conversations * fixed code style * BUG: preserve dtype for right/outer merge of datetime with different resolutions (#53233) * remove special BooleanArray.sum method * remove BooleanArray.prod * fixes * Update doc/source/whatsnew/v2.1.0.rst Co-authored-by: Joris Van den Bossche <[email protected]> * Update pandas/core/array_algos/masked_reductions.py Co-authored-by: Joris Van den Bossche <[email protected]> * small cleanup * small cleanup * only reduce 1d * fix after #53418 * update according to comments * revome note * update _minmax * REF: add keepdims parameter to ExtensionArray._reduce + remove ExtensionArray._reduce_and_wrap * REF: add keepdims parameter to ExtensionArray._reduce + remove ExtensionArray._reduce_and_wrap * fix whatsnew * fix _reduce call * simplify test * add tests for any/all --------- Co-authored-by: Patrick Hoefler <[email protected]> Co-authored-by: Thomas Li <[email protected]> Co-authored-by: Guillaume Lemaitre <[email protected]> Co-authored-by: Yao Xiao <[email protected]> Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 527bfc5 commit dc830ea

File tree

24 files changed

+619
-78
lines changed

24 files changed

+619
-78
lines changed

doc/source/user_guide/integer_na.rst

+2-1
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,11 @@ These dtypes can be merged, reshaped & casted.
126126
pd.concat([df[["A"]], df[["B", "C"]]], axis=1).dtypes
127127
df["A"].astype(float)
128128
129-
Reduction and groupby operations such as 'sum' work as well.
129+
Reduction and groupby operations such as :meth:`~DataFrame.sum` work as well.
130130

131131
.. ipython:: python
132132
133+
df.sum(numeric_only=True)
133134
df.sum()
134135
df.groupby("B").A.sum()
135136

doc/source/whatsnew/v2.1.0.rst

+40
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,46 @@ including other versions of pandas.
1414
Enhancements
1515
~~~~~~~~~~~~
1616

17+
.. _whatsnew_210.enhancements.reduction_extension_dtypes:
18+
19+
DataFrame reductions preserve extension dtypes
20+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
21+
22+
In previous versions of pandas, the results of DataFrame reductions
23+
(:meth:`DataFrame.sum` :meth:`DataFrame.mean` etc.) had numpy dtypes, even when the DataFrames
24+
were of extension dtypes. Pandas can now keep the dtypes when doing reductions over Dataframe
25+
columns with a common dtype (:issue:`52788`).
26+
27+
*Old Behavior*
28+
29+
.. code-block:: ipython
30+
31+
In [1]: df = pd.DataFrame({"a": [1, 1, 2, 1], "b": [np.nan, 2.0, 3.0, 4.0]}, dtype="Int64")
32+
In [2]: df.sum()
33+
Out[2]:
34+
a 5
35+
b 9
36+
dtype: int64
37+
In [3]: df = df.astype("int64[pyarrow]")
38+
In [4]: df.sum()
39+
Out[4]:
40+
a 5
41+
b 9
42+
dtype: int64
43+
44+
*New Behavior*
45+
46+
.. ipython:: python
47+
48+
df = pd.DataFrame({"a": [1, 1, 2, 1], "b": [np.nan, 2.0, 3.0, 4.0]}, dtype="Int64")
49+
df.sum()
50+
df = df.astype("int64[pyarrow]")
51+
df.sum()
52+
53+
Notice that the dtype is now a masked dtype and pyarrow dtype, respectively, while previously it was a numpy integer dtype.
54+
55+
To allow Dataframe reductions to preserve extension dtypes, :meth:`ExtensionArray._reduce` has gotten a new keyword parameter ``keepdims``. Calling :meth:`ExtensionArray._reduce` with ``keepdims=True`` should return an array of length 1 along the reduction axis. In order to maintain backward compatibility, the parameter is not required, but will it become required in the future. If the parameter is not found in the signature, DataFrame reductions can not preserve extension dtypes. Also, if the parameter is not found, a ``FutureWarning`` will be emitted and type checkers like mypy may complain about the signature not being compatible with :meth:`ExtensionArray._reduce`.
56+
1757
.. _whatsnew_210.enhancements.cow:
1858

1959
Copy-on-Write improvements

pandas/core/array_algos/masked_reductions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def _reductions(
5252
axis : int, optional, default None
5353
"""
5454
if not skipna:
55-
if mask.any(axis=axis) or check_below_min_count(values.shape, None, min_count):
55+
if mask.any() or check_below_min_count(values.shape, None, min_count):
5656
return libmissing.NA
5757
else:
5858
return func(values, axis=axis, **kwargs)
@@ -119,11 +119,11 @@ def _minmax(
119119
# min/max with empty array raise in numpy, pandas returns NA
120120
return libmissing.NA
121121
else:
122-
return func(values)
122+
return func(values, axis=axis)
123123
else:
124124
subset = values[~mask]
125125
if subset.size:
126-
return func(subset)
126+
return func(subset, axis=axis)
127127
else:
128128
# min/max with empty array raise in numpy, pandas returns NA
129129
return libmissing.NA

pandas/core/arrays/arrow/array.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -1508,7 +1508,9 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
15081508

15091509
return result
15101510

1511-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1511+
def _reduce(
1512+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1513+
):
15121514
"""
15131515
Return a scalar result of performing the reduction operation.
15141516
@@ -1532,12 +1534,16 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
15321534
------
15331535
TypeError : subclass does not define reductions
15341536
"""
1535-
result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
1537+
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
15361538

1537-
if pc.is_null(result).as_py():
1538-
return self.dtype.na_value
1539+
if keepdims:
1540+
result = pa.array([pa_result.as_py()], type=pa_result.type)
1541+
return type(self)(result)
15391542

1540-
return result.as_py()
1543+
if pc.is_null(pa_result).as_py():
1544+
return self.dtype.na_value
1545+
else:
1546+
return pa_result.as_py()
15411547

15421548
def _explode(self):
15431549
"""

pandas/core/arrays/base.py

+17-2
Original file line numberDiff line numberDiff line change
@@ -1535,7 +1535,9 @@ def _accumulate(
15351535
"""
15361536
raise NotImplementedError(f"cannot perform {name} with type {self.dtype}")
15371537

1538-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1538+
def _reduce(
1539+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1540+
):
15391541
"""
15401542
Return a scalar result of performing the reduction operation.
15411543
@@ -1547,6 +1549,15 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
15471549
std, var, sem, kurt, skew }.
15481550
skipna : bool, default True
15491551
If True, skip NaN values.
1552+
keepdims : bool, default False
1553+
If False, a scalar is returned.
1554+
If True, the result has dimension with size one along the reduced axis.
1555+
1556+
.. versionadded:: 2.1
1557+
1558+
This parameter is not required in the _reduce signature to keep backward
1559+
compatibility, but will become required in the future. If the parameter
1560+
is not found in the method signature, a FutureWarning will be emitted.
15501561
**kwargs
15511562
Additional keyword arguments passed to the reduction function.
15521563
Currently, `ddof` is the only supported kwarg.
@@ -1565,7 +1576,11 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
15651576
f"'{type(self).__name__}' with dtype {self.dtype} "
15661577
f"does not support reduction '{name}'"
15671578
)
1568-
return meth(skipna=skipna, **kwargs)
1579+
result = meth(skipna=skipna, **kwargs)
1580+
if keepdims:
1581+
result = np.array([result])
1582+
1583+
return result
15691584

15701585
# https://github.com/python/typeshed/issues/2148#issuecomment-520783318
15711586
# Incompatible types in assignment (expression has type "None", base class

pandas/core/arrays/categorical.py

+9
Original file line numberDiff line numberDiff line change
@@ -2319,6 +2319,15 @@ def _reverse_indexer(self) -> dict[Hashable, npt.NDArray[np.intp]]:
23192319
# ------------------------------------------------------------------
23202320
# Reductions
23212321

2322+
def _reduce(
2323+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
2324+
):
2325+
result = super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
2326+
if keepdims:
2327+
return type(self)(result, dtype=self.dtype)
2328+
else:
2329+
return result
2330+
23222331
def min(self, *, skipna: bool = True, **kwargs):
23232332
"""
23242333
The minimum value of the object.

pandas/core/arrays/masked.py

+68-21
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,10 @@
3232
Shape,
3333
npt,
3434
)
35+
from pandas.compat import (
36+
IS64,
37+
is_platform_windows,
38+
)
3539
from pandas.errors import AbstractMethodError
3640
from pandas.util._decorators import doc
3741
from pandas.util._validators import validate_fillna_kwargs
@@ -1081,21 +1085,31 @@ def _quantile(
10811085
# ------------------------------------------------------------------
10821086
# Reductions
10831087

1084-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1088+
def _reduce(
1089+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1090+
):
10851091
if name in {"any", "all", "min", "max", "sum", "prod", "mean", "var", "std"}:
1086-
return getattr(self, name)(skipna=skipna, **kwargs)
1087-
1088-
data = self._data
1089-
mask = self._mask
1090-
1091-
# median, skew, kurt, sem
1092-
op = getattr(nanops, f"nan{name}")
1093-
result = op(data, axis=0, skipna=skipna, mask=mask, **kwargs)
1092+
result = getattr(self, name)(skipna=skipna, **kwargs)
1093+
else:
1094+
# median, skew, kurt, sem
1095+
data = self._data
1096+
mask = self._mask
1097+
op = getattr(nanops, f"nan{name}")
1098+
axis = kwargs.pop("axis", None)
1099+
result = op(data, axis=axis, skipna=skipna, mask=mask, **kwargs)
1100+
1101+
if keepdims:
1102+
if isna(result):
1103+
return self._wrap_na_result(name=name, axis=0, mask_size=(1,))
1104+
else:
1105+
result = result.reshape(1)
1106+
mask = np.zeros(1, dtype=bool)
1107+
return self._maybe_mask_result(result, mask)
10941108

1095-
if np.isnan(result):
1109+
if isna(result):
10961110
return libmissing.NA
1097-
1098-
return result
1111+
else:
1112+
return result
10991113

11001114
def _wrap_reduction_result(self, name: str, result, *, skipna, axis):
11011115
if isinstance(result, np.ndarray):
@@ -1108,6 +1122,32 @@ def _wrap_reduction_result(self, name: str, result, *, skipna, axis):
11081122
return self._maybe_mask_result(result, mask)
11091123
return result
11101124

1125+
def _wrap_na_result(self, *, name, axis, mask_size):
1126+
mask = np.ones(mask_size, dtype=bool)
1127+
1128+
float_dtyp = "float32" if self.dtype == "Float32" else "float64"
1129+
if name in ["mean", "median", "var", "std", "skew"]:
1130+
np_dtype = float_dtyp
1131+
elif name in ["min", "max"] or self.dtype.itemsize == 8:
1132+
np_dtype = self.dtype.numpy_dtype.name
1133+
else:
1134+
is_windows_or_32bit = is_platform_windows() or not IS64
1135+
int_dtyp = "int32" if is_windows_or_32bit else "int64"
1136+
uint_dtyp = "uint32" if is_windows_or_32bit else "uint64"
1137+
np_dtype = {"b": int_dtyp, "i": int_dtyp, "u": uint_dtyp, "f": float_dtyp}[
1138+
self.dtype.kind
1139+
]
1140+
1141+
value = np.array([1], dtype=np_dtype)
1142+
return self._maybe_mask_result(value, mask=mask)
1143+
1144+
def _wrap_min_count_reduction_result(
1145+
self, name: str, result, *, skipna, min_count, axis
1146+
):
1147+
if min_count == 0 and isinstance(result, np.ndarray):
1148+
return self._maybe_mask_result(result, np.zeros(result.shape, dtype=bool))
1149+
return self._wrap_reduction_result(name, result, skipna=skipna, axis=axis)
1150+
11111151
def sum(
11121152
self,
11131153
*,
@@ -1125,7 +1165,9 @@ def sum(
11251165
min_count=min_count,
11261166
axis=axis,
11271167
)
1128-
return self._wrap_reduction_result("sum", result, skipna=skipna, axis=axis)
1168+
return self._wrap_min_count_reduction_result(
1169+
"sum", result, skipna=skipna, min_count=min_count, axis=axis
1170+
)
11291171

11301172
def prod(
11311173
self,
@@ -1136,14 +1178,17 @@ def prod(
11361178
**kwargs,
11371179
):
11381180
nv.validate_prod((), kwargs)
1181+
11391182
result = masked_reductions.prod(
11401183
self._data,
11411184
self._mask,
11421185
skipna=skipna,
11431186
min_count=min_count,
11441187
axis=axis,
11451188
)
1146-
return self._wrap_reduction_result("prod", result, skipna=skipna, axis=axis)
1189+
return self._wrap_min_count_reduction_result(
1190+
"prod", result, skipna=skipna, min_count=min_count, axis=axis
1191+
)
11471192

11481193
def mean(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
11491194
nv.validate_mean((), kwargs)
@@ -1183,23 +1228,25 @@ def std(
11831228

11841229
def min(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
11851230
nv.validate_min((), kwargs)
1186-
return masked_reductions.min(
1231+
result = masked_reductions.min(
11871232
self._data,
11881233
self._mask,
11891234
skipna=skipna,
11901235
axis=axis,
11911236
)
1237+
return self._wrap_reduction_result("min", result, skipna=skipna, axis=axis)
11921238

11931239
def max(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
11941240
nv.validate_max((), kwargs)
1195-
return masked_reductions.max(
1241+
result = masked_reductions.max(
11961242
self._data,
11971243
self._mask,
11981244
skipna=skipna,
11991245
axis=axis,
12001246
)
1247+
return self._wrap_reduction_result("max", result, skipna=skipna, axis=axis)
12011248

1202-
def any(self, *, skipna: bool = True, **kwargs):
1249+
def any(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
12031250
"""
12041251
Return whether any element is truthy.
12051252
@@ -1218,6 +1265,7 @@ def any(self, *, skipna: bool = True, **kwargs):
12181265
If `skipna` is False, the result will still be True if there is
12191266
at least one element that is truthy, otherwise NA will be returned
12201267
if there are NA's present.
1268+
axis : int, optional, default 0
12211269
**kwargs : any, default None
12221270
Additional keywords have no effect but might be accepted for
12231271
compatibility with NumPy.
@@ -1261,7 +1309,6 @@ def any(self, *, skipna: bool = True, **kwargs):
12611309
>>> pd.array([0, 0, pd.NA]).any(skipna=False)
12621310
<NA>
12631311
"""
1264-
kwargs.pop("axis", None)
12651312
nv.validate_any((), kwargs)
12661313

12671314
values = self._data.copy()
@@ -1280,7 +1327,7 @@ def any(self, *, skipna: bool = True, **kwargs):
12801327
else:
12811328
return self.dtype.na_value
12821329

1283-
def all(self, *, skipna: bool = True, **kwargs):
1330+
def all(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):
12841331
"""
12851332
Return whether all elements are truthy.
12861333
@@ -1299,6 +1346,7 @@ def all(self, *, skipna: bool = True, **kwargs):
12991346
If `skipna` is False, the result will still be False if there is
13001347
at least one element that is falsey, otherwise NA will be returned
13011348
if there are NA's present.
1349+
axis : int, optional, default 0
13021350
**kwargs : any, default None
13031351
Additional keywords have no effect but might be accepted for
13041352
compatibility with NumPy.
@@ -1342,7 +1390,6 @@ def all(self, *, skipna: bool = True, **kwargs):
13421390
>>> pd.array([1, 0, pd.NA]).all(skipna=False)
13431391
False
13441392
"""
1345-
kwargs.pop("axis", None)
13461393
nv.validate_all((), kwargs)
13471394

13481395
values = self._data.copy()
@@ -1352,7 +1399,7 @@ def all(self, *, skipna: bool = True, **kwargs):
13521399
# bool, int, float, complex, str, bytes,
13531400
# _NestedSequence[Union[bool, int, float, complex, str, bytes]]]"
13541401
np.putmask(values, self._mask, self._truthy_value) # type: ignore[arg-type]
1355-
result = values.all()
1402+
result = values.all(axis=axis)
13561403

13571404
if skipna:
13581405
return result

pandas/core/arrays/sparse/array.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,9 @@ def nonzero(self) -> tuple[npt.NDArray[np.int32]]:
13841384
# Reductions
13851385
# ------------------------------------------------------------------------
13861386

1387-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1387+
def _reduce(
1388+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1389+
):
13881390
method = getattr(self, name, None)
13891391

13901392
if method is None:
@@ -1395,7 +1397,12 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
13951397
else:
13961398
arr = self.dropna()
13971399

1398-
return getattr(arr, name)(**kwargs)
1400+
result = getattr(arr, name)(**kwargs)
1401+
1402+
if keepdims:
1403+
return type(self)([result], dtype=self.dtype)
1404+
else:
1405+
return result
13991406

14001407
def all(self, axis=None, *args, **kwargs):
14011408
"""

0 commit comments

Comments
 (0)