Skip to content

Commit c94f9af

Browse files
authored
PERF: groupby reductions with pyarrow dtypes (#52469)
* REF: move groupby reduction methods to EA * REF: move EA-specific checks to EAs * REF: dont pass op to groupby_op * mypy fixup * groupby_op -> _groupby_op * mypy fixup * mypy fixup * PERF: Groupby reductions with pyarrow dtypes * mypy fixup
1 parent 38c57ce commit c94f9af

File tree

3 files changed

+84
-5
lines changed

3 files changed

+84
-5
lines changed

pandas/core/arrays/arrow/array.py

+77
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
ExtensionArray,
5757
ExtensionArraySupportsAnyAll,
5858
)
59+
from pandas.core.arrays.string_ import StringDtype
5960
import pandas.core.common as com
6061
from pandas.core.indexers import (
6162
check_array_indexer,
@@ -1655,6 +1656,82 @@ def _replace_with_mask(
16551656
result[mask] = replacements
16561657
return pa.array(result, type=values.type, from_pandas=True)
16571658

1659+
# ------------------------------------------------------------------
1660+
# GroupBy Methods
1661+
1662+
def _to_masked(self):
1663+
pa_dtype = self._pa_array.type
1664+
na_value = 1
1665+
from pandas.core.arrays import (
1666+
BooleanArray,
1667+
FloatingArray,
1668+
IntegerArray,
1669+
)
1670+
1671+
arr_cls: type[FloatingArray | IntegerArray | BooleanArray]
1672+
if pa.types.is_floating(pa_dtype):
1673+
nbits = pa_dtype.bit_width
1674+
dtype = f"Float{nbits}"
1675+
np_dtype = dtype.lower()
1676+
arr_cls = FloatingArray
1677+
elif pa.types.is_unsigned_integer(pa_dtype):
1678+
nbits = pa_dtype.bit_width
1679+
dtype = f"UInt{nbits}"
1680+
np_dtype = dtype.lower()
1681+
arr_cls = IntegerArray
1682+
1683+
elif pa.types.is_signed_integer(pa_dtype):
1684+
nbits = pa_dtype.bit_width
1685+
dtype = f"Int{nbits}"
1686+
np_dtype = dtype.lower()
1687+
arr_cls = IntegerArray
1688+
1689+
elif pa.types.is_boolean(pa_dtype):
1690+
dtype = "boolean"
1691+
np_dtype = "bool"
1692+
na_value = True
1693+
arr_cls = BooleanArray
1694+
else:
1695+
raise NotImplementedError
1696+
1697+
mask = self.isna()
1698+
arr = self.to_numpy(dtype=np_dtype, na_value=na_value)
1699+
return arr_cls(arr, mask)
1700+
1701+
def _groupby_op(
1702+
self,
1703+
*,
1704+
how: str,
1705+
has_dropped_na: bool,
1706+
min_count: int,
1707+
ngroups: int,
1708+
ids: npt.NDArray[np.intp],
1709+
**kwargs,
1710+
):
1711+
if isinstance(self.dtype, StringDtype):
1712+
return super()._groupby_op(
1713+
how=how,
1714+
has_dropped_na=has_dropped_na,
1715+
min_count=min_count,
1716+
ngroups=ngroups,
1717+
ids=ids,
1718+
**kwargs,
1719+
)
1720+
1721+
masked = self._to_masked()
1722+
1723+
result = masked._groupby_op(
1724+
how=how,
1725+
has_dropped_na=has_dropped_na,
1726+
min_count=min_count,
1727+
ngroups=ngroups,
1728+
ids=ids,
1729+
**kwargs,
1730+
)
1731+
if isinstance(result, np.ndarray):
1732+
return result
1733+
return type(self)._from_sequence(result, copy=False)
1734+
16581735
def _str_count(self, pat: str, flags: int = 0):
16591736
if flags:
16601737
raise NotImplementedError(f"count not implemented with {flags=}")

pandas/core/arrays/masked.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979
)
8080
from pandas.core.array_algos.quantile import quantile_with_mask
8181
from pandas.core.arraylike import OpsMixin
82-
from pandas.core.arrays import ExtensionArray
82+
from pandas.core.arrays.base import ExtensionArray
8383
from pandas.core.construction import ensure_wrapped_if_datetimelike
8484
from pandas.core.indexers import check_array_indexer
8585
from pandas.core.ops import invalid_comparison

pandas/core/arrays/string_.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,15 @@
3535

3636
from pandas.core import ops
3737
from pandas.core.array_algos import masked_reductions
38-
from pandas.core.arrays import (
39-
ExtensionArray,
38+
from pandas.core.arrays.base import ExtensionArray
39+
from pandas.core.arrays.floating import (
4040
FloatingArray,
41+
FloatingDtype,
42+
)
43+
from pandas.core.arrays.integer import (
4144
IntegerArray,
45+
IntegerDtype,
4246
)
43-
from pandas.core.arrays.floating import FloatingDtype
44-
from pandas.core.arrays.integer import IntegerDtype
4547
from pandas.core.arrays.numpy_ import PandasArray
4648
from pandas.core.construction import extract_array
4749
from pandas.core.indexers import check_array_indexer

0 commit comments

Comments
 (0)