-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
PERF: axis=1 reductions with EA dtypes #54341
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
0364feb
114c764
8335a18
e51fb7e
bb37fe1
3f7553b
66ac02b
6088110
f0eaa29
266b2af
279766d
826711e
3b31fbb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11083,6 +11083,25 @@ def _get_data() -> DataFrame: | |
).iloc[:0] | ||
result.index = df.index | ||
return result | ||
|
||
if df.shape[1] and name != "kurt": | ||
dtype = find_common_type([arr.dtype for arr in df._mgr.arrays]) | ||
if isinstance(dtype, ExtensionDtype): | ||
name = {"argmax": "idxmax", "argmin": "idxmin"}.get(name, name) | ||
df = df.astype(dtype, copy=False) | ||
arr = concat_compat(list(df._iter_column_arrays())) | ||
nrows, ncols = df.shape | ||
row_index = np.tile(np.arange(nrows), ncols) | ||
col_index = np.repeat(np.arange(ncols), nrows) | ||
ser = Series(arr, index=col_index) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you add comments on what is going on here. i think i get it, but it could be non-obvious to someone who doesnt know e.g. groupby.agg There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated with comments There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you do copy=False here for CoW? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. updated. thanks |
||
result = ser.groupby(row_index).agg(name, **kwds) | ||
result.index = df.index | ||
if not skipna and name not in ("any", "all"): | ||
mask = df.isna().to_numpy(dtype=np.bool_).any(axis=1) | ||
other = -1 if name in ("idxmax", "idxmin") else lib.no_default | ||
result = result.mask(mask, other) | ||
return result | ||
|
||
df = df.T | ||
|
||
# After possibly _get_data and transposing, we are now in the | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -105,6 +105,7 @@ class providing the base-class of operations. | |
ExtensionArray, | ||
FloatingArray, | ||
IntegerArray, | ||
SparseArray, | ||
) | ||
from pandas.core.base import ( | ||
PandasObject, | ||
|
@@ -1905,7 +1906,10 @@ def array_func(values: ArrayLike) -> ArrayLike: | |
# and non-applicable functions | ||
# try to python agg | ||
# TODO: shouldn't min_count matter? | ||
if how in ["any", "all", "std", "sem"]: | ||
# TODO: avoid special casing SparseArray here | ||
if how in ["any", "all"] and isinstance(values, SparseArray): | ||
pass | ||
Comment on lines
+1914
to
+1915
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this is a general issue with SparseArray any/all so really independent of this PR, is that right? I'm thinking this should be fixed in SparseArray itself rather than in groupby code. Would it be okay to xfail any relevant tests? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see the issue with my request above; this would make axis=1 fail for SparseArray whereas it didn't before. I would be okay opening up an issue. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure - by opening an issue, do you mean xfail for now as part of this PR? or open an issue and address that first before this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave this PR as-is; open an issue to cleanup after this is merged (before is okay too) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah - got it. thanks |
||
elif how in ["any", "all", "std", "sem"]: | ||
raise # TODO: re-raise as TypeError? should not be reached | ||
else: | ||
return result | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC kurt is excluded here bc GroupBy doesnt support it. Can you comment to that effect
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a comment here