Skip to content

PERF/BUG: use masked algo in groupby cummin and cummax #40651

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 44 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7cd4dc3
wip
mzeitlin11 Mar 26, 2021
91984dc
wip
mzeitlin11 Mar 26, 2021
b371cc5
wip
mzeitlin11 Mar 26, 2021
69cce96
wip
mzeitlin11 Mar 26, 2021
dd7f324
wip
mzeitlin11 Mar 26, 2021
64680d4
wip
mzeitlin11 Mar 26, 2021
be16f65
wip
mzeitlin11 Mar 26, 2021
9442846
wip
mzeitlin11 Mar 26, 2021
f089175
wip
mzeitlin11 Mar 26, 2021
31409f8
wip
mzeitlin11 Mar 26, 2021
0c05f74
wip
mzeitlin11 Mar 26, 2021
18dcc94
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Mar 26, 2021
5c60a1f
wip
mzeitlin11 Mar 26, 2021
f0c27ce
PERF: use masked algo in groupby cummin and cummax
mzeitlin11 Mar 27, 2021
2fa80ad
Avoid mask copy
mzeitlin11 Mar 27, 2021
280c7e5
Update whatsnew
mzeitlin11 Mar 27, 2021
dca28cf
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 1, 2021
7e2fbe0
Merge fixup
mzeitlin11 Apr 1, 2021
0ebb97a
Follow transpose
mzeitlin11 Apr 1, 2021
0009dfd
Compute mask usage inside algo
mzeitlin11 Apr 1, 2021
6663832
try optional
mzeitlin11 Apr 1, 2021
8247f82
WIP
mzeitlin11 Apr 1, 2021
71e1c4f
Use more contiguity
mzeitlin11 Apr 1, 2021
c6cf9ee
Shrink benchmark
mzeitlin11 Apr 1, 2021
02768ec
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 1, 2021
836175b
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 2, 2021
293dc6e
Revert unrelated
mzeitlin11 Apr 2, 2021
478c6c9
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 6, 2021
fa45a9a
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 8, 2021
1632b81
Merge remote-tracking branch 'origin/master' into perf/masked_cummin/max
mzeitlin11 Apr 12, 2021
1bb344e
Remove merge conflict relic
mzeitlin11 Apr 12, 2021
97d9eea
Update doc/source/whatsnew/v1.3.0.rst
mzeitlin11 Apr 13, 2021
892a92a
Update doc/source/whatsnew/v1.3.0.rst
mzeitlin11 Apr 13, 2021
a239a68
Update pandas/core/groupby/ops.py
mzeitlin11 Apr 13, 2021
f98ca35
Merge remote-tracking branch 'origin' into perf/masked_cummin/max
mzeitlin11 Apr 13, 2021
e7ed12f
Merge branch 'perf/masked_cummin/max' of github.com:/mzeitlin11/panda…
mzeitlin11 Apr 13, 2021
a1422ba
Address comments
mzeitlin11 Apr 13, 2021
482a209
Change random generation style
mzeitlin11 Apr 13, 2021
4e7404d
Merge remote-tracking branch 'origin' into perf/masked_cummin/max
mzeitlin11 Apr 18, 2021
251c02a
Use conditional instead of partial
mzeitlin11 Apr 18, 2021
3de7e5e
Remove ensure_int_or_float
mzeitlin11 Apr 18, 2021
237f86f
Remove unnecessary condition
mzeitlin11 Apr 18, 2021
a1b0c04
Merge remote-tracking branch 'origin' into perf/masked_cummin/max
mzeitlin11 Apr 19, 2021
5e1dac4
Merge remote-tracking branch 'origin' into perf/masked_cummin/max
mzeitlin11 Apr 20, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions asv_bench/benchmarks/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,33 @@ def time_frame_agg(self, dtype, method):
self.df.groupby("key").agg(method)


class CumminMax:
param_names = ["dtype", "method"]
params = [
["float64", "int64", "Float64", "Int64"],
["cummin", "cummax"],
]

