Skip to content

REF: implement groupby.ops.WrappedCythonFunc #40733

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 2, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
342 changes: 172 additions & 170 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,54 +97,172 @@
get_indexer_dict,
)

_CYTHON_FUNCTIONS = {
"aggregate": {
"add": "group_add",
"prod": "group_prod",
"min": "group_min",
"max": "group_max",
"mean": "group_mean",
"median": "group_median",
"var": "group_var",
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
},
"transform": {
"cumprod": "group_cumprod",
"cumsum": "group_cumsum",
"cummin": "group_cummin",
"cummax": "group_cummax",
"rank": "group_rank",
},
}


@functools.lru_cache(maxsize=None)
def _get_cython_function(kind: str, how: str, dtype: np.dtype, is_numeric: bool):

dtype_str = dtype.name
ftype = _CYTHON_FUNCTIONS[kind][how]

# see if there is a fused-type version of function
# only valid for numeric
f = getattr(libgroupby, ftype, None)
if f is not None:
if is_numeric:
return f
elif dtype == object:
if "object" not in f.__signatures__:
# raise NotImplementedError here rather than TypeError later

class WrappedCythonOp:
"""
Dispatch logic for functions defined in _libs.groupby
"""

def __init__(self, kind: str, how: str):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob should be after the class variables / methods, but nbd

self.kind = kind
self.how = how

_CYTHON_FUNCTIONS = {
"aggregate": {
"add": "group_add",
"prod": "group_prod",
"min": "group_min",
"max": "group_max",
"mean": "group_mean",
"median": "group_median",
"var": "group_var",
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
},
"transform": {
"cumprod": "group_cumprod",
"cumsum": "group_cumsum",
"cummin": "group_cummin",
"cummax": "group_cummax",
"rank": "group_rank",
},
}

_cython_arity = {"ohlc": 4} # OHLC

# Note: we make this a classmethod and pass kind+how so that caching
# works at the class level and not the instance level
@classmethod
@functools.lru_cache(maxsize=None)
def _get_cython_function(
cls, kind: str, how: str, dtype: np.dtype, is_numeric: bool
):

dtype_str = dtype.name
ftype = cls._CYTHON_FUNCTIONS[kind][how]

# see if there is a fused-type version of function
# only valid for numeric
f = getattr(libgroupby, ftype, None)
if f is not None:
if is_numeric:
return f
elif dtype == object:
if "object" not in f.__signatures__:
# raise NotImplementedError here rather than TypeError later
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
return f

raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)

def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
"""
Find the appropriate cython function, casting if necessary.

Parameters
----------
values : np.ndarray
is_numeric : bool

Returns
-------
func : callable
values : np.ndarray
"""
how = self.how
kind = self.kind

if how in ["median", "cumprod"]:
# these two only have float64 implementations
if is_numeric:
values = ensure_float64(values)
else:
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
f"[how->{how},dtype->{values.dtype.name}]"
)
return f
func = getattr(libgroupby, f"group_{how}_float64")
return func, values

raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
func = self._get_cython_function(kind, how, values.dtype, is_numeric)

if values.dtype.kind in ["i", "u"]:
if how in ["add", "var", "prod", "mean", "ohlc"]:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

return func, values

def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
"""
Check if we can do this operation with our cython functions.

Raises
------
NotImplementedError
This is either not a valid function for this dtype, or
valid but not implemented in cython.
"""
how = self.how

if is_numeric:
# never an invalid op for those dtypes, so return early as fastpath
return

if is_categorical_dtype(dtype) or is_sparse(dtype):
# categoricals are only 1d, so we
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
how = self.how
kind = self.kind

arity = self._cython_arity.get(how, 1)

out_shape: Shape
if how == "ohlc":
out_shape = (ngroups, 4)
elif arity > 1:
raise NotImplementedError(
"arity of more than 1 is not supported for the 'how' argument"
)
elif kind == "transform":
out_shape = values.shape
else:
out_shape = (ngroups,) + values.shape[1:]
return out_shape

def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
how = self.how

if how == "rank":
out_dtype = "float64"
else:
if is_numeric_dtype(dtype):
out_dtype = f"{dtype.kind}{dtype.itemsize}"
else:
out_dtype = "object"
return np.dtype(out_dtype)


class BaseGrouper:
Expand Down Expand Up @@ -437,8 +555,6 @@ def get_group_levels(self) -> List[Index]:
# ------------------------------------------------------------
# Aggregation functions

_cython_arity = {"ohlc": 4} # OHLC

@final
def _is_builtin_func(self, arg):
"""
Expand All @@ -447,80 +563,6 @@ def _is_builtin_func(self, arg):
"""
return SelectionMixin._builtin_table.get(arg, arg)

