Skip to content

Commit 49334c7

Browse files
committed
REF: add keepdims parameter to ExtensionArray._reduce + remove ExtensionArray._reduce_and_wrap
1 parent dd0bfe8 commit 49334c7

File tree

13 files changed

+127
-92
lines changed

13 files changed

+127
-92
lines changed

doc/source/reference/extensions.rst

-1
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ objects.
4040
api.extensions.ExtensionArray._from_sequence_of_strings
4141
api.extensions.ExtensionArray._hash_pandas_object
4242
api.extensions.ExtensionArray._reduce
43-
api.extensions.ExtensionArray._reduce_and_wrap
4443
api.extensions.ExtensionArray._values_for_argsort
4544
api.extensions.ExtensionArray._values_for_factorize
4645
api.extensions.ExtensionArray.argsort

doc/source/whatsnew/v2.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ columns with a common dtype (:issue:`52788`).
5252
5353
Notice that the dtype is now a masked dtype and pyarrow dtype, respectively, while previously it was a numpy integer dtype.
5454

55+
To allow Dataframe reductions to preserve extension dtypes, :ref:`ExtensionArray._reduce` has gotten a new keyword parameter ``keepdims``. Calling :ref:`ExtensionArray._reduce` with ``keepdims=True`` should return an array of length 1 along the reduction axis. In order to maintain backward compatibility, the parameter is not required, but will it become required in the future. If the parameter is not found in the signature, DataFrame reductions can not preserve extension dtypes. Also, if the parameter is not found, a ``FutureWarning`` will be emitted and type checkers like mypy may complain about the signature not being compatible with :ref:`ExtensionArray._reduce`.
56+
5557
.. _whatsnew_210.enhancements.cow:
5658

5759
Copy-on-Write improvements

pandas/core/arrays/arrow/array.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -1512,7 +1512,9 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
15121512

15131513
return result
15141514

1515-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1515+
def _reduce(
1516+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1517+
):
15161518
"""
15171519
Return a scalar result of performing the reduction operation.
15181520
@@ -1536,18 +1538,16 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
15361538
------
15371539
TypeError : subclass does not define reductions
15381540
"""
1539-
result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
1540-
1541-
if pc.is_null(result).as_py():
1542-
return self.dtype.na_value
1541+
pa_result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
15431542

1544-
return result.as_py()
1543+
if keepdims:
1544+
result = pa.array([pa_result.as_py()], type=pa_result.type)
1545+
return type(self)(result)
15451546

1546-
def _reduce_and_wrap(self, name: str, *, skipna: bool = True, kwargs):
1547-
"""Takes the result of ``_reduce`` and wraps it an a ndarray/extensionArray."""
1548-
result = self._reduce_pyarrow(name, skipna=skipna, **kwargs)
1549-
result = pa.array([result.as_py()], type=result.type)
1550-
return type(self)(result)
1547+
if pc.is_null(pa_result).as_py():
1548+
return self.dtype.na_value
1549+
else:
1550+
return pa_result.as_py()
15511551

15521552
def _explode(self):
15531553
"""

pandas/core/arrays/base.py

+16-30
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,6 @@ class ExtensionArray:
140140
_from_sequence_of_strings
141141
_hash_pandas_object
142142
_reduce
143-
_reduce_and_wrap
144143
_values_for_argsort
145144
_values_for_factorize
146145
@@ -190,7 +189,6 @@ class ExtensionArray:
190189
191190
* _accumulate
192191
* _reduce
193-
* _reduce_and_wrap
194192
195193
One can implement methods to handle parsing from strings that will be used
196194
in methods such as ``pandas.io.parsers.read_csv``.
@@ -1437,7 +1435,9 @@ def _accumulate(
14371435
"""
14381436
raise NotImplementedError(f"cannot perform {name} with type {self.dtype}")
14391437

