-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
REF: implement groupby.ops.WrappedCythonFunc #40733
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -97,54 +97,172 @@ | |
get_indexer_dict, | ||
) | ||
|
||
_CYTHON_FUNCTIONS = { | ||
"aggregate": { | ||
"add": "group_add", | ||
"prod": "group_prod", | ||
"min": "group_min", | ||
"max": "group_max", | ||
"mean": "group_mean", | ||
"median": "group_median", | ||
"var": "group_var", | ||
"first": "group_nth", | ||
"last": "group_last", | ||
"ohlc": "group_ohlc", | ||
}, | ||
"transform": { | ||
"cumprod": "group_cumprod", | ||
"cumsum": "group_cumsum", | ||
"cummin": "group_cummin", | ||
"cummax": "group_cummax", | ||
"rank": "group_rank", | ||
}, | ||
} | ||
|
||
|
||
@functools.lru_cache(maxsize=None) | ||
def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool): | ||
|
||
dtype_str = dtype.name | ||
ftype = _CYTHON_FUNCTIONS[kind][how] | ||
|
||
# see if there is a fused-type version of function | ||
# only valid for numeric | ||
f = getattr(libgroupby, ftype, None) | ||
if f is not None: | ||
if is_numeric: | ||
return f | ||
elif dtype == object: | ||
if "object" not in f.__signatures__: | ||
# raise NotImplementedError here rather than TypeError later | ||
|
||
class WrappedCythonOp: | ||
""" | ||
Dispatch logic for functions defined in _libs.groupby | ||
""" | ||
|
||
def __init__(self, kind: str, how: str): | ||
self.kind = kind | ||
self.how = how | ||
|
||
_CYTHON_FUNCTIONS = { | ||
"aggregate": { | ||
"add": "group_add", | ||
"prod": "group_prod", | ||
"min": "group_min", | ||
"max": "group_max", | ||
"mean": "group_mean", | ||
"median": "group_median", | ||
"var": "group_var", | ||
"first": "group_nth", | ||
"last": "group_last", | ||
"ohlc": "group_ohlc", | ||
}, | ||
"transform": { | ||
"cumprod": "group_cumprod", | ||
"cumsum": "group_cumsum", | ||
"cummin": "group_cummin", | ||
"cummax": "group_cummax", | ||
"rank": "group_rank", | ||
}, | ||
} | ||
|
||
_cython_arity = {"ohlc": 4} # OHLC | ||
|
||
# Note: we make this a classmethod and pass kind+how so that caching | ||
# works at the class level and not the instance level | ||
@classmethod | ||
@functools.lru_cache(maxsize=None) | ||
def _get_cython_function( | ||
cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool | ||
): | ||
|
||
dtype_str = dtype.name | ||
ftype = cls._CYTHON_FUNCTIONS[kind][how] | ||
|
||
# see if there is a fused-type version of function | ||
# only valid for numeric | ||
f = getattr(libgroupby, ftype, None) | ||
if f is not None: | ||
if is_numeric: | ||
return f | ||
elif dtype == object: | ||
if "object" not in f.__signatures__: | ||
# raise NotImplementedError here rather than TypeError later | ||
raise NotImplementedError( | ||
f"function is not implemented for this dtype: " | ||
f"[how->{how},dtype->{dtype_str}]" | ||
) | ||
return f | ||
|
||
raise NotImplementedError( | ||
f"function is not implemented for this dtype: " | ||
f"[how->{how},dtype->{dtype_str}]" | ||
) | ||
|
||
def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool): | ||
""" | ||
Find the appropriate cython function, casting if necessary. | ||
|
||
Parameters | ||
---------- | ||
values : np.ndarray | ||
is_numeric : bool | ||
|
||
Returns | ||
------- | ||
func : callable | ||
values : np.ndarray | ||
""" | ||
how = self.how | ||
kind = self.kind | ||
|
||
if how in ["median", "cumprod"]: | ||
# these two only have float64 implementations | ||
if is_numeric: | ||
values = ensure_float64(values) | ||
else: | ||
raise NotImplementedError( | ||
f"function is not implemented for this dtype: " | ||
f"[how->{how},dtype->{dtype_str}]" | ||
f"[how->{how},dtype->{values.dtype.name}]" | ||
) | ||
return f | ||
func = getattr(libgroupby, f"group_{how}_float64") | ||
return func, values | ||
|
||
raise NotImplementedError( | ||
f"function is not implemented for this dtype: " | ||
f"[how->{how},dtype->{dtype_str}]" | ||
) | ||
func = self._get_cython_function(kind, how, values.dtype, is_numeric) | ||
|
||
if values.dtype.kind in ["i", "u"]: | ||
if how in ["add", "var", "prod", "mean", "ohlc"]: | ||
# result may still include NaN, so we have to cast | ||
values = ensure_float64(values) | ||
|
||
return func, values | ||
|
||
def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False): | ||
""" | ||
Check if we can do this operation with our cython functions. | ||
|
||
Raises | ||
------ | ||
NotImplementedError | ||
This is either not a valid function for this dtype, or | ||
valid but not implemented in cython. | ||
""" | ||
how = self.how | ||
|
||
if is_numeric: | ||
# never an invalid op for those dtypes, so return early as fastpath | ||
return | ||
|
||
if is_categorical_dtype(dtype) or is_sparse(dtype): | ||
# categoricals are only 1d, so we | ||
# are not setup for dim transforming | ||
raise NotImplementedError(f"{dtype} dtype not supported") | ||
elif is_datetime64_any_dtype(dtype): | ||
# we raise NotImplemented if this is an invalid operation | ||
# entirely, e.g. adding datetimes | ||
if how in ["add", "prod", "cumsum", "cumprod"]: | ||
raise NotImplementedError( | ||
f"datetime64 type does not support {how} operations" | ||
) | ||
elif is_timedelta64_dtype(dtype): | ||
if how in ["prod", "cumprod"]: | ||
raise NotImplementedError( | ||
f"timedelta64 type does not support {how} operations" | ||
) | ||
|
||
def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape: | ||
how = self.how | ||
kind = self.kind | ||
|
||
arity = self._cython_arity.get(how, 1) | ||
|
||
out_shape: Shape | ||
if how == "ohlc": | ||
out_shape = (ngroups, 4) | ||
elif arity > 1: | ||
raise NotImplementedError( | ||
"arity of more than 1 is not supported for the 'how' argument" | ||
) | ||
elif kind == "transform": | ||
out_shape = values.shape | ||
else: | ||
out_shape = (ngroups,) + values.shape[1:] | ||
return out_shape | ||
|
||
def get_out_dtype(self, dtype: np.dtype) -> np.dtype: | ||
how = self.how | ||
|
||
if how == "rank": | ||
out_dtype = "float64" | ||
else: | ||
if is_numeric_dtype(dtype): | ||
out_dtype = f"{dtype.kind}{dtype.itemsize}" | ||
else: | ||
out_dtype = "object" | ||
return np.dtype(out_dtype) | ||
|
||
|
||
class BaseGrouper: | ||
|
@@ -437,8 +555,6 @@ def get_group_levels(self) -> List[Index]: | |
# ------------------------------------------------------------ | ||
# Aggregation functions | ||
|
||
_cython_arity = {"ohlc": 4} # OHLC | ||
|
||
@final | ||
def _is_builtin_func(self, arg): | ||
""" | ||
|
@@ -447,80 +563,6 @@ def _is_builtin_func(self, arg): | |
""" | ||
return SelectionMixin._builtin_table.get(arg, arg) | ||
|
||
@final | ||
def _get_cython_func_and_vals( | ||
self, kind: str, how: str, values: np.ndarray, is_numeric: bool | ||
): | ||
""" | ||
Find the appropriate cython function, casting if necessary. | ||
|
||
Parameters | ||
---------- | ||
kind : str | ||
how : str | ||
values : np.ndarray | ||
is_numeric : bool | ||
|
||
Returns | ||
------- | ||
func : callable | ||
values : np.ndarray | ||
""" | ||
if how in ["median", "cumprod"]: | ||
# these two only have float64 implementations | ||
if is_numeric: | ||
values = ensure_float64(values) | ||
else: | ||
raise NotImplementedError( | ||
f"function is not implemented for this dtype: " | ||
f"[how->{how},dtype->{values.dtype.name}]" | ||
) | ||
func = getattr(libgroupby, f"group_{how}_float64") | ||
return func, values | ||
|
||
func = _get_cython_function(kind, how, values.dtype, is_numeric) | ||
|
||
if values.dtype.kind in ["i", "u"]: | ||
if how in ["add", "var", "prod", "mean", "ohlc"]: | ||
# result may still include NaN, so we have to cast | ||
values = ensure_float64(values) | ||
|
||
return func, values | ||
|
||
@final | ||
def _disallow_invalid_ops( | ||
self, dtype: DtypeObj, how: str, is_numeric: bool = False | ||
): | ||
""" | ||
Check if we can do this operation with our cython functions. | ||
|
||
Raises | ||
------ | ||
NotImplementedError | ||
This is either not a valid function for this dtype, or | ||
valid but not implemented in cython. | ||
""" | ||
if is_numeric: | ||
# never an invalid op for those dtypes, so return early as fastpath | ||
return | ||
|
||
if is_categorical_dtype(dtype) or is_sparse(dtype): | ||
# categoricals are only 1d, so we | ||
# are not setup for dim transforming | ||
raise NotImplementedError(f"{dtype} dtype not supported") | ||
elif is_datetime64_any_dtype(dtype): | ||
# we raise NotImplemented if this is an invalid operation | ||
# entirely, e.g. adding datetimes | ||
if how in ["add", "prod", "cumsum", "cumprod"]: | ||
raise NotImplementedError( | ||
f"datetime64 type does not support {how} operations" | ||
) | ||
elif is_timedelta64_dtype(dtype): | ||
if how in ["prod", "cumprod"]: | ||
raise NotImplementedError( | ||
f"timedelta64 type does not support {how} operations" | ||
) | ||
|
||
@final | ||
def _ea_wrap_cython_operation( | ||
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs | ||
|
@@ -593,9 +635,11 @@ def _cython_operation( | |
dtype = values.dtype | ||
is_numeric = is_numeric_dtype(dtype) | ||
|
||
cy_op = WrappedCythonOp(kind=kind, how=how) | ||
|
||
# can we do this operation with our cython functions | ||
# if not raise NotImplementedError | ||
self._disallow_invalid_ops(dtype, how, is_numeric) | ||
cy_op.disallow_invalid_ops(dtype, is_numeric) | ||
|
||
if is_extension_array_dtype(dtype): | ||
return self._ea_wrap_cython_operation( | ||
|
@@ -637,43 +681,23 @@ def _cython_operation( | |
if not is_complex_dtype(dtype): | ||
values = ensure_float64(values) | ||
|
||
arity = self._cython_arity.get(how, 1) | ||
ngroups = self.ngroups | ||
comp_ids, _, _ = self.group_info | ||
|
||
assert axis == 1 | ||
values = values.T | ||
if how == "ohlc": | ||
out_shape = (ngroups, 4) | ||
elif arity > 1: | ||
raise NotImplementedError( | ||
"arity of more than 1 is not supported for the 'how' argument" | ||
) | ||
elif kind == "transform": | ||
out_shape = values.shape | ||
else: | ||
out_shape = (ngroups,) + values.shape[1:] | ||
|
||
func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric) | ||
|
||
if how == "rank": | ||
out_dtype = "float" | ||
else: | ||
if is_numeric: | ||
out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}" | ||
else: | ||
out_dtype = "object" | ||
|
||
codes, _, _ = self.group_info | ||
out_shape = cy_op.get_output_shape(ngroups, values) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. future, would just have a single method here where you pass
|
||
func, values = cy_op.get_cython_func_and_vals(values, is_numeric) | ||
out_dtype = cy_op.get_out_dtype(values.dtype) | ||
|
||
result = maybe_fill(np.empty(out_shape, dtype=out_dtype)) | ||
if kind == "aggregate": | ||
counts = np.zeros(self.ngroups, dtype=np.int64) | ||
result = self._aggregate(result, counts, values, codes, func, min_count) | ||
counts = np.zeros(ngroups, dtype=np.int64) | ||
func(result, counts, values, comp_ids, min_count) | ||
elif kind == "transform": | ||
# TODO: min_count | ||
result = self._transform( | ||
result, values, codes, func, is_datetimelike, **kwargs | ||
) | ||
func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) | ||
|
||
if is_integer_dtype(result.dtype) and not is_datetimelike: | ||
mask = result == iNaT | ||
|
@@ -697,28 +721,6 @@ def _cython_operation( | |
|
||
return op_result | ||
|
||
@final | ||
def _aggregate( | ||
self, result, counts, values, comp_ids, agg_func, min_count: int = -1 | ||
): | ||
if agg_func is libgroupby.group_nth: | ||
# different signature from the others | ||
agg_func(result, counts, values, comp_ids, min_count, rank=1) | ||
else: | ||
agg_func(result, counts, values, comp_ids, min_count) | ||
|
||
return result | ||
|
||
@final | ||
def _transform( | ||
self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs | ||
): | ||
|
||
_, _, ngroups = self.group_info | ||
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs) | ||
|
||
return result | ||
|
||
def agg_series(self, obj: Series, func: F): | ||
# Caller is responsible for checking ngroups != 0 | ||
assert self.ngroups != 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
prob should be after the class variables / methods, but nbd