Skip to content

Commit 3c4ab29

Browse files
jbrockmendelJulianWgs
authored andcommitted
CLN: groupby.ops follow-up cleanup (pandas-dev#41204)
1 parent 9e5eecc commit 3c4ab29

File tree

2 files changed

+130
-87
lines changed

2 files changed

+130
-87
lines changed

pandas/core/groupby/groupby.py

+1
Original file line numberDiff line numberDiff line change
@@ -814,6 +814,7 @@ class GroupBy(BaseGroupBy[FrameOrSeries]):
814814
grouper: ops.BaseGrouper
815815
as_index: bool
816816

817+
@final
817818
def __init__(
818819
self,
819820
obj: FrameOrSeries,

pandas/core/groupby/ops.py

+129-87
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
Hashable,
1515
Iterator,
1616
Sequence,
17+
overload,
1718
)
1819

1920
import numpy as np
@@ -47,23 +48,35 @@
4748
is_categorical_dtype,
4849
is_complex_dtype,
4950
is_datetime64_any_dtype,
50-
is_datetime64tz_dtype,
5151
is_extension_array_dtype,
52-
is_float_dtype,
5352
is_integer_dtype,
5453
is_numeric_dtype,
55-
is_period_dtype,
5654
is_sparse,
5755
is_timedelta64_dtype,
5856
needs_i8_conversion,
5957
)
58+
from pandas.core.dtypes.dtypes import ExtensionDtype
6059
from pandas.core.dtypes.generic import ABCCategoricalIndex
6160
from pandas.core.dtypes.missing import (
6261
isna,
6362
maybe_fill,
6463
)
6564