1440-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1438+
def _reduce(
1439+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1440+
):
14411441
"""
14421442
Return a scalar result of performing the reduction operation.
14431443
@@ -1449,6 +1449,15 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
14491449
std, var, sem, kurt, skew }.
14501450
skipna : bool, default True
14511451
If True, skip NaN values.
1452+
keepdims : bool, default False
1453+
If False, a scalar is returned.
1454+
If True, the result has dimension with size one along the reduced axis.
1455+
1456+
.. versionadded:: 2.1
1457+
1458+
This parameter is not required in the _reduce signature to keep backward
1459+
compatibility, but will become required in the future. If the parameter
1460+
is not found in the method signature, a FutureWarning will be emitted.
14521461
**kwargs
14531462
Additional keyword arguments passed to the reduction function.
14541463
Currently, `ddof` is the only supported kwarg.
@@ -1460,41 +1469,18 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
14601469
Raises
14611470
------
14621471
TypeError : subclass does not define reductions
1463-
1464-
See Also
1465-
--------
1466-
ExtensionArray._reduce_and_wrap
1467-
Calls ``_reduce`` and wraps the result in a ndarray/ExtensionArray.
14681472
"""
14691473
meth = getattr(self, name, None)
14701474
if meth is None:
14711475
raise TypeError(
14721476
f"'{type(self).__name__}' with dtype {self.dtype} "
14731477
f"does not support reduction '{name}'"
14741478
)
1475-
return meth(skipna=skipna, **kwargs)
1476-
1477-
def _reduce_and_wrap(self, name: str, *, skipna: bool = True, kwargs):
1478-
"""
1479-
Call ``_reduce`` and wrap the result in a ndarray/ExtensionArray.
1479+
result = meth(skipna=skipna, **kwargs)
1480+
if keepdims:
1481+
result = np.array([result])
14801482

1481-
This is used to control the returned dtype when doing reductions in DataFrames,
1482-
and ensures the correct dtype for e.g. ``DataFrame({"a": extr_arr2}).sum()``.
1483-
1484-
Returns
1485-
-------
1486-
ndarray or ExtensionArray
1487-
1488-
Examples
1489-
--------
1490-
>>> arr = pd.array([1, 2, pd.NA])
1491-
>>> arr._reduce_and_wrap("sum", kwargs={})
1492-
<IntegerArray>
1493-
[3]
1494-
Length: 1, dtype: Int64
1495-
"""
1496-
result = self._reduce(name, skipna=skipna, **kwargs)
1497-
return np.array([result])
1483+
return result
14981484

14991485
# https://github.com/python/typeshed/issues/2148#issuecomment-520783318
15001486
# Incompatible types in assignment (expression has type "None", base class

pandas/core/arrays/categorical.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -2229,9 +2229,14 @@ def _reverse_indexer(self) -> dict[Hashable, npt.NDArray[np.intp]]:
22292229
# ------------------------------------------------------------------
22302230
# Reductions
22312231

2232-
def _reduce_and_wrap(self, name: str, *, skipna: bool = True, kwargs):
2233-
result = self._reduce(name, skipna=skipna, **kwargs)
2234-
return type(self)([result], dtype=self.dtype)
2232+
def _reduce(
2233+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
2234+
):
2235+
result = super()._reduce(name, skipna=skipna, keepdims=keepdims, **kwargs)
2236+
if keepdims:
2237+
return type(self)(result, dtype=self.dtype)
2238+
else:
2239+
return result
22352240

