Skip to content

REF: simplify groupby.ops #46196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 3, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 45 additions & 36 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ def __init__(self, kind: str, how: str):
"min": "group_min",
"max": "group_max",
"mean": "group_mean",
"median": "group_median",
"median": "group_median_float64",
"var": "group_var",
"first": "group_nth",
"last": "group_last",
"ohlc": "group_ohlc",
},
"transform": {
"cumprod": "group_cumprod",
"cumprod": "group_cumprod_float64",
"cumsum": "group_cumsum",
"cummin": "group_cummin",
"cummax": "group_cummax",
Expand Down Expand Up @@ -161,52 +161,54 @@ def _get_cython_function(
if is_numeric:
return f
elif dtype == object:
if "object" not in f.__signatures__:
if how in ["median", "cumprod"]:
# no fused types -> no __signatures__
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
elif "object" not in f.__signatures__:
# raise NotImplementedError here rather than TypeError later
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{dtype_str}]"
)
return f
else:
raise NotImplementedError(
"This should not be reached. Please report a bug at "
"github.com/pandas-dev/pandas/",
dtype,
)

def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
"""
Find the appropriate cython function, casting if necessary.
Cast numeric dtypes to float64 for functions that only support that.

Parameters
----------
values : np.ndarray
is_numeric : bool

Returns
-------
func : callable
values : np.ndarray
"""
how = self.how
kind = self.kind

if how in ["median", "cumprod"]:
# these two only have float64 implementations
if is_numeric:
values = ensure_float64(values)
else:
raise NotImplementedError(
f"function is not implemented for this dtype: "
f"[how->{how},dtype->{values.dtype.name}]"
)
func = getattr(libgroupby, f"group_{how}_float64")
return func, values

func = self._get_cython_function(kind, how, values.dtype, is_numeric)
# We should only get here with is_numeric, as non-numeric cases
# should raise in _get_cython_function
values = ensure_float64(values)

if values.dtype.kind in ["i", "u"]:
elif values.dtype.kind in ["i", "u"]:
if how in ["add", "var", "prod", "mean", "ohlc"]:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

return func, values
return values

# TODO: general case implementation overridable by EAs.
def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
"""
Check if we can do this operation with our cython functions.
Expand Down Expand Up @@ -235,6 +237,7 @@ def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
# are not setup for dim transforming
raise NotImplementedError(f"{dtype} dtype not supported")
elif is_datetime64_any_dtype(dtype):
# TODO: same for period_dtype? no for these methods with Period
# we raise NotImplemented if this is an invalid operation
# entirely, e.g. adding datetimes
if how in ["add", "prod", "cumsum", "cumprod"]:
Expand Down Expand Up @@ -262,7 +265,7 @@ def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
out_shape = (ngroups,) + values.shape[1:]
return out_shape

def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
def _get_out_dtype(self, dtype: np.dtype) -> np.dtype:
how = self.how

if how == "rank":
Expand All @@ -282,6 +285,7 @@ def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
... # pragma: no cover

# TODO: general case implementation overridable by EAs.
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
"""
Get the desired dtype of a result based on the
Expand Down Expand Up @@ -329,7 +333,6 @@ def _ea_wrap_cython_operation(
If we have an ExtensionArray, unwrap, call _cython_operation, and
re-wrap if appropriate.
"""
# TODO: general case implementation overridable by EAs.
if isinstance(values, BaseMaskedArray) and self.uses_mask():
return self._masked_ea_wrap_cython_operation(
values,
Expand Down Expand Up @@ -357,7 +360,8 @@ def _ea_wrap_cython_operation(

return self._reconstruct_ea_result(values, res_values)

def _ea_to_cython_values(self, values: ExtensionArray):
# TODO: general case implementation overridable by EAs.
def _ea_to_cython_values(self, values: ExtensionArray) -> np.ndarray:
# GH#43682
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
# All of the functions implemented here are ordinal, so we can
Expand All @@ -378,23 +382,24 @@ def _ea_to_cython_values(self, values: ExtensionArray):
)
return npvalues

def _reconstruct_ea_result(self, values, res_values):
# TODO: general case implementation overridable by EAs.
def _reconstruct_ea_result(
self, values: ExtensionArray, res_values: np.ndarray
) -> ExtensionArray:
"""
Construct an ExtensionArray result from an ndarray result.
"""
# TODO: allow EAs to override this logic

if isinstance(
values.dtype, (BooleanDtype, IntegerDtype, FloatingDtype, StringDtype)
):
if isinstance(values.dtype, (BaseMaskedDtype, StringDtype)):
dtype = self._get_result_dtype(values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif needs_i8_conversion(values.dtype):
assert res_values.dtype.kind != "f" # just to be on the safe side
i8values = res_values.view("i8")
return type(values)(i8values, dtype=values.dtype)
# error: Too many arguments for "ExtensionArray"
return type(values)(i8values, dtype=values.dtype) # type: ignore[call-arg]

raise NotImplementedError

Expand Down Expand Up @@ -429,13 +434,16 @@ def _masked_ea_wrap_cython_operation(
)

dtype = self._get_result_dtype(orig_values.dtype)
assert isinstance(dtype, BaseMaskedDtype)
cls = dtype.construct_array_type()
# TODO: avoid cast as res_values *should* already have the right
# dtype; last attempt ran into trouble on 32bit linux build
res_values = res_values.astype(dtype.type, copy=False)

if self.kind != "aggregate":
return cls(res_values.astype(dtype.type, copy=False), mask)
out_mask = mask
else:
return cls(res_values.astype(dtype.type, copy=False), result_mask)
out_mask = result_mask

return orig_values._maybe_mask_result(res_values, out_mask)

@final
def _cython_op_ndim_compat(
Expand Down Expand Up @@ -521,8 +529,9 @@ def _call_cython_op(
result_mask = result_mask.T

out_shape = self._get_output_shape(ngroups, values)
func, values = self.get_cython_func_and_vals(values, is_numeric)
out_dtype = self.get_out_dtype(values.dtype)
func = self._get_cython_function(self.kind, self.how, values.dtype, is_numeric)
values = self._get_cython_vals(values)
out_dtype = self._get_out_dtype(values.dtype)

result = maybe_fill(np.empty(out_shape, dtype=out_dtype))
if self.kind == "aggregate":
Expand Down