Skip to content

TYP: GroupBy #43806

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pandas/_libs/lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from typing import (
Any,
Callable,
Generator,
Hashable,
Literal,
overload,
)
Expand Down Expand Up @@ -197,7 +198,7 @@ def indices_fast(
labels: np.ndarray, # const int64_t[:]
keys: list,
sorted_labels: list[npt.NDArray[np.int64]],
) -> dict: ...
) -> dict[Hashable, npt.NDArray[np.intp]]: ...
def generate_slices(
labels: np.ndarray, ngroups: int # const intp_t[:]
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: ...
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/groupby/grouper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from pandas._typing import (
ArrayLike,
NDFrameT,
npt,
)
from pandas.errors import InvalidIndexError
from pandas.util._decorators import cache_readonly
Expand Down Expand Up @@ -604,7 +605,7 @@ def ngroups(self) -> int:
return len(self.group_index)

@cache_readonly
def indices(self):
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
# we have a list of groupers
if isinstance(self.grouping_vector, ops.BaseGrouper):
return self.grouping_vector.indices
Expand Down
20 changes: 12 additions & 8 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def apply(
return result_values, mutated

@cache_readonly
def indices(self):
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
"""dict {group name -> group indices}"""
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
# This shows unused categories in indices GH#38642
Expand Down Expand Up @@ -807,7 +807,7 @@ def is_monotonic(self) -> bool:
return Index(self.group_info[0]).is_monotonic

@cache_readonly
def group_info(self):
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
comp_ids, obs_group_ids = self._get_compressed_codes()

ngroups = len(obs_group_ids)
Expand All @@ -817,22 +817,26 @@ def group_info(self):

@final
@cache_readonly
def codes_info(self) -> np.ndarray:
def codes_info(self) -> npt.NDArray[np.intp]:
# return the codes of items in original grouped axis
ids, _, _ = self.group_info
if self.indexer is not None:
sorter = np.lexsort((ids, self.indexer))
ids = ids[sorter]
ids = ensure_platform_int(ids)
# TODO: if numpy annotates np.lexsort, this ensure_platform_int
# may become unnecessary
return ids

@final
def _get_compressed_codes(self) -> tuple[np.ndarray, np.ndarray]:
def _get_compressed_codes(self) -> tuple[np.ndarray, npt.NDArray[np.intp]]:
# The first returned ndarray may have any signed integer dtype
if len(self.groupings) > 1:
group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
return compress_group_index(group_index, sort=self._sort)

ping = self.groupings[0]
return ping.codes, np.arange(len(ping.group_index))
return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)

@final
@cache_readonly
Expand Down Expand Up @@ -1017,7 +1021,7 @@ class BinGrouper(BaseGrouper):

"""

bins: np.ndarray # np.ndarray[np.int64]
bins: npt.NDArray[np.int64]
binlabels: Index
mutated: bool

Expand Down Expand Up @@ -1101,9 +1105,9 @@ def indices(self):
return indices

@cache_readonly
def group_info(self):
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
ngroups = self.ngroups
obs_group_ids = np.arange(ngroups, dtype=np.int64)
obs_group_ids = np.arange(ngroups, dtype=np.intp)
rep = np.diff(np.r_[0, self.bins])

rep = ensure_platform_int(rep)
Expand Down
22 changes: 11 additions & 11 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def __repr__(self) -> str:
def __len__(self) -> int:
return len(self.values)

def _slice(self, slicer):
def _slice(self, slicer) -> ArrayLike:
"""return a slice of my values"""

return self.values[slicer]
Expand Down Expand Up @@ -344,7 +344,7 @@ def dtype(self) -> DtypeObj:
def iget(self, i):
return self.values[i]

def set_inplace(self, locs, values):
def set_inplace(self, locs, values) -> None:
"""
Modify block values in-place with new item value.

Expand Down Expand Up @@ -563,13 +563,13 @@ def _downcast_2d(self) -> list[Block]:
return [self.make_block(new_values)]

@final
def astype(self, dtype, copy: bool = False, errors: str = "raise"):
def astype(self, dtype: DtypeObj, copy: bool = False, errors: str = "raise"):
"""
Coerce to the new dtype.

Parameters
----------
dtype : str, dtype convertible
dtype : np.dtype or ExtensionDtype
copy : bool, default False
copy if indicated
errors : str, {'raise', 'ignore'}, default 'raise'
Expand Down Expand Up @@ -1441,7 +1441,7 @@ def iget(self, col):
raise IndexError(f"{self} only contains one item")
return self.values

def set_inplace(self, locs, values):
def set_inplace(self, locs, values) -> None:
# NB: This is a misnomer, is supposed to be inplace but is not,
# see GH#33457
assert locs.tolist() == [0]
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def setitem(self, indexer, value):
# https://github.com/pandas-dev/pandas/issues/24020
# Need a dedicated setitem until GH#24020 (type promotion in setitem
# for extension arrays) is designed and implemented.
return self.astype(object).setitem(indexer, value)
return self.astype(_dtype_obj).setitem(indexer, value)

if isinstance(indexer, tuple):
# TODO(EA2D): not needed with 2D EAs
Expand Down Expand Up @@ -1547,7 +1547,7 @@ def take_nd(

return self.make_block_same_class(new_values, new_mgr_locs)

def _slice(self, slicer):
def _slice(self, slicer) -> ExtensionArray:
"""
Return a slice of my values.

Expand All @@ -1558,7 +1558,7 @@ def _slice(self, slicer):

Returns
-------
np.ndarray or ExtensionArray
ExtensionArray
"""
# return same dims as we currently have
if not isinstance(slicer, tuple) and self.ndim == 2:
Expand Down Expand Up @@ -1736,7 +1736,7 @@ def is_view(self) -> bool:
def setitem(self, indexer, value):
if not self._can_hold_element(value):
# TODO: general case needs casting logic.
return self.astype(object).setitem(indexer, value)
return self.astype(_dtype_obj).setitem(indexer, value)

values = self.values
if self.ndim > 1:
Expand All @@ -1750,7 +1750,7 @@ def putmask(self, mask, new) -> list[Block]:
mask = extract_bool_array(mask)

if not self._can_hold_element(new):
return self.astype(object).putmask(mask, new)
return self.astype(_dtype_obj).putmask(mask, new)

arr = self.values
arr.T.putmask(mask, new)
Expand Down Expand Up @@ -1808,7 +1808,7 @@ def fillna(
# We support filling a DatetimeTZ with a `value` whose timezone
# is different by coercing to object.
# TODO: don't special-case td64
return self.astype(object).fillna(value, limit, inplace, downcast)
return self.astype(_dtype_obj).fillna(value, limit, inplace, downcast)

values = self.values
values = values if inplace else values.copy()
Expand Down
3 changes: 2 additions & 1 deletion pandas/core/sorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
TYPE_CHECKING,
Callable,
DefaultDict,
Hashable,
Iterable,
Sequence,
)
Expand Down Expand Up @@ -576,7 +577,7 @@ def get_flattened_list(

def get_indexer_dict(
label_list: list[np.ndarray], keys: list[Index]
) -> dict[str | tuple, np.ndarray]:
) -> dict[Hashable, npt.NDArray[np.intp]]:
"""
Returns
-------
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/groupby/test_grouping.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ def test_groupby_empty(self):
)

tm.assert_numpy_array_equal(
gr.grouper.group_info[1], np.array([], dtype=np.dtype("int"))
gr.grouper.group_info[1], np.array([], dtype=np.dtype(np.intp))
)

assert gr.grouper.group_info[2] == 0
Expand Down