def setup(self, dtype, method):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this costly? worth using setup_cache?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks to take ~0.25s. So might be worth caching, but appears setup_cache can't be parameterized, so would have to ugly up the benchmark a bit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just make N // 10

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep have shrunk benchmark

N = 1_000_000
vals = np.random.randint(0, 1000, (N, 10))
null_vals = vals.astype(float, copy=True)
null_vals[::2, :] = np.nan
null_vals[::3, :] = np.nan
df = DataFrame(vals, columns=list("abcdefghij"), dtype=dtype)
null_df = DataFrame(null_vals, columns=list("abcdefghij"), dtype=dtype)
df["key"] = np.random.randint(0, 100, size=N)
null_df["key"] = np.random.randint(0, 100, size=N)
self.df = df
self.null_df = null_df

def time_frame_transform(self, dtype, method):
self.df.groupby("key").transform(method)

def time_frame_transform_many_nulls(self, dtype, method):
self.null_df.groupby("key").transform(method)


class RankWithTies:
# GH 21237
param_names = ["dtype", "tie_method"]
Expand Down
4 changes: 4 additions & 0 deletions doc/source/whatsnew/v1.3.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ Performance improvements
- Performance improvement in :class:`Styler` where render times are more than 50% reduced (:issue:`39972` :issue:`39952`)
- Performance improvement in :meth:`core.window.ewm.ExponentialMovingWindow.mean` with ``times`` (:issue:`39784`)
- Performance improvement in :meth:`.GroupBy.apply` when requiring the python fallback implementation (:issue:`40176`)
- Performance improvement in :class:`SeriesGroupBy` and :class:`DataFrameGroupBy` when using methods ``cummin`` and ``cummax`` with nullable data types (:issue:`37493`)
-

.. ---------------------------------------------------------------------------

Expand Down Expand Up @@ -633,6 +635,8 @@ Groupby/resample/rolling
- Bug in :class:`core.window.ewm.ExponentialMovingWindow` when calling ``__getitem__`` would not retain ``com``, ``span``, ``alpha`` or ``halflife`` attributes (:issue:`40164`)
- :class:`core.window.ewm.ExponentialMovingWindow` now raises a ``NotImplementedError`` when specifying ``times`` with ``adjust=False`` due to an incorrect calculation (:issue:`40098`)
- Bug in :meth:`Series.asfreq` and :meth:`DataFrame.asfreq` dropping rows when the index is not sorted (:issue:`39805`)
- Bug in :class:`SeriesGroupBy` and :class:`DataFrameGroupBy` computing wrong result for methods ``cummin`` and ``cummax`` with nullable data types too large to roundtrip when casting to float (:issue:`37493`)
-

