Skip to content

Commit e69df38

Browse files
authored
REF: implement groupby.ops.WrappedCythonFunc (#40733)
1 parent fa2da60 commit e69df38

File tree

1 file changed

+172
-170
lines changed

1 file changed

+172
-170
lines changed

pandas/core/groupby/ops.py

+172-170
Original file line numberDiff line numberDiff line change
@@ -97,54 +97,172 @@
9797
get_indexer_dict,
9898
)
9999

100-
_CYTHON_FUNCTIONS = {
101-
"aggregate": {
102-
"add": "group_add",
103-
"prod": "group_prod",
104-
"min": "group_min",
105-
"max": "group_max",
106-
"mean": "group_mean",
107-
"median": "group_median",
108-
"var": "group_var",
109-
"first": "group_nth",
110-
"last": "group_last",
111-
"ohlc": "group_ohlc",
112-
},
113-
"transform": {
114-
"cumprod": "group_cumprod",
115-
"cumsum": "group_cumsum",
116-
"cummin": "group_cummin",
117-
"cummax": "group_cummax",
118-
"rank": "group_rank",
119-
},
120-
}
121-
122-
123-
@functools.lru_cache(maxsize=None)
124-
def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool):
125-
126-
dtype_str = dtype.name
127-
ftype = _CYTHON_FUNCTIONS[kind][how]
128-
129-
# see if there is a fused-type version of function
130-
# only valid for numeric
131-
f = getattr(libgroupby, ftype, None)
132-
if f is not None:
133-
if is_numeric:
134-
return f
135-
elif dtype == object:
136-
if "object" not in f.__signatures__:
137-
# raise NotImplementedError here rather than TypeError later
100+
101+
class WrappedCythonOp:
102+
"""
103+
Dispatch logic for functions defined in _libs.groupby
104+
"""
105+
106+
def __init__(self, kind: str, how: str):
107+
self.kind = kind
108+
self.how = how
109+
110+
_CYTHON_FUNCTIONS = {
111+
"aggregate": {
112+
"add": "group_add",
113+
"prod": "group_prod",
114+
"min": "group_min",
115+
"max": "group_max",
116+
"mean": "group_mean",
117+
"median": "group_median",
118+
"var": "group_var",
119+
"first": "group_nth",
120+
"last": "group_last",
121+
"ohlc": "group_ohlc",
122+
},
123+
"transform": {
124+
"cumprod": "group_cumprod",
125+
"cumsum": "group_cumsum",
126+
"cummin": "group_cummin",
127+
"cummax": "group_cummax",
128+
"rank": "group_rank",
129+
},
130+
}
131+
132+
_cython_arity = {"ohlc": 4} # OHLC
133+
134+
# Note: we make this a classmethod and pass kind+how so that caching
135+
# works at the class level and not the instance level
136+
@classmethod
137+
@functools.lru_cache(maxsize=None)
138+
def _get_cython_function(
139+
cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool
140+
):
141+
142+
dtype_str = dtype.name
143+
ftype = cls._CYTHON_FUNCTIONS[kind][how]
144+
145+
# see if there is a fused-type version of function
146+
# only valid for numeric
147+
f = getattr(libgroupby, ftype, None)
148+
if f is not None:
149+
if is_numeric:
150+
return f
151+
elif dtype == object:
152+
if "object" not in f.__signatures__:
153+
# raise NotImplementedError here rather than TypeError later
154+
raise NotImplementedError(
155+
f"function is not implemented for this dtype: "
156+
f"[how->{how},dtype->{dtype_str}]"
157+
)
158+
return f
159+
160+
raise NotImplementedError(
161+
f"function is not implemented for this dtype: "
162+
f"[how->{how},dtype->{dtype_str}]"
163+
)
164+
165+
def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
166+
"""
167+
Find the appropriate cython function, casting if necessary.
168+
169+
Parameters
170+
----------
171+
values : np.ndarray
172+
is_numeric : bool
173+
174+
Returns
175+
-------
176+
func : callable
177+
values : np.ndarray
178+
"""
179+
how = self.how
180+
kind = self.kind
181+
182+
if how in ["median", "cumprod"]:
183+
# these two only have float64 implementations
184+
if is_numeric:
185+
values = ensure_float64(values)
186+
else:
138187
raise NotImplementedError(
139188
f"function is not implemented for this dtype: "
140-
f"[how->{how},dtype->{dtype_str}]"
189+
f"[how->{how},dtype->{values.dtype.name}]"
141190
)
142-
return f
191+
func = getattr(libgroupby, f"group_{how}_float64")
192+
return func, values
143193