@final
def _get_cython_func_and_vals(
self, kind: str, how: str, values: np.ndarray, is_numeric: bool
):
"""
Find the appropriate cython function, casting if necessary.

Parameters
----------
kind : str
how : str
values : np.ndarray
is_numeric : bool

Returns
-------
func : callable
values : np.ndarray
"""
if how in ["median", "cumprod"]:
# these two only have float64 implementations
if is_numeric:
values = ensure_float64(values)
else:
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{values.dtype.name}]"
)
func = getattr(libgroupby, f"group_{how}_float64")
return func, values

func = _get_cython_function(kind, how, values.dtype, is_numeric)

if values.dtype.kind in ["i", "u"]:
if how in ["add", "var", "prod", "mean", "ohlc"]:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

return func, values

@final
def _disallow_invalid_ops(
self, dtype: DtypeObj, how: str, is_numeric: bool = False
):
"""
Check if we can do this operation with our cython functions.

Raises
------
NotImplementedError
This is either not a valid function for this dtype, or
valid but not implemented in cython.
"""
if is_numeric:
# never an invalid op for those dtypes, so return early as fastpath
return

if is_categorical_dtype(dtype) or is_sparse(dtype):
# categoricals are only 1d, so we
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
raise NotImplementedError(
f"datetime64 type does not support {how} operations"
)
elif is_timedelta64_dtype(dtype):
if how in ["prod", "cumprod"]:
raise NotImplementedError(
f"timedelta64 type does not support {how} operations"
)

@final
def _ea_wrap_cython_operation(
self, kind: str, values, how: str, axis: int, min_count: int = -1, **kwargs
Expand Down Expand Up @@ -593,9 +635,11 @@ def _cython_operation(
dtype = values.dtype
is_numeric = is_numeric_dtype(dtype)

cy_op = WrappedCythonOp(kind=kind, how=how)

# can we do this operation with our cython functions
# if not raise NotImplementedError
self._disallow_invalid_ops(dtype, how, is_numeric)
cy_op.disallow_invalid_ops(dtype, is_numeric)

if is_extension_array_dtype(dtype):
return self._ea_wrap_cython_operation(
Expand Down Expand Up @@ -637,43 +681,23 @@ def _cython_operation(
if not is_complex_dtype(dtype):
values = ensure_float64(values)

arity = self._cython_arity.get(how, 1)
ngroups = self.ngroups
comp_ids, _, _ = self.group_info

assert axis == 1
values = values.T
if how == "ohlc":
out_shape = (ngroups, 4)
elif arity > 1:
raise NotImplementedError(
"arity of more than 1 is not supported for the 'how' argument"
)
elif kind == "transform":
out_shape = values.shape
else:
out_shape = (ngroups,) + values.shape[1:]

func, values = self._get_cython_func_and_vals(kind, how, values, is_numeric)

if how == "rank":
out_dtype = "float"
else:
if is_numeric:
out_dtype = f"{values.dtype.kind}{values.dtype.itemsize}"
else:
out_dtype = "object"

codes, _, _ = self.group_info
out_shape = cy_op.get_output_shape(ngroups, values)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

future, would just have a single method here where you pass ngroup, values, may be able to determine is_numeric inside the WrappedCythonOps (instead of here), and just return the items you need, though you can take it a step further and have a WrappedAggregateOp and WrappedTransformOp, something like

op = WrappedCythonOps(kind, how, values.....) # other args
# op is WrappedTransform / WrappedAggregate
result = op.get_result()

func, values = cy_op.get_cython_func_and_vals(values, is_numeric)
out_dtype = cy_op.get_out_dtype(values.dtype)

result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
if kind == "aggregate":
counts = np.zeros(self.ngroups, dtype=np.int64)
result = self._aggregate(result, counts, values, codes, func, min_count)
counts = np.zeros(ngroups, dtype=np.int64)
func(result, counts, values, comp_ids, min_count)
elif kind == "transform":
# TODO: min_count
result = self._transform(
result, values, codes, func, is_datetimelike, **kwargs
)
func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)

if is_integer_dtype(result.dtype) and not is_datetimelike:
mask = result == iNaT
Expand All @@ -697,28 +721,6 @@ def _cython_operation(

return op_result

@final
def _aggregate(
self, result, counts, values, comp_ids, agg_func, min_count: int = -1
):
if agg_func is libgroupby.group_nth:
# different signature from the others
agg_func(result, counts, values, comp_ids, min_count, rank=1)
else:
agg_func(result, counts, values, comp_ids, min_count)

return result

@final
def _transform(
self, result, values, comp_ids, transform_func, is_datetimelike: bool, **kwargs
):

_, _, ngroups = self.group_info
transform_func(result, values, comp_ids, ngroups, is_datetimelike, **kwargs)

return result

def agg_series(self, obj: Series, func: F):
# Caller is responsible for checking ngroups != 0
assert self.ngroups != 0
Expand Down