Reshaping
^^^^^^^^^
Expand Down
64 changes: 57 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1255,9 +1255,11 @@ def group_min(groupby_t[:, ::1] out,
@cython.wraparound(False)
def group_cummin_max(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
uint8_t[:, ::1] mask,
const intp_t[:] labels,
int ngroups,
bint is_datetimelike,
bint use_mask,
bint compute_max):
"""
Cumulative minimum/maximum of columns of `values`, in row groups `labels`.
Expand All @@ -1268,12 +1270,19 @@ def group_cummin_max(groupby_t[:, ::1] out,
Array to store cummin/max in.
values : array
Values to take cummin/max of.
mask : array[uint8_t]
If `use_mask`, then indices represent missing values,
otherwise will be passed as a zeroed array
labels : np.ndarray[np.intp]
Labels to group by.
ngroups : int
Number of groups, larger than all entries of `labels`.
is_datetimelike : bool
True if `values` contains datetime-like entries.
use_mask : bool
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this actually worth it? we don't do this anywhere else. can you show head to head where you pass this always as True (vs your change method)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another option is to allow mask to be None, so use_mask can be defined inside the function as use_mask = mask is not None (a similar approach is used hashtable_class_helper.pxi for the unique/factorize implementations)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, definitely a cleaner solution.

Addressing your comment @jreback, will take a look back at perf diff with forcing mask usage. IIRC the main reason for the optional mask was a slowdown in float with the cost of isna along with mask lookups in the algo ending up larger than existing null checking in cython.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't seem like as much of a slowdown as feared, ASV's with forced mask usage (from 09b73fc)

       before           after         ratio
     [e69df38c]       [293dc6ea]
     <master>         <not_optional_mask>
          324±8ms         277±10ms    ~0.86  groupby.CumminMax.time_frame_transform('Float64', 'cummax')
         326±10ms         294±10ms    ~0.90  groupby.CumminMax.time_frame_transform('Float64', 'cummin')
-         692±8ms         393±10ms     0.57  groupby.CumminMax.time_frame_transform('Int64', 'cummax')
-        688±20ms          387±7ms     0.56  groupby.CumminMax.time_frame_transform('Int64', 'cummin')
          368±7ms         372±20ms     1.01  groupby.CumminMax.time_frame_transform('float64', 'cummax')
          382±3ms         380±20ms     0.99  groupby.CumminMax.time_frame_transform('float64', 'cummin')
-         598±5ms          536±3ms     0.90  groupby.CumminMax.time_frame_transform('int64', 'cummax')
          606±4ms         575±30ms     0.95  groupby.CumminMax.time_frame_transform('int64', 'cummin')
-         645±6ms          266±6ms     0.41  groupby.CumminMax.time_frame_transform_many_nulls('Float64', 'cummax')
-        643±10ms         275±10ms     0.43  groupby.CumminMax.time_frame_transform_many_nulls('Float64', 'cummin')
-         734±7ms          259±6ms     0.35  groupby.CumminMax.time_frame_transform_many_nulls('Int64', 'cummax')
-         752±5ms          251±9ms     0.33  groupby.CumminMax.time_frame_transform_many_nulls('Int64', 'cummin')
          364±6ms         382±10ms     1.05  groupby.CumminMax.time_frame_transform_many_nulls('float64', 'cummax')
          362±7ms          378±8ms     1.04  groupby.CumminMax.time_frame_transform_many_nulls('float64', 'cummin')
       1.14±0.02s       1.13±0.02s     0.99  groupby.CumminMax.time_frame_transform_many_nulls('int64', 'cummax')
       1.14±0.02s       1.12±0.03s     0.98  groupby.CumminMax.time_frame_transform_many_nulls('int64', 'cummin')

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this forced mask solution is cleaner because along with the simpler cython algo, it would eventually allow cleaning up some of the kludge around NaT/datetimelike. Which can lose information on edge cases with forced casts to float from

elif is_integer_dtype(dtype):
# we use iNaT for the missing value on ints
# so pre-convert to guard this condition
if (values == iNaT).any():
values = ensure_float64(values)

True if the mask should be used (otherwise we continue
as if it is not a masked algorithm). Avoids the cost
of checking for a completely zeroed mask.
compute_max : bool
True if cumulative maximum should be computed, False
if cumulative minimum should be computed
Expand All @@ -1287,6 +1296,7 @@ def group_cummin_max(groupby_t[:, ::1] out,
groupby_t val, mval
ndarray[groupby_t, ndim=2] accum
intp_t lab
bint val_is_nan

N, K = (<object>values).shape
accum = np.empty((ngroups, K), dtype=np.asarray(values).dtype)
Expand All @@ -1304,11 +1314,29 @@ def group_cummin_max(groupby_t[:, ::1] out,
if lab < 0:
continue
for j in range(K):
val = values[i, j]
val_is_nan = False

if use_mask:
if mask[i, j]:

# `out` does not need to be set since it
# will be masked anyway
val_is_nan = True
else:

# If using the mask, we can avoid grabbing the
# value unless necessary
val = values[i, j]

if _treat_as_na(val, is_datetimelike):
out[i, j] = val
# Otherwise, `out` must be set accordingly if the
# value is missing
else:
val = values[i, j]
if _treat_as_na(val, is_datetimelike):
val_is_nan = True
out[i, j] = val

if not val_is_nan:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it make sense to implement this as a separate function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a strong opinion about this. The question would be the tradeoff between a bit more complexity/branching vs duplication/increased package size (if we end up adding masked support to a lot more of these grouped algos)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any guess what the impact on package size is?

potential duplication might be addressed by making e.g. refactoring L1340-1347 or L 1302-1308 into helper functions

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on rough estimate, binaries generated from groupby.pyx take up ~5% of total _libs. So based on the figure of _libs taking up 17MB from #30741, cost of full duplication would be around 0.8-0.9 MB. But like you mentioned above, some duplication could be avoided, 1 MB should be an upper bound.

mval = accum[lab, j]
if compute_max:
if val > mval:
Expand All @@ -1323,19 +1351,41 @@ def group_cummin_max(groupby_t[:, ::1] out,
@cython.wraparound(False)
def group_cummin(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
uint8_t[:, ::1] mask,
const intp_t[:] labels,
int ngroups,
bint is_datetimelike):
bint is_datetimelike,
bint use_mask):
"""See group_cummin_max.__doc__"""
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=False)
group_cummin_max(
out,
values,
mask,
labels,
ngroups,
is_datetimelike,
use_mask,
compute_max=False
)


@cython.boundscheck(False)
@cython.wraparound(False)
def group_cummax(groupby_t[:, ::1] out,
ndarray[groupby_t, ndim=2] values,
uint8_t[:, ::1] mask,
const intp_t[:] labels,
int ngroups,
bint is_datetimelike):
bint is_datetimelike,
bint use_mask):
"""See group_cummin_max.__doc__"""
group_cummin_max(out, values, labels, ngroups, is_datetimelike, compute_max=True)
group_cummin_max(
out,
values,
mask,
labels,
ngroups,
is_datetimelike,
use_mask,
compute_max=True
)
120 changes: 103 additions & 17 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,10 @@
maybe_fill,
)

from pandas.core.arrays.masked import (
BaseMaskedArray,
BaseMaskedDtype,
)
from pandas.core.base import SelectionMixin
import pandas.core.common as com
from pandas.core.frame import DataFrame
Expand Down Expand Up @@ -118,6 +122,11 @@
},
}

_CYTHON_MASKED_FUNCTIONS = {
"cummin",
"cummax",
}


@functools.lru_cache(maxsize=None)
def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool):
Expand Down Expand Up @@ -155,6 +164,10 @@ def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool)
return func


def cython_function_uses_mask(kind: str) -> bool:
return kind in _CYTHON_MASKED_FUNCTIONS


class BaseGrouper:
"""
This is an internal Grouper class, which actually holds
Expand Down Expand Up @@ -574,9 +587,52 @@ def _ea_wrap_cython_operation(
f"function is not implemented for this dtype: {values.dtype}"
)

@final
def _masked_ea_wrap_cython_operation(
self,
kind: str,
values: BaseMaskedArray,
how: str,
axis: int,
min_count: int = -1,
**kwargs,
) -> BaseMaskedArray:
"""
Equivalent of `_ea_wrap_cython_operation`, but optimized for masked EA's
and cython algorithms which accept a mask.
"""
orig_values = values

# isna just directly returns self._mask, so copy here to prevent
# modifying the original
mask = isna(values).copy()
arr = values._data

if is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this an entirely different funtion? pls integrate to existing infra

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a different function because it's specific for MaskedArrays. Having this as a separate function is consistent with how it's currently implemented IMO (with a similar separate function for generic EAs).

It could also be another elif check in _ea_wrap_cython_operation, but it's not that it would result in less code or so, and since _ea_wrap_cython_operation already gets quite complicated, I think this separate function is good.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In an initial version I tried to fold this into _ea_wrap_cython_operation, but thought this smaller function was a cleaner solution since the conditionals here can remain much simpler. While adding a function for just supporting masked cummax, cummin seems wasteful, this infrastructure should extend to more masked groupby algos.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see my comments. This MUST integrate with the existing infrastructure (or refactor that). Duplicating is -1

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The whole point of this PR is to add a special handling for masked arrays, to ensure to pass through the mask to the groupby cython algo. This will always add some code (and the code in _masked_ea_wrap_cython_operation is specific to masked arrays).
  2. IMO this is integrated with the existing infrastructure: it integrates nicely into the existing _cython_operation and follows the same pattern as we already have for _ea_wrap_cython_operation

If you don't like how the added code is structured right now, please do a concrete suggestion of how you would do it differently.

# IntegerArray or BooleanArray
arr = ensure_int_or_float(arr)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW im planning to kill off this function; for EAs this is always just arr.to_numpy(dtype="float64", na_value=np.nan)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for bringing up - realized this whole condition can be simplified since we actually have an ndarray at this point


res_values = self._cython_operation(
kind, arr, how, axis, min_count, mask=mask, **kwargs
)
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
assert isinstance(dtype, BaseMaskedDtype)
cls = dtype.construct_array_type()

return cls(
res_values.astype(dtype.type, copy=False), mask.astype(bool, copy=False)
)

@final
def _cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
self,
kind: str,
values,
how: str,
axis: int,
min_count: int = -1,
mask: np.ndarray | None = None,
**kwargs,
) -> ArrayLike:
"""
Returns the values of a cython operation.
Expand All @@ -598,10 +654,16 @@ def _cython_operation(
# if not raise NotImplementedError
self._disallow_invalid_ops(dtype, how, is_numeric)

func_uses_mask = cython_function_uses_mask(how)
if is_extension_array_dtype(dtype):
return self._ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
)
if isinstance(values, BaseMaskedArray) and func_uses_mask:
return self._masked_ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i really don't understand all of this code duplication. this is adding huge complexity. pls reduce it.

Copy link
Member

@jorisvandenbossche jorisvandenbossche Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jeff, did you actually read the previous responses to your similar comment? (https://github.com/pandas-dev/pandas/pull/40651/files#r603319910) Can you then please answer there to the concrete reasons given.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes and its a terrible pattern.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this duplication of code is ridiculous. We have a VERY large codebase. Having this kind of separate logic is amazingling confusing & is humungous tech debt. This is heavily used code and needs to be carefully modified.

Copy link
Member Author

@mzeitlin11 mzeitlin11 Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand the concern about adding code complexity - my thinking was that if the goal is for nullable types to become the default in pandas, then direct support makes sense. And in that case, nullable types would need to be special-cased somewhere, and I think the separate function is cleaner than interleaving in _ea_wrap_cython_operation.

If direct support for nullable dtypes is not desired, we can just close this. If it is, I'll keep trying to think of ways to achieve this without adding more code, but any suggestions there would be welcome!

Copy link
Member

@jorisvandenbossche jorisvandenbossche Apr 13, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proper support for nullable dtypes is certainly desired (how to add it exactly can of course be discussed), so thanks a lot @mzeitlin11 for your efforts here.

AFAIK, it's correct we need some special casing for it somewhere (that's the whole point of this PR is to add special handling for it).
Where exactly to put this special casing can of course be discussed, but to me the separate helper method instead of interleaving it in _ea_wrap_cython_operation seems good (I don't think that interleaving it into the existing _ea_wrap_cython_operation would result in fewer added lines of code (and would be harder to read)).

@jreback please try to stay constructive (eg answer to our arguments or provide concrete suggestions on where you would put it / how you would do it differently) and please mind your language (there is no need to call the approach taken by a contributor "terrible").

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I agree with @jorisvandenbossche on phrasing concerns. Even the best of us slip up here from time to time.

  2. if the goal is for nullable types to become the default in pandas

This decision has not been made.

  1. I think the separate function is cleaner than interleaving in _ea_wrap_cython_operation.

Agreed.

  1. My preferred dispatch logic would look something like:
def _cython_operation(...)
    if is_ea_dtype(...):
       return self. _ea_wrap_cython_operation(...)
    [status quo]

def _ea_wrap_cython_operation(...):
    if should_use_mask(...):
        return self._masked_ea_wrap_cython_operation(...)
    [status quo]

as Joris correctly pointed out, that is not viable ATM. I think a lot of this dispatch logic eventually belongs in WrappedCythonOp (which i've been vaguely planning on doing next time there aren't any open PRs touching this code), at which point we can reconsider flattening this

  1. My other preferred dispatch logic would not be in this file at all, but be implemented as a method on the EA subclass. I'm really uncomfortable with this code depending on MaskedArray implementation details, seeing as how there has been discussion of swapping them out for something arrow-based.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel if you plan further refactoring of this code, I'm happy to just mothball this pr for now. The real benefit won't come in until more groupby algos allow a mask on this path anyway, so not worth adding now if it's just going to cause more pain in future refactoring.

I also like the idea of approach 5 instead of this - could start looking into that if you think it's a promising direction.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you plan further refactoring of this code, I'm happy to just mothball this pr for now.

From today's call, I think the plan is to move forward with this first.

I also like the idea of approach 5 instead of this - could start looking into that if you think it's a promising direction.

Long-term I think this is the right way to go to get the general case right, so I'd encourage you if you're interested in trying to implement this on the EA- separate PR(s).

)
else:
return self._ea_wrap_cython_operation(
kind, values, how, axis, min_count, **kwargs
)

is_datetimelike = needs_i8_conversion(dtype)

Expand All @@ -613,7 +675,7 @@ def _cython_operation(
elif is_integer_dtype(dtype):
# we use iNaT for the missing value on ints
# so pre-convert to guard this condition
if (values == iNaT).any():
if mask is None and (values == iNaT).any():
values = ensure_float64(values)
else:
values = ensure_int_or_float(values)
Expand All @@ -628,6 +690,9 @@ def _cython_operation(
swapped = False
if vdim == 1:
values = values[:, None]
if mask is not None:
mask = mask[:, None]

out_shape = (self.ngroups, arity)
else:
if axis > 0:
Expand All @@ -641,6 +706,10 @@ def _cython_operation(
out_shape = (self.ngroups,) + values.shape[1:]

func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric)
use_mask = mask is not None
if func_uses_mask:
if mask is None:
mask = np.zeros_like(values, dtype=np.uint8, order="C")

if how == "rank":
out_dtype = "float"
Expand All @@ -650,25 +719,23 @@ def _cython_operation(
else:
out_dtype = "object"

codes, _, _ = self.group_info

if kind == "aggregate":
codes, _, _ = self.group_info
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
counts = np.zeros(self.ngroups, dtype=np.int64)
result = self._aggregate(result, counts, values, codes, func, min_count)
elif kind == "transform":
result = maybe_fill(np.empty(values.shape, dtype=out_dtype))

# TODO: min_count
result = self._transform(
result, values, codes, func, is_datetimelike, **kwargs
result, values, func, is_datetimelike, use_mask, mask, **kwargs
)

if is_integer_dtype(result.dtype) and not is_datetimelike:
mask = result == iNaT
if mask.any():
if not use_mask and is_integer_dtype(result.dtype) and not is_datetimelike:
result_mask = result == iNaT
if result_mask.any():
result = result.astype("float64")
result[mask] = np.nan
result[result_mask] = np.nan

if kind == "aggregate" and self._filter_empty_groups and not counts.all():
assert result.ndim != 2
Expand Down Expand Up @@ -704,11 +771,30 @@ def _aggregate(

@final
def _transform(
self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs
):
self,
result: np.ndarray,
values: np.ndarray,
transform_func,
is_datetimelike: bool,
use_mask: bool,
mask: np.ndarray | None,
**kwargs,
) -> np.ndarray:

_, _, ngroups = self.group_info
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
comp_ids, _, ngroups = self.group_info
if mask is not None:
transform_func(
result,
values,
mask,
comp_ids,
ngroups,
is_datetimelike,
use_mask,
**kwargs,
)
else:
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)

return result

Expand Down
Loading