144-
raise NotImplementedError(
145-
f"function is not implemented for this dtype: "
146-
f"[how->{how},dtype->{dtype_str}]"
147-
)
194+
func = self._get_cython_function(kind, how, values.dtype, is_numeric)
195+
196+
if values.dtype.kind in ["i", "u"]:
197+
if how in ["add", "var", "prod", "mean", "ohlc"]:
198+
# result may still include NaN, so we have to cast
199+
values = ensure_float64(values)
200+
201+
return func, values
202+
203+
def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
204+
"""
205+
Check if we can do this operation with our cython functions.
206+
207+
Raises
208+
------
209+
NotImplementedError
210+
This is either not a valid function for this dtype, or
211+
valid but not implemented in cython.
212+
"""
213+
how = self.how
214+
215+
if is_numeric:
216+
# never an invalid op for those dtypes, so return early as fastpath
217+
return
218+
219+
if is_categorical_dtype(dtype) or is_sparse(dtype):
220+
# categoricals are only 1d, so we
221+
# are not setup for dim transforming
222+
raise NotImplementedError(f"{dtype} dtype not supported")
223+
elif is_datetime64_any_dtype(dtype):
224+
# we raise NotImplemented if this is an invalid operation
225+
# entirely, e.g. adding datetimes
226+
if how in ["add", "prod", "cumsum", "cumprod"]:
227+
raise NotImplementedError(
228+
f"datetime64 type does not support {how} operations"
229+
)
230+
elif is_timedelta64_dtype(dtype):
231+
if how in ["prod", "cumprod"]:
232+
raise NotImplementedError(
233+
f"timedelta64 type does not support {how} operations"
234+
)
235+
236+
def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
237+
how = self.how
238+
kind = self.kind
239+
240+
arity = self._cython_arity.get(how, 1)
241+
242+
out_shape: Shape
243+
if how == "ohlc":
244+
out_shape = (ngroups, 4)
245+
elif arity > 1:
246+
raise NotImplementedError(
247+
"arity of more than 1 is not supported for the 'how' argument"
248+
)
249+
elif kind == "transform":
250+
out_shape = values.shape
251+
else:
252+
out_shape = (ngroups,) + values.shape[1:]
253+
return out_shape
254+
255+
def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
256+
how = self.how
257+
258+
if how == "rank":
259+
out_dtype = "float64"
260+
else:
261+
if is_numeric_dtype(dtype):
262+
out_dtype = f"{dtype.kind}{dtype.itemsize}"
263+
else:
264+
out_dtype = "object"
265+
return np.dtype(out_dtype)
148266

149267

150268
class BaseGrouper:
@@ -437,8 +555,6 @@ def get_group_levels(self) -> List[Index]:
437555
# ------------------------------------------------------------
438556
# Aggregation functions
439557

440-
_cython_arity = {"ohlc": 4} # OHLC
441-
442558
@final
443559
def _is_builtin_func(self, arg):
444560
"""
@@ -447,80 +563,6 @@ def _is_builtin_func(self, arg):
447563
"""
448564
return SelectionMixin._builtin_table.get(arg, arg)
449565