22362241
def min(self, *, skipna: bool = True, **kwargs):
22372242
"""

pandas/core/arrays/masked.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -1083,30 +1083,31 @@ def _quantile(
10831083
# ------------------------------------------------------------------
10841084
# Reductions
10851085

1086-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1086+
def _reduce(
1087+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1088+
):
10871089
if name in {"any", "all", "min", "max", "sum", "prod", "mean", "var", "std"}:
1088-
return getattr(self, name)(skipna=skipna, **kwargs)
1089-
1090-
data = self._data
1091-
mask = self._mask
1090+
result = getattr(self, name)(skipna=skipna, **kwargs)
1091+
else:
1092+
# median, skew, kurt, sem
1093+
data = self._data
1094+
mask = self._mask
1095+
op = getattr(nanops, f"nan{name}")
1096+
axis = kwargs.pop("axis", None)
1097+
result = op(data, axis=axis, skipna=skipna, mask=mask, **kwargs)
1098+
1099+
if keepdims:
1100+
if isna(result):
1101+
return self._wrap_na_result(name=name, axis=0, mask_size=(1,))
1102+
else:
1103+
result = result.reshape(1)
1104+
mask = np.zeros(1, dtype=bool)
1105+
return self._maybe_mask_result(result, mask)
10921106

1093-
# median, skew, kurt, sem
1094-
op = getattr(nanops, f"nan{name}")
1095-
axis = kwargs.pop("axis", None)
1096-
result = op(data, axis=axis, skipna=skipna, mask=mask, **kwargs)
1097-
if np.isnan(result):
1107+
if isna(result):
10981108
return libmissing.NA
1099-
1100-
return result
1101-
1102-
def _reduce_and_wrap(self, name: str, *, skipna: bool = True, kwargs):
1103-
res = self._reduce(name=name, skipna=skipna, **kwargs)
1104-
if res is libmissing.NA:
1105-
return self._wrap_na_result(name=name, axis=0, mask_size=(1,))
11061109
else:
1107-
res = res.reshape(1)
1108-
mask = np.zeros(1, dtype=bool)
1109-
return self._maybe_mask_result(res, mask)
1110+
return result
11101111

11111112
def _wrap_reduction_result(self, name: str, result, *, skipna, axis):
11121113
if isinstance(result, np.ndarray):

pandas/core/arrays/sparse/array.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -1384,7 +1384,9 @@ def nonzero(self) -> tuple[npt.NDArray[np.int32]]:
13841384
# Reductions
13851385
# ------------------------------------------------------------------------
13861386

1387-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1387+
def _reduce(
1388+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1389+
):
13881390
method = getattr(self, name, None)
13891391

13901392
if method is None:
@@ -1395,7 +1397,12 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
13951397
else:
13961398
arr = self.dropna()
13971399

1398-
return getattr(arr, name)(**kwargs)
1400+
result = getattr(arr, name)(**kwargs)
1401+
1402+
if keepdims:
1403+
return type(self)([result], dtype=self.dtype)
1404+
else:
1405+
return result
13991406

14001407
def all(self, axis=None, *args, **kwargs):
14011408
"""

pandas/core/frame.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import collections
1414
from collections import abc
1515
import functools
16+
from inspect import signature
1617
from io import StringIO
1718
import itertools
1819
import operator
@@ -10877,7 +10878,18 @@ def blk_func(values, axis: Axis = 1):
1087710878
self._mgr, ArrayManager
1087810879
):
1087910880
return values._reduce(name, axis=1, skipna=skipna, **kwds)
10880-
return values._reduce_and_wrap(name, skipna=skipna, kwargs=kwds)
10881+
sign = signature(values._reduce)
10882+
if "keepdims" in sign.parameters:
10883+
return values._reduce(name, skipna=skipna, keepdims=True, **kwds)
10884+
else:
10885+
warnings.warn(
10886+
f"{type(values)}._reduce will require a `keepdims` parameter "
10887+
"in the future",
10888+
FutureWarning,
10889+
stacklevel=find_stack_level(),
10890+
)
10891+
result = values._reduce(name, skipna=skipna, kwargs=kwds)
10892+
return np.array([result])
1088110893
else:
1088210894
return op(values, axis=axis, skipna=skipna, **kwds)
1088310895

pandas/tests/extension/decimal/array.py

+21-20
Original file line numberDiff line numberDiff line change
@@ -235,28 +235,29 @@ def _formatter(self, boxed=False):
235235
def _concat_same_type(cls, to_concat):
236236
return cls(np.concatenate([x._data for x in to_concat]))
237237

