Skip to content

REF: de-duplicate WrappedCythonOp._ea_wrap_cython_operation #42521

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 13, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 38 additions & 78 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,95 +341,54 @@ def _ea_wrap_cython_operation(
comp_ids=comp_ids,
**kwargs,
)
orig_values = values

if isinstance(orig_values, (DatetimeArray, PeriodArray)):
if isinstance(values, (DatetimeArray, PeriodArray, TimedeltaArray)):
# All of the functions implemented here are ordinal, so we can
# operate on the tz-naive equivalents
npvalues = orig_values._ndarray.view("M8[ns]")
res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)
if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
# preserve float64 dtype
return res_values

res_values = res_values.view("i8")
result = type(orig_values)(res_values, dtype=orig_values.dtype)
return result

elif isinstance(orig_values, TimedeltaArray):
# We have an ExtensionArray but not ExtensionDtype
res_values = self._cython_op_ndim_compat(
orig_values._ndarray,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)
if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
# preserve float64 dtype
return res_values

# otherwise res_values has the same dtype as original values
return type(orig_values)(res_values)

npvalues = values._ndarray.view("M8[ns]")
elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)):
# IntegerArray or BooleanArray
npvalues = values.to_numpy("float64", na_value=np.nan)
res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)
if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = self._get_result_dtype(orig_values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

elif isinstance(values.dtype, FloatingDtype):
# FloatingArray
npvalues = values.to_numpy(
values.dtype.numpy_dtype,
na_value=np.nan,
)
res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
npvalues = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
else:
raise NotImplementedError(
f"function is not implemented for this dtype: {values.dtype}"
)
if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

dtype = self._get_result_dtype(orig_values.dtype)
res_values = self._cython_op_ndim_compat(
npvalues,
min_count=min_count,
ngroups=ngroups,
comp_ids=comp_ids,
mask=None,
**kwargs,
)

if self.how in ["rank"]:
# i.e. how in WrappedCythonOp.cast_blocklist, since
# other cast_blocklist methods dont go through cython_operation
return res_values

return self._reconstruct_ea_result(values, res_values)

def _reconstruct_ea_result(self, values, res_values):
"""
Construct an ExtensionArray result from an ndarray result.
"""
# TODO: allow EAs to override this logic

if isinstance(values.dtype, (BooleanDtype, _IntegerDtype, FloatingDtype)):
dtype = self._get_result_dtype(values.dtype)
cls = dtype.construct_array_type()
return cls._from_sequence(res_values, dtype=dtype)

raise NotImplementedError(
f"function is not implemented for this dtype: {values.dtype}"
)
elif needs_i8_conversion(values.dtype):
i8values = res_values.view("i8")
return type(values)(i8values, dtype=values.dtype)

raise NotImplementedError
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prob should expand this error message (followon), though maybe this cannot be hit at all? e.g. L357 hits first?


@final
def _masked_ea_wrap_cython_operation(
Expand Down Expand Up @@ -478,6 +437,8 @@ def _cython_op_ndim_compat(
if values.ndim == 1:
# expand to 2d, dispatch, then squeeze if appropriate
values2d = values[None, :]
if mask is not None:
mask = mask[None, :]
res = self._call_cython_op(
values2d,
min_count=min_count,
Expand Down Expand Up @@ -533,9 +494,8 @@ def _call_cython_op(
values = ensure_float64(values)

values = values.T

if mask is not None:
mask = mask.reshape(values.shape, order="C")
mask = mask.T

out_shape = self._get_output_shape(ngroups, values)
func, values = self.get_cython_func_and_vals(values, is_numeric)
Expand Down