450-
@final
451-
def _get_cython_func_and_vals(
452-
self, kind: str, how: str, values: np.ndarray, is_numeric: bool
453-
):
454-
"""
455-
Find the appropriate cython function, casting if necessary.
456-
457-
Parameters
458-
----------
459-
kind : str
460-
how : str
461-
values : np.ndarray
462-
is_numeric : bool
463-
464-
Returns
465-
-------
466-
func : callable
467-
values : np.ndarray
468-
"""
469-
if how in ["median", "cumprod"]:
470-
# these two only have float64 implementations
471-
if is_numeric:
472-
values = ensure_float64(values)
473-
else:
474-
raise NotImplementedError(
475-
f"function is not implemented for this dtype: "
476-
f"[how->{how},dtype->{values.dtype.name}]"
477-
)
478-
func = getattr(libgroupby, f"group_{how}_float64")
479-
return func, values
480-
481-
func = _get_cython_function(kind, how, values.dtype, is_numeric)
482-
483-
if values.dtype.kind in ["i", "u"]:
484-
if how in ["add", "var", "prod", "mean", "ohlc"]:
485-
# result may still include NaN, so we have to cast
486-
values = ensure_float64(values)
487-
488-
return func, values
489-
490-
@final
491-
def _disallow_invalid_ops(
492-
self, dtype: DtypeObj, how: str, is_numeric: bool = False
493-
):
494-
"""
495-
Check if we can do this operation with our cython functions.
496-
497-
Raises
498-
------
499-
NotImplementedError
500-
This is either not a valid function for this dtype, or
501-
valid but not implemented in cython.
502-
"""
503-
if is_numeric:
504-
# never an invalid op for those dtypes, so return early as fastpath
505-
return
506-
507-
if is_categorical_dtype(dtype) or is_sparse(dtype):
508-
# categoricals are only 1d, so we
509-
# are not setup for dim transforming
510-
raise NotImplementedError(f"{dtype} dtype not supported")
511-
elif is_datetime64_any_dtype(dtype):
512-
# we raise NotImplemented if this is an invalid operation
513-
# entirely, e.g. adding datetimes
514-
if how in ["add", "prod", "cumsum", "cumprod"]:
515-
raise NotImplementedError(
516-
f"datetime64 type does not support {how} operations"
517-
)
518-
elif is_timedelta64_dtype(dtype):
519-
if how in ["prod", "cumprod"]:
520-
raise NotImplementedError(
521-
f"timedelta64 type does not support {how} operations"
522-
)
523-
524566
@final
525567
def _ea_wrap_cython_operation(
526568
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
@@ -593,9 +635,11 @@ def _cython_operation(
593635
dtype = values.dtype
594636
is_numeric = is_numeric_dtype(dtype)
595637

638+
cy_op = WrappedCythonOp(kind=kind, how=how)
639+
596640
# can we do this operation with our cython functions
597641
# if not raise NotImplementedError
598-
self._disallow_invalid_ops(dtype, how, is_numeric)
642+
cy_op.disallow_invalid_ops(dtype, is_numeric)
599643

600644
if is_extension_array_dtype(dtype):
601645
return self._ea_wrap_cython_operation(
@@ -637,43 +681,23 @@ def _cython_operation(
637681
if not is_complex_dtype(dtype):
638682
values = ensure_float64(values)
639683

640-
arity = self._cython_arity.get(how, 1)
641684
ngroups = self.ngroups
685+
comp_ids, _, _ = self.group_info
642686

643687
assert axis == 1
644688
values = values.T
645-
if how == "ohlc":
646-
out_shape = (ngroups, 4)
647-
elif arity > 1:
648-
raise NotImplementedError(
649-
"arity of more than 1 is not supported for the 'how' argument"
650-
)
651-
elif kind == "transform":
652-
out_shape = values.shape
653-
else:
654-
out_shape = (ngroups,) + values.shape[1:]
655-
656-
func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric)
657-
658-
if how == "rank":
659-
out_dtype = "float"
660-
else:
661-
if is_numeric:
662-
out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}"
663-
else:
664-
out_dtype = "object"
665689

666-
codes, _, _ = self.group_info
690+
out_shape = cy_op.get_output_shape(ngroups, values)
691+
func, values = cy_op.get_cython_func_and_vals(values, is_numeric)
692+
out_dtype = cy_op.get_out_dtype(values.dtype)
667693

668694
result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
669695
if kind == "aggregate":
670-
counts = np.zeros(self.ngroups, dtype=np.int64)
671-
result = self._aggregate(result, counts, values, codes, func, min_count)
696+
counts = np.zeros(ngroups, dtype=np.int64)
697+
func(result, counts, values, comp_ids, min_count)
672698
elif kind == "transform":
673699
# TODO: min_count
674-
result = self._transform(
675-
result, values, codes, func, is_datetimelike, **kwargs
676-
)
700+
func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
677701

678702
if is_integer_dtype(result.dtype) and not is_datetimelike:
679703
mask = result == iNaT
@@ -697,28 +721,6 @@ def _cython_operation(
697721

698722
return op_result
699723

700-
@final
701-
def _aggregate(
702-
self, result, counts, values, comp_ids, agg_func, min_count: int = -1
703-
):
704-
if agg_func is libgroupby.group_nth:
705-
# different signature from the others
706-
agg_func(result, counts, values, comp_ids, min_count, rank=1)
707-
else:
708-
agg_func(result, counts, values, comp_ids, min_count)
709-
710-
return result
711-
712-
@final
713-
def _transform(
714-
self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs
715-
):
716-
717-
_, _, ngroups = self.group_info
718-
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)
719-
720-
return result
721-
722724
def agg_series(self, obj: Series, func: F):
723725
# Caller is responsible for checking ngroups != 0
724726
assert self.ngroups != 0

0 commit comments

Comments
 (0)