66-
from pandas.core.arrays import ExtensionArray
65+
from pandas.core.arrays import (
66+
DatetimeArray,
67+
ExtensionArray,
68+
PeriodArray,
69+
TimedeltaArray,
70+
)
71+
from pandas.core.arrays.boolean import BooleanDtype
72+
from pandas.core.arrays.floating import (
73+
Float64Dtype,
74+
FloatingDtype,
75+
)
76+
from pandas.core.arrays.integer import (
77+
Int64Dtype,
78+
_IntegerDtype,
79+
)
6780
from pandas.core.arrays.masked import (
6881
BaseMaskedArray,
6982
BaseMaskedDtype,
@@ -194,7 +207,7 @@ def get_cython_func_and_vals(self, values: np.ndarray, is_numeric: bool):
194207

195208
return func, values
196209

197-
def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
210+
def _disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
198211
"""
199212
Check if we can do this operation with our cython functions.
200213
@@ -230,7 +243,7 @@ def disallow_invalid_ops(self, dtype: DtypeObj, is_numeric: bool = False):
230243
if how in ["prod", "cumprod"]:
231244
raise TypeError(f"timedelta64 type does not support {how} operations")
232245

233-
def get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
246+
def _get_output_shape(self, ngroups: int, values: np.ndarray) -> Shape:
234247
how = self.how
235248
kind = self.kind
236249

@@ -261,7 +274,15 @@ def get_out_dtype(self, dtype: np.dtype) -> np.dtype:
261274
out_dtype = "object"
262275
return np.dtype(out_dtype)
263276

264-
def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
277+
@overload
278+
def _get_result_dtype(self, dtype: np.dtype) -> np.dtype:
279+
...
280+
281+
@overload
282+
def _get_result_dtype(self, dtype: ExtensionDtype) -> ExtensionDtype:
283+
...
284+
285+
def _get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
265286
"""
266287
Get the desired dtype of a result based on the
267288
input dtype and how it was computed.
@@ -276,13 +297,6 @@ def get_result_dtype(self, dtype: DtypeObj) -> DtypeObj:
276297
np.dtype or ExtensionDtype
277298
The desired dtype of the result.
278299
"""
279-
from pandas.core.arrays.boolean import BooleanDtype
280-
from pandas.core.arrays.floating import Float64Dtype
281-
from pandas.core.arrays.integer import (
282-
Int64Dtype,
283-
_IntegerDtype,
284-
)
285-
286300
how = self.how
287301

288302
if how in ["add", "cumsum", "sum", "prod"]:
@@ -315,15 +329,12 @@ def _ea_wrap_cython_operation(
315329
# TODO: general case implementation overridable by EAs.
316330
orig_values = values
317331

318-
if is_datetime64tz_dtype(values.dtype) or is_period_dtype(values.dtype):
332+
if isinstance(orig_values, (DatetimeArray, PeriodArray)):
319333
# All of the functions implemented here are ordinal, so we can
320334
# operate on the tz-naive equivalents
321-
npvalues = values.view("M8[ns]")
335+
npvalues = orig_values._ndarray.view("M8[ns]")
322336
res_values = self._cython_op_ndim_compat(
323-
# error: Argument 1 to "_cython_op_ndim_compat" of
324-
# "WrappedCythonOp" has incompatible type
325-
# "Union[ExtensionArray, ndarray]"; expected "ndarray"
326-
npvalues, # type: ignore[arg-type]
337+
npvalues,
327338
min_count=min_count,
328339
ngroups=ngroups,
329340
comp_ids=comp_ids,
@@ -336,14 +347,31 @@ def _ea_wrap_cython_operation(
336347
# preserve float64 dtype
337348
return res_values
338349

339-
res_values = res_values.astype("i8", copy=False)
340-
# error: Too many arguments for "ExtensionArray"
341-
result = type(orig_values)( # type: ignore[call-arg]
342-
res_values, dtype=orig_values.dtype
350+
res_values = res_values.view("i8")
351+
result = type(orig_values)(res_values, dtype=orig_values.dtype)
352+
return result
353+
354+
elif isinstance(orig_values, TimedeltaArray):
355+
# We have an ExtensionArray but not ExtensionDtype
356+
res_values = self._cython_op_ndim_compat(
357+
orig_values._ndarray,
358+
min_count=min_count,
359+
ngroups=ngroups,
360+
comp_ids=comp_ids,
361+
mask=None,
362+
**kwargs,
343363
)
364+
if self.how in ["rank"]:
365+
# i.e. how in WrappedCythonOp.cast_blocklist, since
366+
# other cast_blocklist methods dont go through cython_operation
367+
# preserve float64 dtype
368+
return res_values
369+
370+
# otherwise res_values has the same dtype as original values
371+
result = type(orig_values)(res_values)
344372
return result
345373

346-
elif is_integer_dtype(values.dtype) or is_bool_dtype(values.dtype):
374+
elif isinstance(values.dtype, (BooleanDtype, _IntegerDtype)):
347375
# IntegerArray or BooleanArray
348376
npvalues = values.to_numpy("float64", na_value=np.nan)
349377
res_values = self._cython_op_ndim_compat(
@@ -359,17 +387,14 @@ def _ea_wrap_cython_operation(
359387
# other cast_blocklist methods dont go through cython_operation
360388
return res_values
361389

362-
dtype = self.get_result_dtype(orig_values.dtype)
363-
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
364-
# has no attribute "construct_array_type"
365-
cls = dtype.construct_array_type() # type: ignore[union-attr]
390+
dtype = self._get_result_dtype(orig_values.dtype)
391+
cls = dtype.construct_array_type()
366392
return cls._from_sequence(res_values, dtype=dtype)
367393

368-
elif is_float_dtype(values.dtype):
394+
elif isinstance(values.dtype, FloatingDtype):
369395
# FloatingArray
370-
# error: "ExtensionDtype" has no attribute "numpy_dtype"
371396
npvalues = values.to_numpy(
372-
values.dtype.numpy_dtype, # type: ignore[attr-defined]
397+
values.dtype.numpy_dtype,
373398
na_value=np.nan,
374399
)
375400
res_values = self._cython_op_ndim_compat(
@@ -385,10 +410,8 @@ def _ea_wrap_cython_operation(
385410
# other cast_blocklist methods dont go through cython_operation
386411
return res_values
387412

388-
dtype = self.get_result_dtype(orig_values.dtype)
389-
# error: Item "dtype[Any]" of "Union[dtype[Any], ExtensionDtype]"
390-
# has no attribute "construct_array_type"
391-
cls = dtype.construct_array_type() # type: ignore[union-attr]
413+
dtype = self._get_result_dtype(orig_values.dtype)
414+
cls = dtype.construct_array_type()
392415
return cls._from_sequence(res_values, dtype=dtype)
393416

394417
raise NotImplementedError(
@@ -422,12 +445,13 @@ def _masked_ea_wrap_cython_operation(
422445
mask=mask,
423446
**kwargs,
424447
)
425-
dtype = self.get_result_dtype(orig_values.dtype)
448+
dtype = self._get_result_dtype(orig_values.dtype)
426449
assert isinstance(dtype, BaseMaskedDtype)
427450
cls = dtype.construct_array_type()
428451

429452
return cls(res_values.astype(dtype.type, copy=False), mask)
430453

454+
@final
431455
def _cython_op_ndim_compat(
432456
self,
433457
values: np.ndarray,
@@ -500,7 +524,7 @@ def _call_cython_op(
500524
if mask is not None:
501525
mask = mask.reshape(values.shape, order="C")
502526

503-
out_shape = self.get_output_shape(ngroups, values)
527+
out_shape = self._get_output_shape(ngroups, values)
504528
func, values = self.get_cython_func_and_vals(values, is_numeric)
505529
out_dtype = self.get_out_dtype(values.dtype)
506530

@@ -550,19 +574,71 @@ def _call_cython_op(
550574
if self.how not in self.cast_blocklist:
551575
# e.g. if we are int64 and need to restore to datetime64/timedelta64
552576
# "rank" is the only member of cast_blocklist we get here
553-
res_dtype = self.get_result_dtype(orig_values.dtype)
554-
# error: Argument 2 to "maybe_downcast_to_dtype" has incompatible type
555-
# "Union[dtype[Any], ExtensionDtype]"; expected "Union[str, dtype[Any]]"
556-
op_result = maybe_downcast_to_dtype(
557-
result, res_dtype # type: ignore[arg-type]
558-
)
577+
res_dtype = self._get_result_dtype(orig_values.dtype)
578+
op_result = maybe_downcast_to_dtype(result, res_dtype)
559579
else:
560580
op_result = result
561581

562582
# error: Incompatible return value type (got "Union[ExtensionArray, ndarray]",
563583
# expected "ndarray")
564584
return op_result # type: ignore[return-value]
565585

586+
@final
587+
def cython_operation(
588+
self,
589+
*,
590+
values: ArrayLike,
591+
axis: int,
592+
min_count: int = -1,
593+
comp_ids: np.ndarray,
594+
ngroups: int,
595+
**kwargs,
596+
) -> ArrayLike:
597+
"""
598+
Call our cython function, with appropriate pre- and post- processing.
599+
"""
600+
if values.ndim > 2:
601+
raise NotImplementedError("number of dimensions is currently limited to 2")
602+
elif values.ndim == 2:
603+
# Note: it is *not* the case that axis is always 0 for 1-dim values,
604+
# as we can have 1D ExtensionArrays that we need to treat as 2D
605+
assert axis == 1, axis
606+
607+
dtype = values.dtype
608+
is_numeric = is_numeric_dtype(dtype)
609+
610+
# can we do this operation with our cython functions
611+
# if not raise NotImplementedError
612+
self._disallow_invalid_ops(dtype, is_numeric)
613+
614+
if not isinstance(values, np.ndarray):
615+
# i.e. ExtensionArray
616+
if isinstance(values, BaseMaskedArray) and self.uses_mask():
617+
return self._masked_ea_wrap_cython_operation(
618+
values,
619+
min_count=min_count,
620+
ngroups=ngroups,
621+
comp_ids=comp_ids,
622+
**kwargs,
623+
)
624+
else:
625+
return self._ea_wrap_cython_operation(
626+
values,
627+
min_count=min_count,
628+
ngroups=ngroups,
629+
comp_ids=comp_ids,
630+
**kwargs,
631+
)
632+
633+
return self._cython_op_ndim_compat(
634+
values,
635+
min_count=min_count,
636+
ngroups=ngroups,
637+
comp_ids=comp_ids,
638+
mask=None,
639+
**kwargs,
640+
)
641+
566642

567643
class BaseGrouper:
568644
"""
@@ -799,6 +875,7 @@ def group_info(self):
799875

800876
ngroups = len(obs_group_ids)
801877
comp_ids = ensure_platform_int(comp_ids)
878+
802879
return comp_ids, obs_group_ids, ngroups
803880

804881
@final
@@ -868,58 +945,23 @@ def _cython_operation(
868945
how: str,
869946
axis: int,
870947
min_count: int = -1,
871-
mask: np.ndarray | None = None,
872948
**kwargs,
873949
) -> ArrayLike:
874950
"""
875951
Returns the values of a cython operation.
876952
"""
877953
assert kind in ["transform", "aggregate"]
878954

879-
if values.ndim > 2:
880-
raise NotImplementedError("number of dimensions is currently limited to 2")
881-
elif values.ndim == 2:
882-
# Note: it is *not* the case that axis is always 0 for 1-dim values,
883-
# as we can have 1D ExtensionArrays that we need to treat as 2D
884-
assert axis == 1, axis
885-
886-
dtype = values.dtype
887-
is_numeric = is_numeric_dtype(dtype)
888-
889955
cy_op = WrappedCythonOp(kind=kind, how=how)
890956

891-
# can we do this operation with our cython functions
892-
# if not raise NotImplementedError
893-
cy_op.disallow_invalid_ops(dtype, is_numeric)
894-
895957
comp_ids, _, _ = self.group_info
896958
ngroups = self.ngroups
897-
898-
func_uses_mask = cy_op.uses_mask()
899-
if is_extension_array_dtype(dtype):
900-
if isinstance(values, BaseMaskedArray) and func_uses_mask:
901-
return cy_op._masked_ea_wrap_cython_operation(
902-
values,
903-
min_count=min_count,
904-
ngroups=ngroups,
905-
comp_ids=comp_ids,
906-
**kwargs,
907-
)
908-
else:
909-
return cy_op._ea_wrap_cython_operation(
910-
values,
911-
min_count=min_count,
912-
ngroups=ngroups,
913-
comp_ids=comp_ids,
914-
**kwargs,
915-
)
916-
917-
return cy_op._cython_op_ndim_compat(
918-
values,
959+
return cy_op.cython_operation(
960+
values=values,
961+
axis=axis,
919962
min_count=min_count,
920-
ngroups=self.ngroups,
921963
comp_ids=comp_ids,
922-
mask=mask,
964+
ngroups=ngroups,
923965
**kwargs,
924966
)
925967

@@ -969,8 +1011,8 @@ def _aggregate_series_fast(
9691011
indexer = get_group_index_sorter(group_index, ngroups)
9701012
obj = obj.take(indexer)
9711013
group_index = group_index.take(indexer)
972-
grouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups)
973-
result, counts = grouper.get_result()
1014+
sgrouper = libreduction.SeriesGrouper(obj, func, group_index, ngroups)
1015+
result, counts = sgrouper.get_result()
9741016
return result, counts
9751017

9761018
@final
@@ -1167,8 +1209,8 @@ def _aggregate_series_fast(
11671209
# - obj is backed by an ndarray, not ExtensionArray
11681210
# - ngroups != 0
11691211
# - len(self.bins) > 0
1170-
grouper = libreduction.SeriesBinGrouper(obj, func, self.bins)
1171-
return grouper.get_result()
1212+
sbg = libreduction.SeriesBinGrouper(obj, func, self.bins)
1213+
return sbg.get_result()
11721214

11731215

11741216
def _is_indexed_like(obj, axes, axis: int) -> bool:

0 commit comments

Comments
 (0)