238-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
239-
if skipna:
238+
def _reduce(
239+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
240+
):
241+
if skipna and self.isna().any():
240242
# If we don't have any NAs, we can ignore skipna
241-
if self.isna().any():
242-
other = self[~self.isna()]
243-
return other._reduce(name, **kwargs)
244-
245-
if name == "sum" and len(self) == 0:
243+
other = self[~self.isna()]
244+
result = other._reduce(name, **kwargs)
245+
elif name == "sum" and len(self) == 0:
246246
# GH#29630 avoid returning int 0 or np.bool_(False) on old numpy
247-
return decimal.Decimal(0)
248-
249-
try:
250-
op = getattr(self.data, name)
251-
except AttributeError as err:
252-
raise NotImplementedError(
253-
f"decimal does not support the {name} operation"
254-
) from err
255-
return op(axis=0)
256-
257-
def _reduce_and_wrap(self, name: str, *, skipna: bool = True, kwargs):
258-
result = self._reduce(name, skipna=skipna, **kwargs)
259-
return type(self)([result])
247+
result = decimal.Decimal(0)
248+
else:
249+
try:
250+
op = getattr(self.data, name)
251+
except AttributeError as err:
252+
raise NotImplementedError(
253+
f"decimal does not support the {name} operation"
254+
) from err
255+
result = op(axis=0)
256+
257+
if keepdims:
258+
return type(self)([result])
259+
else:
260+
return result
260261

261262
def _cmp_method(self, other, op):
262263
# For use with OpsMixin

pandas/tests/extension/decimal/test_decimal.py

+23-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
123123
assert not hasattr(arr, op_name)
124124
pytest.skip(f"{op_name} not an array method")
125125

126-
result1 = arr._reduce_and_wrap(op_name, skipna=skipna, kwargs={})
126+
result1 = arr._reduce(op_name, skipna=skipna, keepdims=True)
127127
result2 = getattr(df, op_name)(skipna=skipna).array
128128

129129
tm.assert_extension_array_equal(result1, result2)
@@ -136,6 +136,28 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
136136

137137
tm.assert_extension_array_equal(result1, expected)
138138

139+
def test_reduction_without_keepdims(self):
140+
# GH52788
141+
# test _reduce without keepdims
142+
143+
class DecimalArray2(DecimalArray):
144+
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
145+
# no keepdims in signature
146+
return super()._reduce(name, skipna=skipna)
147+
148+
arr = DecimalArray2([decimal.Decimal(2) for _ in range(100)])
149+
150+
ser = pd.Series(arr)
151+
result = ser.agg("sum")
152+
expected = decimal.Decimal(200)
153+
assert result == expected
154+
155+
df = pd.DataFrame({"a": arr, "b": arr})
156+
with tm.assert_produces_warning(FutureWarning):
157+
result = df.agg("sum")
158+
expected = pd.Series({"a": 200, "b": 200}, dtype=object)
159+
tm.assert_series_equal(result, expected)
160+
139161

140162
class TestNumericReduce(Reduce, base.BaseNumericReduceTests):
141163
pass

pandas/tests/extension/masked_shared.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
9595
exp_value = getattr(ser.dropna().astype(cmp_dtype), op_name)()
9696
expected = pd.array([exp_value], dtype=cmp_dtype)
9797

98-
result1 = arr._reduce_and_wrap(op_name, skipna=skipna, kwargs={})
98+
result1 = arr._reduce(op_name, skipna=skipna, keepdims=True)
9999
result2 = getattr(df, op_name)(skipna=skipna).array
100100

101101
tm.assert_extension_array_equal(result1, result2)

pandas/tests/extension/test_arrow.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -533,7 +533,7 @@ def check_reduce_frame(self, ser, op_name, skipna):
533533
"u": "uint64[pyarrow]",
534534
"f": "float64[pyarrow]",
535535
}[arr.dtype.kind]
536-
result = arr._reduce_and_wrap(op_name, skipna=skipna, kwargs=kwargs)
536+
result = arr._reduce(op_name, skipna=skipna, keepdims=True, **kwargs)
537537

538538
if not skipna and ser.isna().any():
539539
expected = pd.array([pd.NA], dtype=cmp_dtype)

pandas/tests/extension/test_boolean.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ def check_reduce_frame(self, ser: pd.Series, op_name: str, skipna: bool):
387387
else:
388388
raise TypeError("not supposed to reach this")
389389

390-
result = arr._reduce_and_wrap(op_name, skipna=skipna, kwargs={})
390+
result = arr._reduce(op_name, skipna=skipna, keepdims=True)
391391
if not skipna and ser.isna().any():
392392
expected = pd.array([pd.NA], dtype=cmp_dtype)
393393
else:

0 commit comments

Comments
 (0)