|
56 | 56 | ExtensionArray,
|
57 | 57 | ExtensionArraySupportsAnyAll,
|
58 | 58 | )
|
| 59 | +from pandas.core.arrays.string_ import StringDtype |
59 | 60 | import pandas.core.common as com
|
60 | 61 | from pandas.core.indexers import (
|
61 | 62 | check_array_indexer,
|
@@ -1655,6 +1656,82 @@ def _replace_with_mask(
|
1655 | 1656 | result[mask] = replacements
|
1656 | 1657 | return pa.array(result, type=values.type, from_pandas=True)
|
1657 | 1658 |
|
| 1659 | + # ------------------------------------------------------------------ |
| 1660 | + # GroupBy Methods |
| 1661 | + |
| 1662 | + def _to_masked(self): |
| 1663 | + pa_dtype = self._pa_array.type |
| 1664 | + na_value = 1 |
| 1665 | + from pandas.core.arrays import ( |
| 1666 | + BooleanArray, |
| 1667 | + FloatingArray, |
| 1668 | + IntegerArray, |
| 1669 | + ) |
| 1670 | + |
| 1671 | + arr_cls: type[FloatingArray | IntegerArray | BooleanArray] |
| 1672 | + if pa.types.is_floating(pa_dtype): |
| 1673 | + nbits = pa_dtype.bit_width |
| 1674 | + dtype = f"Float{nbits}" |
| 1675 | + np_dtype = dtype.lower() |
| 1676 | + arr_cls = FloatingArray |
| 1677 | + elif pa.types.is_unsigned_integer(pa_dtype): |
| 1678 | + nbits = pa_dtype.bit_width |
| 1679 | + dtype = f"UInt{nbits}" |
| 1680 | + np_dtype = dtype.lower() |
| 1681 | + arr_cls = IntegerArray |
| 1682 | + |
| 1683 | + elif pa.types.is_signed_integer(pa_dtype): |
| 1684 | + nbits = pa_dtype.bit_width |
| 1685 | + dtype = f"Int{nbits}" |
| 1686 | + np_dtype = dtype.lower() |
| 1687 | + arr_cls = IntegerArray |
| 1688 | + |
| 1689 | + elif pa.types.is_boolean(pa_dtype): |
| 1690 | + dtype = "boolean" |
| 1691 | + np_dtype = "bool" |
| 1692 | + na_value = True |
| 1693 | + arr_cls = BooleanArray |
| 1694 | + else: |
| 1695 | + raise NotImplementedError |
| 1696 | + |
| 1697 | + mask = self.isna() |
| 1698 | + arr = self.to_numpy(dtype=np_dtype, na_value=na_value) |
| 1699 | + return arr_cls(arr, mask) |
| 1700 | + |
| 1701 | + def _groupby_op( |
| 1702 | + self, |
| 1703 | + *, |
| 1704 | + how: str, |
| 1705 | + has_dropped_na: bool, |
| 1706 | + min_count: int, |
| 1707 | + ngroups: int, |
| 1708 | + ids: npt.NDArray[np.intp], |
| 1709 | + **kwargs, |
| 1710 | + ): |
| 1711 | + if isinstance(self.dtype, StringDtype): |
| 1712 | + return super()._groupby_op( |
| 1713 | + how=how, |
| 1714 | + has_dropped_na=has_dropped_na, |
| 1715 | + min_count=min_count, |
| 1716 | + ngroups=ngroups, |
| 1717 | + ids=ids, |
| 1718 | + **kwargs, |
| 1719 | + ) |
| 1720 | + |
| 1721 | + masked = self._to_masked() |
| 1722 | + |
| 1723 | + result = masked._groupby_op( |
| 1724 | + how=how, |
| 1725 | + has_dropped_na=has_dropped_na, |
| 1726 | + min_count=min_count, |
| 1727 | + ngroups=ngroups, |
| 1728 | + ids=ids, |
| 1729 | + **kwargs, |
| 1730 | + ) |
| 1731 | + if isinstance(result, np.ndarray): |
| 1732 | + return result |
| 1733 | + return type(self)._from_sequence(result, copy=False) |
| 1734 | + |
1658 | 1735 | def _str_count(self, pat: str, flags: int = 0):
|
1659 | 1736 | if flags:
|
1660 | 1737 | raise NotImplementedError(f"count not implemented with {flags=}")
|
|
0 commit comments