Skip to content

Commit e1bb74f

Browse files
jbrockmendeltopper-123
authored andcommitted
REF: let EAs override WrappedCythonOp groupby implementations (pandas-dev#51166)
1 parent b66b9e3 commit e1bb74f

File tree

6 files changed

+297
-256
lines changed

6 files changed

+297
-256
lines changed

pandas/core/arrays/base.py

+76
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,82 @@ def map(self, mapper, na_action=None):
17281728
"""
17291729
return map_array(self, mapper, na_action=na_action)
17301730

1731+
# ------------------------------------------------------------------------
1732+
# GroupBy Methods
1733+
1734+
def _groupby_op(
1735+
self,
1736+
*,
1737+
how: str,
1738+
has_dropped_na: bool,
1739+
min_count: int,
1740+
ngroups: int,
1741+
ids: npt.NDArray[np.intp],
1742+
**kwargs,
1743+
) -> ArrayLike:
1744+
"""
1745+
Dispatch GroupBy reduction or transformation operation.
1746+
1747+
This is an *experimental* API to allow ExtensionArray authors to implement
1748+
reductions and transformations. The API is subject to change.
1749+
1750+
Parameters
1751+
----------
1752+
how : {'any', 'all', 'sum', 'prod', 'min', 'max', 'mean', 'median',
1753+
'median', 'var', 'std', 'sem', 'nth', 'last', 'ohlc',
1754+
'cumprod', 'cumsum', 'cummin', 'cummax', 'rank'}
1755+
has_dropped_na : bool
1756+
min_count : int
1757+
ngroups : int
1758+
ids : np.ndarray[np.intp]
1759+
ids[i] gives the integer label for the group that self[i] belongs to.
1760+
**kwargs : operation-specific
1761+
'any', 'all' -> ['skipna']
1762+
'var', 'std', 'sem' -> ['ddof']
1763+
'cumprod', 'cumsum', 'cummin', 'cummax' -> ['skipna']
1764+
'rank' -> ['ties_method', 'ascending', 'na_option', 'pct']
1765+
1766+
Returns
1767+
-------
1768+
np.ndarray or ExtensionArray
1769+
"""
1770+
from pandas.core.arrays.string_ import StringDtype
1771+
from pandas.core.groupby.ops import WrappedCythonOp
1772+
1773+
kind = WrappedCythonOp.get_kind_from_how(how)
1774+
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
1775+
1776+
# GH#43682
1777+
if isinstance(self.dtype, StringDtype):
1778+
# StringArray
1779+
npvalues = self.to_numpy(object, na_value=np.nan)
1780+
else:
1781+
raise NotImplementedError(
1782+
f"function is not implemented for this dtype: {self.dtype}"
1783+
)
1784+
1785+
res_values = op._cython_op_ndim_compat(
1786+
npvalues,
1787+
min_count=min_count,
1788+
ngroups=ngroups,
1789+
comp_ids=ids,
1790+
mask=None,
1791+
**kwargs,
1792+
)
1793+
1794+
if op.how in op.cast_blocklist:
1795+
# i.e. how in ["rank"], since other cast_blocklist methods don't go
1796+
# through cython_operation
1797+
return res_values
1798+
1799+
if isinstance(self.dtype, StringDtype):
1800+
dtype = self.dtype
1801+
string_array_cls = dtype.construct_array_type()
1802+
return string_array_cls._from_sequence(res_values, dtype=dtype)
1803+
1804+
else:
1805+
raise NotImplementedError
1806+
17311807

17321808
class ExtensionArraySupportsAnyAll(ExtensionArray):
17331809
def any(self, *, skipna: bool = True) -> bool:

pandas/core/arrays/categorical.py

+59
Original file line numberDiff line numberDiff line change
@@ -2389,6 +2389,65 @@ def _str_get_dummies(self, sep: str = "|"):
23892389

23902390
return PandasArray(self.astype(str))._str_get_dummies(sep)
23912391

2392+
# ------------------------------------------------------------------------
2393+
# GroupBy Methods
2394+
2395+
def _groupby_op(
2396+
self,
2397+
*,
2398+
how: str,
2399+
has_dropped_na: bool,
2400+
min_count: int,
2401+
ngroups: int,
2402+
ids: npt.NDArray[np.intp],
2403+
**kwargs,
2404+
):
2405+
from pandas.core.groupby.ops import WrappedCythonOp
2406+
2407+
kind = WrappedCythonOp.get_kind_from_how(how)
2408+
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
2409+
2410+
dtype = self.dtype
2411+
if how in ["sum", "prod", "cumsum", "cumprod", "skew"]:
2412+
raise TypeError(f"{dtype} type does not support {how} operations")
2413+
if how in ["min", "max", "rank"] and not dtype.ordered:
2414+
# raise TypeError instead of NotImplementedError to ensure we
2415+
# don't go down a group-by-group path, since in the empty-groups
2416+
# case that would fail to raise
2417+
raise TypeError(f"Cannot perform {how} with non-ordered Categorical")
2418+
if how not in ["rank", "any", "all", "first", "last", "min", "max"]:
2419+
if kind == "transform":
2420+
raise TypeError(f"{dtype} type does not support {how} operations")
2421+
raise TypeError(f"{dtype} dtype does not support aggregation '{how}'")
2422+
2423+
result_mask = None
2424+
mask = self.isna()
2425+
if how == "rank":
2426+
assert self.ordered # checked earlier
2427+
npvalues = self._ndarray
2428+
elif how in ["first", "last", "min", "max"]:
2429+
npvalues = self._ndarray
2430+
result_mask = np.zeros(ngroups, dtype=bool)
2431+
else:
2432+
# any/all
2433+
npvalues = self.astype(bool)
2434+
2435+
res_values = op._cython_op_ndim_compat(
2436+
npvalues,
2437+
min_count=min_count,
2438+
ngroups=ngroups,
2439+
comp_ids=ids,
2440+
mask=mask,
2441+
result_mask=result_mask,
2442+
**kwargs,
2443+
)
2444+
2445+
if how in op.cast_blocklist:
2446+
return res_values
2447+
elif how in ["first", "last", "min", "max"]:
2448+
res_values[result_mask == 1] = -1
2449+
return self._from_backing_data(res_values)
2450+
23922451

23932452
# The Series.cat accessor
23942453

pandas/core/arrays/datetimelike.py

+82
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,88 @@ def _mode(self, dropna: bool = True):
15591559
npmodes = cast(np.ndarray, npmodes)
15601560
return self._from_backing_data(npmodes)
15611561

1562+
# ------------------------------------------------------------------
1563+
# GroupBy Methods
1564+
1565+
def _groupby_op(
1566+
self,
1567+
*,
1568+
how: str,
1569+
has_dropped_na: bool,
1570+
min_count: int,
1571+
ngroups: int,
1572+
ids: npt.NDArray[np.intp],
1573+
**kwargs,
1574+
):
1575+
dtype = self.dtype
1576+
if dtype.kind == "M":
1577+
# Adding/multiplying datetimes is not valid
1578+
if how in ["sum", "prod", "cumsum", "cumprod", "var", "skew"]:
1579+
raise TypeError(f"datetime64 type does not support {how} operations")
1580+
if how in ["any", "all"]:
1581+
# GH#34479
1582+
warnings.warn(
1583+
f"'{how}' with datetime64 dtypes is deprecated and will raise in a "
1584+
f"future version. Use (obj != pd.Timestamp(0)).{how}() instead.",
1585+
FutureWarning,
1586+
stacklevel=find_stack_level(),
1587+
)
1588+
1589+
elif isinstance(dtype, PeriodDtype):
1590+
# Adding/multiplying Periods is not valid
1591+
if how in ["sum", "prod", "cumsum", "cumprod", "var", "skew"]:
1592+
raise TypeError(f"Period type does not support {how} operations")
1593+
if how in ["any", "all"]:
1594+
# GH#34479
1595+
warnings.warn(
1596+
f"'{how}' with PeriodDtype is deprecated and will raise in a "
1597+
f"future version. Use (obj != pd.Period(0, freq)).{how}() instead.",
1598+
FutureWarning,
1599+
stacklevel=find_stack_level(),
1600+
)
1601+
else:
1602+
# timedeltas we can add but not multiply
1603+
if how in ["prod", "cumprod", "skew"]:
1604+
raise TypeError(f"timedelta64 type does not support {how} operations")
1605+
1606+
# All of the functions implemented here are ordinal, so we can
1607+
# operate on the tz-naive equivalents
1608+
npvalues = self._ndarray.view("M8[ns]")
1609+
1610+
from pandas.core.groupby.ops import WrappedCythonOp
1611+
1612+
kind = WrappedCythonOp.get_kind_from_how(how)
1613+
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
1614+
1615+
res_values = op._cython_op_ndim_compat(
1616+
npvalues,
1617+
min_count=min_count,
1618+
ngroups=ngroups,
1619+
comp_ids=ids,
1620+
mask=None,
1621+
**kwargs,
1622+
)
1623+
1624+
if op.how in op.cast_blocklist:
1625+
# i.e. how in ["rank"], since other cast_blocklist methods don't go
1626+
# through cython_operation
1627+
return res_values
1628+
1629+
# We did a view to M8[ns] above, now we go the other direction
1630+
assert res_values.dtype == "M8[ns]"
1631+
if how in ["std", "sem"]:
1632+
from pandas.core.arrays import TimedeltaArray
1633+
1634+
if isinstance(self.dtype, PeriodDtype):
1635+
raise TypeError("'std' and 'sem' are not valid for PeriodDtype")
1636+
self = cast("DatetimeArray | TimedeltaArray", self)
1637+
new_dtype = f"m8[{self.unit}]"
1638+
res_values = res_values.view(new_dtype)
1639+
return TimedeltaArray(res_values)
1640+
1641+
res_values = res_values.view(self._ndarray.dtype)
1642+
return self._from_backing_data(res_values)
1643+
15621644

15631645
class DatelikeOps(DatetimeLikeArrayMixin):
15641646
"""

pandas/core/arrays/masked.py

+43
Original file line numberDiff line numberDiff line change
@@ -1381,3 +1381,46 @@ def _accumulate(
13811381
data, mask = op(data, mask, skipna=skipna, **kwargs)
13821382

13831383
return type(self)(data, mask, copy=False)
1384+
1385+
# ------------------------------------------------------------------
1386+
# GroupBy Methods
1387+
1388+
def _groupby_op(
1389+
self,
1390+
*,
1391+
how: str,
1392+
has_dropped_na: bool,
1393+
min_count: int,
1394+
ngroups: int,
1395+
ids: npt.NDArray[np.intp],
1396+
**kwargs,
1397+
):
1398+
from pandas.core.groupby.ops import WrappedCythonOp
1399+
1400+
kind = WrappedCythonOp.get_kind_from_how(how)
1401+
op = WrappedCythonOp(how=how, kind=kind, has_dropped_na=has_dropped_na)
1402+
1403+
# libgroupby functions are responsible for NOT altering mask
1404+
mask = self._mask
1405+
if op.kind != "aggregate":
1406+
result_mask = mask.copy()
1407+
else:
1408+
result_mask = np.zeros(ngroups, dtype=bool)
1409+
1410+
res_values = op._cython_op_ndim_compat(
1411+
self._data,
1412+
min_count=min_count,
1413+
ngroups=ngroups,
1414+
comp_ids=ids,
1415+
mask=mask,
1416+
result_mask=result_mask,
1417+
**kwargs,
1418+
)
1419+
1420+
if op.how == "ohlc":
1421+
arity = op._cython_arity.get(op.how, 1)
1422+
result_mask = np.tile(result_mask, (arity, 1)).T
1423+
1424+
# res_values should already have the correct dtype, we just need to
1425+
# wrap in a MaskedArray
1426+
return self._maybe_mask_result(res_values, result_mask)

pandas/core/arrays/sparse/array.py

+15
Original file line numberDiff line numberDiff line change
@@ -1805,6 +1805,21 @@ def _formatter(self, boxed: bool = False):
18051805
# This will infer the correct formatter from the dtype of the values.
18061806
return None
18071807

1808+
# ------------------------------------------------------------------------
1809+
# GroupBy Methods
1810+
1811+
def _groupby_op(
1812+
self,
1813+
*,
1814+
how: str,
1815+
has_dropped_na: bool,
1816+
min_count: int,
1817+
ngroups: int,
1818+
ids: npt.NDArray[np.intp],
1819+
**kwargs,
1820+
):
1821+
raise NotImplementedError(f"{self.dtype} dtype not supported")
1822+
18081823

18091824
def _make_sparse(
18101825
arr: np.ndarray,

0 commit comments

Comments
 (0)