Skip to content

PERF: axis=1 reductions with EA dtypes #54341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Aug 13, 2023
Merged
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.1.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ Performance improvements
- Performance improvement in :func:`concat` (:issue:`52291`, :issue:`52290`)
- :class:`Period`'s default formatter (`period_format`) is now significantly (~twice) faster. This improves performance of ``str(Period)``, ``repr(Period)``, and :meth:`Period.strftime(fmt=None)`, as well as ``PeriodArray.strftime(fmt=None)``, ``PeriodIndex.strftime(fmt=None)`` and ``PeriodIndex.format(fmt=None)``. Finally, ``to_csv`` operations involving :class:`PeriodArray` or :class:`PeriodIndex` with default ``date_format`` are also significantly accelerated. (:issue:`51459`)
- Performance improvement accessing :attr:`arrays.IntegerArrays.dtype` & :attr:`arrays.FloatingArray.dtype` (:issue:`52998`)
- Performance improvement in :class:`DataFrame` reductions with ``axis=1`` and extension dtypes (:issue:`54341`)
- Performance improvement in :class:`DataFrame` reductions with ``axis=None`` and extension dtypes (:issue:`54308`)
- Performance improvement in :class:`MultiIndex` and multi-column operations (e.g. :meth:`DataFrame.sort_values`, :meth:`DataFrame.groupby`, :meth:`Series.unstack`) when index/column values are already sorted (:issue:`53806`)
- Performance improvement in :class:`Series` reductions (:issue:`52341`)
Expand Down
19 changes: 19 additions & 0 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -11083,6 +11083,25 @@ def _get_data() -> DataFrame:
).iloc[:0]
result.index = df.index
return result

if df.shape[1] and name != "kurt":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC kurt is excluded here bc GroupBy doesnt support it. Can you comment to that effect

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a comment here

dtype = find_common_type([arr.dtype for arr in df._mgr.arrays])
if isinstance(dtype, ExtensionDtype):
name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name)
df = df.astype(dtype, copy=False)
arr = concat_compat(list(df._iter_column_arrays()))
nrows, ncols = df.shape
row_index = np.tile(np.arange(nrows), ncols)
col_index = np.repeat(np.arange(ncols), nrows)
ser = Series(arr, index=col_index)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add comments on what is going on here. i think i get it, but it could be non-obvious to someone who doesnt know e.g. groupby.agg

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated with comments

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you do copy=False here for CoW?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated. thanks

result = ser.groupby(row_index).agg(name, **kwds)
result.index = df.index
if not skipna and name not in ("any", "all"):
mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1)
other = -1 if name in ("idxmax", "idxmin") else lib.no_default
result = result.mask(mask, other)
return result

df = df.T

# After possibly _get_data and transposing, we are now in the
Expand Down
6 changes: 5 additions & 1 deletion pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class providing the base-class of operations.
ExtensionArray,
FloatingArray,
IntegerArray,
SparseArray,
)
from pandas.core.base import (
PandasObject,
Expand Down Expand Up @@ -1905,7 +1906,10 @@ def array_func(values: ArrayLike) -> ArrayLike:
# and non-applicable functions
# try to python agg
# TODO: shouldn't min_count matter?
if how in ["any", "all", "std", "sem"]:
# TODO: avoid special casing SparseArray here
if how in ["any", "all"] and isinstance(values, SparseArray):
pass
Comment on lines +1914 to +1915
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this is a general issue with SparseArray any/all so really independent of this PR, is that right? I'm thinking this should be fixed in SparseArray itself rather than in groupby code. Would it be okay to xfail any relevant tests?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see the issue with my request above; this would make axis=1 fail for SparseArray whereas it didn't before. I would be okay opening up an issue.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure - by opening an issue, do you mean xfail for now as part of this PR? or open an issue and address that first before this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this PR as-is; open an issue to cleanup after this is merged (before is okay too)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah - got it. thanks

elif how in ["any", "all", "std", "sem"]:
raise # TODO: re-raise as TypeError? should not be reached
else:
return result
Expand Down
77 changes: 77 additions & 0 deletions pandas/tests/frame/test_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1941,3 +1941,80 @@ def test_fails_on_non_numeric(kernel):
msg = "|".join([msg1, msg2])
with pytest.raises(TypeError, match=msg):
getattr(df, kernel)(*args)


@pytest.mark.parametrize(
"method",
[
"all",
"any",
"count",
"idxmax",
"idxmin",
"kurt",
"kurtosis",
"max",
"mean",
"median",
"min",
"nunique",
"prod",
"product",
"sem",
"skew",
"std",
"sum",
"var",
],
)
@pytest.mark.parametrize("min_count", [0, 2])
def test_numeric_ea_axis_1(method, skipna, min_count, any_numeric_ea_dtype):
# GH 54341
df = DataFrame(
{
"a": Series([0, 1, 2, 3], dtype=any_numeric_ea_dtype),
"b": Series([0, 1, pd.NA, 3], dtype=any_numeric_ea_dtype),
},
)
expected_df = DataFrame(
{
"a": [0.0, 1.0, 2.0, 3.0],
"b": [0.0, 1.0, np.nan, 3.0],
},
)
if method in ("count", "nunique"):
expected_dtype = "int64"
elif method in ("all", "any"):
expected_dtype = "boolean"
elif method in (
"kurt",
"kurtosis",
"mean",
"median",
"sem",
"skew",
"std",
"var",
) and not any_numeric_ea_dtype.startswith("Float"):
expected_dtype = "Float64"
else:
expected_dtype = any_numeric_ea_dtype

kwargs = {}
if method not in ("count", "nunique", "quantile"):
kwargs["skipna"] = skipna
if method in ("prod", "product", "sum"):
kwargs["min_count"] = min_count

warn = None
msg = None
if not skipna and method in ("idxmax", "idxmin"):
warn = FutureWarning
msg = f"The behavior of DataFrame.{method} with all-NA values"
with tm.assert_produces_warning(warn, match=msg):
result = getattr(df, method)(axis=1, **kwargs)
with tm.assert_produces_warning(warn, match=msg):
expected = getattr(expected_df, method)(axis=1, **kwargs)
if method not in ("idxmax", "idxmin"):
expected = expected.astype(expected_dtype)
tm.assert_series_equal(result, expected)