Skip to content

Commit ef89a57

Browse files
committed
ENH: better dtype inference when doing DataFrame reductions
1 parent 7ab653d commit ef89a57

File tree

6 files changed

+168
-13
lines changed

6 files changed

+168
-13
lines changed

pandas/core/arrays/arrow/array.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -1213,7 +1213,9 @@ def _accumulate(
12131213

12141214
return type(self)(result)
12151215

1216-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1216+
def _reduce(
1217+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1218+
):
12171219
"""
12181220
Return a scalar result of performing the reduction operation.
12191221
@@ -1310,6 +1312,12 @@ def pyarrow_meth(data, skip_nulls, **kwargs):
13101312
if name == "median":
13111313
# GH 52679: Use quantile instead of approximate_median; returns array
13121314
result = result[0]
1315+
1316+
if keepdims:
1317+
# TODO: is there a way to do this without .as_py()
1318+
result = pa.array([result.as_py()], type=result.type)
1319+
return type(self)(result)
1320+
13131321
if pc.is_null(result).as_py():
13141322
return self.dtype.na_value
13151323

pandas/core/arrays/base.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1403,7 +1403,9 @@ def _accumulate(
14031403
"""
14041404
raise NotImplementedError(f"cannot perform {name} with type {self.dtype}")
14051405

1406-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1406+
def _reduce(
1407+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1408+
):
14071409
"""
14081410
Return a scalar result of performing the reduction operation.
14091411
@@ -1433,7 +1435,14 @@ def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
14331435
f"'{type(self).__name__}' with dtype {self.dtype} "
14341436
f"does not support reduction '{name}'"
14351437
)
1436-
return meth(skipna=skipna, **kwargs)
1438+
result = meth(skipna=skipna, **kwargs)
1439+
1440+
if keepdims:
1441+
# if subclasses want to avoid wrapping in np.array, do:
1442+
# super()._reduce(..., keepdims=False) and wrap that.
1443+
return np.array([[result]])
1444+
else:
1445+
return result
14371446

14381447
# https://github.com/python/typeshed/issues/2148#issuecomment-520783318
14391448
# Incompatible types in assignment (expression has type "None", base class

pandas/core/arrays/masked.py

+40-5
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@
4242
from pandas.core.dtypes.common import (
4343
is_bool,
4444
is_dtype_equal,
45+
is_float_dtype,
4546
is_integer_dtype,
4647
is_list_like,
4748
is_scalar,
4849
is_string_dtype,
50+
is_unsigned_integer_dtype,
51+
is_signed_integer_dtype,
4952
pandas_dtype,
5053
)
5154
from pandas.core.dtypes.dtypes import BaseMaskedDtype
@@ -1069,7 +1072,15 @@ def _quantile(
10691072
# ------------------------------------------------------------------
10701073
# Reductions
10711074

1072-
def _reduce(self, name: str, *, skipna: bool = True, **kwargs):
1075+
def _reduce(
1076+
self, name: str, *, skipna: bool = True, keepdims: bool = False, **kwargs
1077+
):
1078+
if keepdims:
1079+
res = self.reshape(-1, 1)._reduce(name=name, skipna=skipna, **kwargs)
1080+
if res is libmissing.NA:
1081+
res = self._wrap_na_result(name)
1082+
return res
1083+
10731084
if name in {"any", "all", "min", "max", "sum", "prod", "mean", "var", "std"}:
10741085
return getattr(self, name)(skipna=skipna, **kwargs)
10751086

@@ -1097,6 +1108,30 @@ def _wrap_reduction_result(self, name: str, result, skipna, **kwargs):
10971108
return self._maybe_mask_result(result, mask)
10981109
return result
10991110

1111+
def _wrap_min_count_reduction_result(
1112+
self, name: str, result, skipna, min_count, **kwargs
1113+
):
1114+
if min_count == 0 and isinstance(result, np.ndarray):
1115+
return self._maybe_mask_result(result, np.zeros(1, dtype=bool))
1116+
return self._wrap_reduction_result(name, result, skipna, **kwargs)
1117+
1118+
def _wrap_na_result(self, name):
1119+
mask = np.ones(1, dtype=bool)
1120+
1121+
if is_float_dtype(self.dtype):
1122+
np_dtype = np.float64
1123+
elif name in ["mean", "median", "var", "std", "skew"]:
1124+
np_dtype = np.float64
1125+
elif is_signed_integer_dtype(self.dtype):
1126+
np_dtype = np.int64
1127+
elif is_unsigned_integer_dtype(self.dtype):
1128+
np_dtype = np.uint64
1129+
else:
1130+
raise TypeError(self.dtype)
1131+
1132+
value = np.array([1], dtype=np_dtype)
1133+
return self._maybe_mask_result(value, mask=mask)
1134+
11001135
def sum(
11011136
self,
11021137
*,
@@ -1114,8 +1149,8 @@ def sum(
11141149
min_count=min_count,
11151150
axis=axis,
11161151
)
1117-
return self._wrap_reduction_result(
1118-
"sum", result, skipna=skipna, axis=axis, **kwargs
1152+
return self._wrap_min_count_reduction_result(
1153+
"sum", result, skipna=skipna, min_count=min_count, axis=axis, **kwargs
11191154
)
11201155

11211156
def prod(
@@ -1134,8 +1169,8 @@ def prod(
11341169
min_count=min_count,
11351170
axis=axis,
11361171
)
1137-
return self._wrap_reduction_result(
1138-
"prod", result, skipna=skipna, axis=axis, **kwargs
1172+
return self._wrap_min_count_reduction_result(
1173+
"sum", result, skipna=skipna, min_count=min_count, axis=axis, **kwargs
11391174
)
11401175

11411176
def mean(self, *, skipna: bool = True, axis: AxisInt | None = 0, **kwargs):

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -10858,7 +10858,7 @@ def blk_func(values, axis: Axis = 1):
1085810858
self._mgr, ArrayManager
1085910859
):
1086010860
return values._reduce(name, axis=1, skipna=skipna, **kwds)
10861-
return values._reduce(name, skipna=skipna, **kwds)
10861+
return values._reduce(name, skipna=skipna, keepdims=True, **kwds)
1086210862
else:
1086310863
return op(values, axis=axis, skipna=skipna, **kwds)
1086410864

@@ -10903,7 +10903,7 @@ def _get_data() -> DataFrame:
1090310903
out = out.astype(out_dtype)
1090410904
elif (df._mgr.get_dtypes() == object).any():
1090510905
out = out.astype(object)
10906-
elif len(self) == 0 and name in ("sum", "prod"):
10906+
elif len(self) == 0 and out.dtype == object and name in ("sum", "prod"):
1090710907
# Even if we are object dtype, follow numpy and return
1090810908
# float64, see test_apply_funcs_over_empty
1090910909
out = out.astype(np.float64)

pandas/core/internals/blocks.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,10 @@ def reduce(self, func) -> list[Block]:
340340

341341
if self.values.ndim == 1:
342342
# TODO(EA2D): special case not needed with 2D EAs
343-
res_values = np.array([[result]])
343+
if isinstance(result, (np.ndarray, ExtensionArray)):
344+
res_values = result
345+
else:
346+
res_values = np.array([[result]])
344347
else:
345348
res_values = result.reshape(-1, 1)
346349

pandas/tests/frame/test_reductions.py

+102-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,10 @@
66
import numpy as np
77
import pytest
88

9-
from pandas.compat import is_platform_windows
9+
from pandas.compat import (
10+
IS64,
11+
is_platform_windows,
12+
)
1013
import pandas.util._test_decorators as td
1114

1215
import pandas as pd
@@ -29,6 +32,8 @@
2932
nanops,
3033
)
3134

35+
is_windows_or_is32 = is_platform_windows() or not IS64
36+
3237

3338
def assert_stat_op_calc(
3439
opname,
@@ -917,7 +922,7 @@ def test_mean_extensionarray_numeric_only_true(self):
917922
arr = np.random.randint(1000, size=(10, 5))
918923
df = DataFrame(arr, dtype="Int64")
919924
result = df.mean(numeric_only=True)
920-
expected = DataFrame(arr).mean()
925+
expected = DataFrame(arr).mean().astype("Float64")
921926
tm.assert_series_equal(result, expected)
922927

923928
def test_stats_mixed_type(self, float_string_frame):
@@ -1626,6 +1631,101 @@ def test_min_max_categorical_dtype_non_ordered_nuisance_column(self, method):
16261631
getattr(np, method)(df, axis=0)
16271632

16281633

1634+
class TestEmptyDataFrameReductions:
1635+
@pytest.mark.parametrize(
1636+
"opname, dtype, exp_value, exp_dtype",
1637+
[
1638+
("sum", np.int8, 0, np.int64),
1639+
("prod", np.int8, 1, np.int_),
1640+
("sum", np.int64, 0, np.int64),
1641+
("prod", np.int64, 1, np.int64),
1642+
("sum", np.uint8, 0, np.int64),
1643+
("prod", np.uint8, 1, np.uint),
1644+
("sum", np.uint64, 0, np.int64),
1645+
("prod", np.uint64, 1, np.uint64),
1646+
("sum", np.float32, 0, np.float32),
1647+
("prod", np.float32, 1, np.float32),
1648+
("sum", np.float64, 0, np.float64),
1649+
],
1650+
)
1651+
def test_df_empty_min_count_0(self, opname, dtype, exp_value, exp_dtype):
1652+
df = DataFrame({0: [], 1: []}, dtype=dtype)
1653+
result = getattr(df, opname)(min_count=0)
1654+
1655+
expected = Series([exp_value, exp_value], dtype=exp_dtype)
1656+
tm.assert_series_equal(result, expected)
1657+
1658+
@pytest.mark.parametrize(
1659+
"opname, dtype, exp_dtype",
1660+
[
1661+
("sum", np.int8, np.float64),
1662+
("prod", np.int8, np.float64),
1663+
("sum", np.int64, np.float64),
1664+
("prod", np.int64, np.float64),
1665+
("sum", np.uint8, np.float64),
1666+
("prod", np.uint8, np.float64),
1667+
("sum", np.uint64, np.float64),
1668+
("prod", np.uint64, np.float64),
1669+
("sum", np.float32, np.float32),
1670+
("prod", np.float32, np.float32),
1671+
("sum", np.float64, np.float64),
1672+
],
1673+
)
1674+
def test_df_empty_min_count_1(self, opname, dtype, exp_dtype):
1675+
df = DataFrame({0: [], 1: []}, dtype=dtype)
1676+
result = getattr(df, opname)(min_count=1)
1677+
1678+
expected = Series([np.nan, np.nan], dtype=exp_dtype)
1679+
tm.assert_series_equal(result, expected)
1680+
1681+
@pytest.mark.parametrize(
1682+
"opname, dtype, exp_value, exp_dtype",
1683+
[
1684+
("sum", "Int8", 0, ("Int32" if is_windows_or_is32 else "Int64")),
1685+
("prod", "Int8", 1, ("Int32" if is_windows_or_is32 else "Int64")),
1686+
("prod", "Int8", 1, ("Int32" if is_windows_or_is32 else "Int64")),
1687+
("sum", "Int64", 0, "Int64"),
1688+
("prod", "Int64", 1, "Int64"),
1689+
("sum", "UInt8", 0, ("UInt32" if is_windows_or_is32 else "UInt64")),
1690+
("prod", "UInt8", 1, ("UInt32" if is_windows_or_is32 else "UInt64")),
1691+
("sum", "UInt64", 0, "UInt64"),
1692+
("prod", "UInt64", 1, "UInt64"),
1693+
("sum", "Float32", 0, "Float32"),
1694+
("prod", "Float32", 1, "Float32"),
1695+
("sum", "Float64", 0, "Float64"),
1696+
],
1697+
)
1698+
def test_df_empty_nullable_min_count_0(self, opname, dtype, exp_value, exp_dtype):
1699+
df = DataFrame({0: [], 1: []}, dtype=dtype)
1700+
result = getattr(df, opname)(min_count=0)
1701+
1702+
expected = Series([exp_value, exp_value], dtype=exp_dtype)
1703+
tm.assert_series_equal(result, expected)
1704+
1705+
@pytest.mark.parametrize(
1706+
"opname, dtype, exp_dtype",
1707+
[
1708+
("sum", "Int8", "Int64"),
1709+
("prod", "Int8", "Int64"),
1710+
("sum", "Int64", "Int64"),
1711+
("prod", "Int64", "Int64"),
1712+
("sum", "UInt8", "UInt64"),
1713+
("prod", "UInt8", "UInt64"),
1714+
("sum", "UInt64", "UInt64"),
1715+
("prod", "UInt64", "UInt64"),
1716+
("sum", "Float32", "Float32"),
1717+
("prod", "Float32", "Float32"),
1718+
("sum", "Float64", "Float64"),
1719+
],
1720+
)
1721+
def test_df_empty_nullable_min_count_1(self, opname, dtype, exp_dtype):
1722+
df = DataFrame({0: [], 1: []}, dtype=dtype)
1723+
result = getattr(df, opname)(min_count=1)
1724+
1725+
expected = Series([pd.NA, pd.NA], dtype=exp_dtype)
1726+
tm.assert_series_equal(result, expected)
1727+
1728+
16291729
def test_sum_timedelta64_skipna_false(using_array_manager, request):
16301730
# GH#17235
16311731
if using_array_manager:

0 commit comments

Comments
 (0)