Skip to content

Commit 19c8a4a

Browse files
authored
PERF: axis=1 reductions with EA dtypes (#54341)
1 parent 582a1be commit 19c8a4a

File tree

4 files changed

+109
-1
lines changed

4 files changed

+109
-1
lines changed

doc/source/whatsnew/v2.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,7 @@ Performance improvements
614614
- :class:`Period`'s default formatter (`period_format`) is now significantly (~twice) faster. This improves performance of ``str(Period)``, ``repr(Period)``, and :meth:`Period.strftime(fmt=None)`, as well as ``PeriodArray.strftime(fmt=None)``, ``PeriodIndex.strftime(fmt=None)`` and ``PeriodIndex.format(fmt=None)``. Finally, ``to_csv`` operations involving :class:`PeriodArray` or :class:`PeriodIndex` with default ``date_format`` are also significantly accelerated. (:issue:`51459`)
615615
- Performance improvement accessing :attr:`arrays.IntegerArrays.dtype` & :attr:`arrays.FloatingArray.dtype` (:issue:`52998`)
616616
- Performance improvement for :class:`DataFrameGroupBy`/:class:`SeriesGroupBy` aggregations (e.g. :meth:`DataFrameGroupBy.sum`) with ``engine="numba"`` (:issue:`53731`)
617+
- Performance improvement in :class:`DataFrame` reductions with ``axis=1`` and extension dtypes (:issue:`54341`)
617618
- Performance improvement in :class:`DataFrame` reductions with ``axis=None`` and extension dtypes (:issue:`54308`)
618619
- Performance improvement in :class:`MultiIndex` and multi-column operations (e.g. :meth:`DataFrame.sort_values`, :meth:`DataFrame.groupby`, :meth:`Series.unstack`) when index/column values are already sorted (:issue:`53806`)
619620
- Performance improvement in :class:`Series` reductions (:issue:`52341`)

pandas/core/frame.py

+26
Original file line numberDiff line numberDiff line change
@@ -11172,6 +11172,32 @@ def _get_data() -> DataFrame:
1117211172
).iloc[:0]
1117311173
result.index = df.index
1117411174
return result
11175+
11176+
# kurtosis excluded since groupby does not implement it
11177+
if df.shape[1] and name != "kurt":
11178+
dtype = find_common_type([arr.dtype for arr in df._mgr.arrays])
11179+
if isinstance(dtype, ExtensionDtype):
11180+
# GH 54341: fastpath for EA-backed axis=1 reductions
11181+
# This flattens the frame into a single 1D array while keeping
11182+
# track of the row and column indices of the original frame. Once
11183+
# flattened, grouping by the row indices and aggregating should
11184+
# be equivalent to transposing the original frame and aggregating
11185+
# with axis=0.
11186+
name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name)
11187+
df = df.astype(dtype, copy=False)
11188+
arr = concat_compat(list(df._iter_column_arrays()))
11189+
nrows, ncols = df.shape
11190+
row_index = np.tile(np.arange(nrows), ncols)
11191+
col_index = np.repeat(np.arange(ncols), nrows)
11192+
ser = Series(arr, index=col_index, copy=False)
11193+
result = ser.groupby(row_index).agg(name, **kwds)
11194+
result.index = df.index
11195+
if not skipna and name not in ("any", "all"):
11196+
mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1)
11197+
other = -1 if name in ("idxmax", "idxmin") else lib.no_default
11198+
result = result.mask(mask, other)
11199+
return result
11200+
1117511201
df = df.T
1117611202

1117711203
# After possibly _get_data and transposing, we are now in the

pandas/core/groupby/groupby.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ class providing the base-class of operations.
105105
ExtensionArray,
106106
FloatingArray,
107107
IntegerArray,
108+
SparseArray,
108109
)
109110
from pandas.core.base import (
110111
PandasObject,
@@ -1909,7 +1910,10 @@ def array_func(values: ArrayLike) -> ArrayLike:
19091910
# and non-applicable functions
19101911
# try to python agg
19111912
# TODO: shouldn't min_count matter?
1912-
if how in ["any", "all", "std", "sem"]:
1913+
# TODO: avoid special casing SparseArray here
1914+
if how in ["any", "all"] and isinstance(values, SparseArray):
1915+
pass
1916+
elif how in ["any", "all", "std", "sem"]:
19131917
raise # TODO: re-raise as TypeError? should not be reached
19141918
else:
19151919
return result

pandas/tests/frame/test_reductions.py

+77
Original file line numberDiff line numberDiff line change
@@ -1938,3 +1938,80 @@ def test_fails_on_non_numeric(kernel):
19381938
msg = "|".join([msg1, msg2])
19391939
with pytest.raises(TypeError, match=msg):
19401940
getattr(df, kernel)(*args)
1941+
1942+
1943+
@pytest.mark.parametrize(
1944+
"method",
1945+
[
1946+
"all",
1947+
"any",
1948+
"count",
1949+
"idxmax",
1950+
"idxmin",
1951+
"kurt",
1952+
"kurtosis",
1953+
"max",
1954+
"mean",
1955+
"median",
1956+
"min",
1957+
"nunique",
1958+
"prod",
1959+
"product",
1960+
"sem",
1961+
"skew",
1962+
"std",
1963+
"sum",
1964+
"var",
1965+
],
1966+
)
1967+
@pytest.mark.parametrize("min_count", [0, 2])
1968+
def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype):
1969+
# GH 54341
1970+
df = DataFrame(
1971+
{
1972+
"a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype),
1973+
"b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype),
1974+
},
1975+
)
1976+
expected_df = DataFrame(
1977+
{
1978+
"a": [0.0, 1.0, 2.0, 3.0],
1979+
"b": [0.0, 1.0, np.nan, 3.0],
1980+
},
1981+
)
1982+
if method in ("count", "nunique"):
1983+
expected_dtype = "int64"
1984+
elif method in ("all", "any"):
1985+
expected_dtype = "boolean"
1986+
elif method in (
1987+
"kurt",
1988+
"kurtosis",
1989+
"mean",
1990+
"median",
1991+
"sem",
1992+
"skew",
1993+
"std",
1994+
"var",
1995+
) and not any_numeric_ea_dtype.startswith("Float"):
1996+
expected_dtype = "Float64"
1997+
else:
1998+
expected_dtype = any_numeric_ea_dtype
1999+
2000+
kwargs = {}
2001+
if method not in ("count", "nunique", "quantile"):
2002+
kwargs["skipna"] = skipna
2003+
if method in ("prod", "product", "sum"):
2004+
kwargs["min_count"] = min_count
2005+
2006+
warn = None
2007+
msg = None
2008+
if not skipna and method in ("idxmax", "idxmin"):
2009+
warn = FutureWarning
2010+
msg = f"The behavior of DataFrame.{method} with all-NA values"
2011+
with tm.assert_produces_warning(warn, match=msg):
2012+
result = getattr(df, method)(axis=1, **kwargs)
2013+
with tm.assert_produces_warning(warn, match=msg):
2014+
expected = getattr(expected_df, method)(axis=1, **kwargs)
2015+
if method not in ("idxmax", "idxmin"):
2016+
expected = expected.astype(expected_dtype)
2017+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)