Skip to content

CLN: BaseGrouper #59034

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/core/groupby/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,8 @@ def nunique(self, dropna: bool = True) -> Series | DataFrame:
b 1
dtype: int64
"""
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
val = self.obj._values
codes, uniques = algorithms.factorize(val, use_na_sentinel=dropna, sort=False)

Expand Down
17 changes: 11 additions & 6 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,7 +1360,7 @@ def _wrap_applied_output(

@final
def _numba_prep(self, data: DataFrame):
ids, ngroups = self._grouper.group_info
ngroups = self._grouper.ngroups
sorted_index = self._grouper.result_ilocs
sorted_ids = self._grouper._sorted_ids

Expand Down Expand Up @@ -1969,7 +1969,8 @@ def _cumcount_array(self, ascending: bool = True) -> np.ndarray:
this is currently implementing sort=False
(though the default is sort=True) for groupby in general
"""
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
sorter = get_group_index_sorter(ids, ngroups)
ids, count = ids[sorter], len(ids)

Expand Down Expand Up @@ -2185,7 +2186,8 @@ def count(self) -> NDFrameT:
Freq: MS, dtype: int64
"""
data = self._get_data_to_aggregate()
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
mask = ids != -1

is_series = data.ndim == 1
Expand Down Expand Up @@ -3840,7 +3842,8 @@ def _fill(self, direction: Literal["ffill", "bfill"], limit: int | None = None):
if limit is None:
limit = -1

ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups

col_func = partial(
libgroupby.group_fillna_indexer,
Expand Down Expand Up @@ -4361,7 +4364,8 @@ def post_processor(
qs = np.array([q], dtype=np.float64)
pass_qs = None

ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
if self.dropna:
# splitter drops NA groups, we need to do the same
ids = ids[ids >= 0]
Expand Down Expand Up @@ -5038,7 +5042,8 @@ def shift(
else:
if fill_value is lib.no_default:
fill_value = None
ids, ngroups = self._grouper.group_info
ids = self._grouper.ids
ngroups = self._grouper.ngroups
res_indexer = np.zeros(len(ids), dtype=np.int64)

libgroupby.group_shift_indexer(res_indexer, ids, ngroups, period)
Expand Down
57 changes: 13 additions & 44 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@
Generator,
Hashable,
Iterator,
Sequence,
)

from pandas.core.generic import NDFrame
Expand Down Expand Up @@ -581,25 +580,21 @@ class BaseGrouper:
def __init__(
self,
axis: Index,
groupings: Sequence[grouper.Grouping],
groupings: list[grouper.Grouping],
sort: bool = True,
dropna: bool = True,
) -> None:
assert isinstance(axis, Index), axis

self.axis = axis
self._groupings: list[grouper.Grouping] = list(groupings)
self._groupings = groupings
self._sort = sort
self.dropna = dropna

@property
def groupings(self) -> list[grouper.Grouping]:
return self._groupings

@property
def shape(self) -> Shape:
return tuple(ping.ngroups for ping in self.groupings)

def __iter__(self) -> Iterator[Hashable]:
return iter(self.indices)

Expand Down Expand Up @@ -628,11 +623,15 @@ def _get_splitter(self, data: NDFrame) -> DataSplitter:
-------
Generator yielding subsetted objects
"""
ids, ngroups = self.group_info
return _get_splitter(
if isinstance(data, Series):
klass: type[DataSplitter] = SeriesSplitter
else:
# i.e. DataFrame
klass = FrameSplitter

return klass(
data,
ids,
ngroups,
self.ngroups,
sorted_ids=self._sorted_ids,
sort_idx=self.result_ilocs,
)
Expand Down Expand Up @@ -692,7 +691,8 @@ def size(self) -> Series:
"""
Compute group sizes.
"""
ids, ngroups = self.group_info
ids = self.ids
ngroups = self.ngroups
out: np.ndarray | list
if ngroups:
out = np.bincount(ids[ids != -1], minlength=ngroups)
Expand Down Expand Up @@ -729,12 +729,6 @@ def has_dropped_na(self) -> bool:
"""
return bool((self.ids < 0).any())

@cache_readonly
def group_info(self) -> tuple[npt.NDArray[np.intp], int]:
result_index, ids = self.result_index_and_ids
ngroups = len(result_index)
return ids, ngroups

@cache_readonly
def codes_info(self) -> npt.NDArray[np.intp]:
# return the codes of items in original grouped axis
Expand Down Expand Up @@ -1123,10 +1117,6 @@ def indices(self):
i = bin
return indices

@cache_readonly
def group_info(self) -> tuple[npt.NDArray[np.intp], int]:
return self.ids, self.ngroups

@cache_readonly
def codes(self) -> list[npt.NDArray[np.intp]]:
return [self.ids]
Expand Down Expand Up @@ -1191,29 +1181,25 @@ class DataSplitter(Generic[NDFrameT]):
def __init__(
self,
data: NDFrameT,
labels: npt.NDArray[np.intp],
ngroups: int,
*,
sort_idx: npt.NDArray[np.intp],
sorted_ids: npt.NDArray[np.intp],
) -> None:
self.data = data
self.labels = ensure_platform_int(labels) # _should_ already be np.intp
self.ngroups = ngroups

self._slabels = sorted_ids
self._sort_idx = sort_idx

def __iter__(self) -> Iterator:
sdata = self._sorted_data

if self.ngroups == 0:
# we are inside a generator, rather than raise StopIteration
# we merely return signal the end
return

starts, ends = lib.generate_slices(self._slabels, self.ngroups)

sdata = self._sorted_data
for start, end in zip(starts, ends):
yield self._chop(sdata, slice(start, end))

Expand Down Expand Up @@ -1241,20 +1227,3 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
mgr = sdata._mgr.get_slice(slice_obj, axis=1)
df = sdata._constructor_from_mgr(mgr, axes=mgr.axes)
return df.__finalize__(sdata, method="groupby")


def _get_splitter(
data: NDFrame,
labels: npt.NDArray[np.intp],
ngroups: int,
*,
sort_idx: npt.NDArray[np.intp],
sorted_ids: npt.NDArray[np.intp],
) -> DataSplitter:
if isinstance(data, Series):
klass: type[DataSplitter] = SeriesSplitter
else:
# i.e. DataFrame
klass = FrameSplitter

return klass(data, labels, ngroups, sort_idx=sort_idx, sorted_ids=sorted_ids)
4 changes: 3 additions & 1 deletion pandas/tests/test_sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def test_int64_overflow_groupby_large_df_shuffled(self, agg):
gr = df.groupby(list("abcde"))

# verify this is testing what it is supposed to test!
assert is_int64_overflow_possible(gr._grouper.shape)
assert is_int64_overflow_possible(
tuple(ping.ngroups for ping in gr._grouper.groupings)
)

mi = MultiIndex.from_arrays(
[ar.ravel() for ar in np.array_split(np.unique(arr, axis=0), 5, axis=1)],
Expand Down
Loading