diff --git a/doc/source/whatsnew/v2.1.0.rst b/doc/source/whatsnew/v2.1.0.rst index 998efdedb1b57..1c6269d50295d 100644 --- a/doc/source/whatsnew/v2.1.0.rst +++ b/doc/source/whatsnew/v2.1.0.rst @@ -614,6 +614,7 @@ Performance improvements - :class:`Period`'s default formatter (`period_format`) is now significantly (~twice) faster. This improves performance of ``str(Period)``, ``repr(Period)``, and :meth:`Period.strftime(fmt=None)`, as well as ``PeriodArray.strftime(fmt=None)``, ``PeriodIndex.strftime(fmt=None)`` and ``PeriodIndex.format(fmt=None)``. Finally, ``to_csv`` operations involving :class:`PeriodArray` or :class:`PeriodIndex` with default ``date_format`` are also significantly accelerated. (:issue:`51459`) - Performance improvement accessing :attr:`arrays.IntegerArrays.dtype` & :attr:`arrays.FloatingArray.dtype` (:issue:`52998`) - Performance improvement for :class:`DataFrameGroupBy`/:class:`SeriesGroupBy` aggregations (e.g. :meth:`DataFrameGroupBy.sum`) with ``engine="numba"`` (:issue:`53731`) +- Performance improvement in :class:`DataFrame` reductions with ``axis=1`` and extension dtypes (:issue:`54341`) - Performance improvement in :class:`DataFrame` reductions with ``axis=None`` and extension dtypes (:issue:`54308`) - Performance improvement in :class:`MultiIndex` and multi-column operations (e.g. :meth:`DataFrame.sort_values`, :meth:`DataFrame.groupby`, :meth:`Series.unstack`) when index/column values are already sorted (:issue:`53806`) - Performance improvement in :class:`Series` reductions (:issue:`52341`) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 4908b535bcb1c..d9746c3b46c9c 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -11166,6 +11166,32 @@ def _get_data() -> DataFrame: ).iloc[:0] result.index = df.index return result + + # kurtosis excluded since groupby does not implement it + if df.shape[1] and name != "kurt": + dtype = find_common_type([arr.dtype for arr in df._mgr.arrays]) + if isinstance(dtype, ExtensionDtype): + # GH 54341: fastpath for EA-backed axis=1 reductions + # This flattens the frame into a single 1D array while keeping + # track of the row and column indices of the original frame. Once + # flattened, grouping by the row indices and aggregating should + # be equivalent to transposing the original frame and aggregating + # with axis=0. + name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name) + df = df.astype(dtype, copy=False) + arr = concat_compat(list(df._iter_column_arrays())) + nrows, ncols = df.shape + row_index = np.tile(np.arange(nrows), ncols) + col_index = np.repeat(np.arange(ncols), nrows) + ser = Series(arr, index=col_index, copy=False) + result = ser.groupby(row_index).agg(name, **kwds) + result.index = df.index + if not skipna and name not in ("any", "all"): + mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1) + other = -1 if name in ("idxmax", "idxmin") else lib.no_default + result = result.mask(mask, other) + return result + df = df.T # After possibly _get_data and transposing, we are now in the diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index 69a99c1bc867c..e327dd9d6c5ff 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -105,6 +105,7 @@ class providing the base-class of operations. ExtensionArray, FloatingArray, IntegerArray, + SparseArray, ) from pandas.core.base import ( PandasObject, @@ -1909,7 +1910,10 @@ def array_func(values: ArrayLike) -> ArrayLike: # and non-applicable functions # try to python agg # TODO: shouldn't min_count matter? - if how in ["any", "all", "std", "sem"]: + # TODO: avoid special casing SparseArray here + if how in ["any", "all"] and isinstance(values, SparseArray): + pass + elif how in ["any", "all", "std", "sem"]: raise # TODO: re-raise as TypeError? should not be reached else: return result diff --git a/pandas/tests/frame/test_reductions.py b/pandas/tests/frame/test_reductions.py index 3768298156550..ab36934533beb 100644 --- a/pandas/tests/frame/test_reductions.py +++ b/pandas/tests/frame/test_reductions.py @@ -1938,3 +1938,80 @@ def test_fails_on_non_numeric(kernel): msg = "|".join([msg1, msg2]) with pytest.raises(TypeError, match=msg): getattr(df, kernel)(*args) + + +@pytest.mark.parametrize( + "method", + [ + "all", + "any", + "count", + "idxmax", + "idxmin", + "kurt", + "kurtosis", + "max", + "mean", + "median", + "min", + "nunique", + "prod", + "product", + "sem", + "skew", + "std", + "sum", + "var", + ], +) +@pytest.mark.parametrize("min_count", [0, 2]) +def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype): + # GH 54341 + df = DataFrame( + { + "a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype), + "b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype), + }, + ) + expected_df = DataFrame( + { + "a": [0.0, 1.0, 2.0, 3.0], + "b": [0.0, 1.0, np.nan, 3.0], + }, + ) + if method in ("count", "nunique"): + expected_dtype = "int64" + elif method in ("all", "any"): + expected_dtype = "boolean" + elif method in ( + "kurt", + "kurtosis", + "mean", + "median", + "sem", + "skew", + "std", + "var", + ) and not any_numeric_ea_dtype.startswith("Float"): + expected_dtype = "Float64" + else: + expected_dtype = any_numeric_ea_dtype + + kwargs = {} + if method not in ("count", "nunique", "quantile"): + kwargs["skipna"] = skipna + if method in ("prod", "product", "sum"): + kwargs["min_count"] = min_count + + warn = None + msg = None + if not skipna and method in ("idxmax", "idxmin"): + warn = FutureWarning + msg = f"The behavior of DataFrame.{method} with all-NA values" + with tm.assert_produces_warning(warn, match=msg): + result = getattr(df, method)(axis=1, **kwargs) + with tm.assert_produces_warning(warn, match=msg): + expected = getattr(expected_df, method)(axis=1, **kwargs) + if method not in ("idxmax", "idxmin"): + expected = expected.astype(expected_dtype) + tm.assert_series_equal(result, expected)