Skip to content

Commit dfe958c

Browse files
authored
TYP: GroupBy (#43806)
1 parent 742ab04 commit dfe958c

File tree

6 files changed

+30
-23
lines changed

6 files changed

+30
-23
lines changed

pandas/_libs/lib.pyi

+2-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ from typing import (
55
Any,
66
Callable,
77
Generator,
8+
Hashable,
89
Literal,
910
overload,
1011
)
@@ -197,7 +198,7 @@ def indices_fast(
197198
labels: np.ndarray, # const int64_t[:]
198199
keys: list,
199200
sorted_labels: list[npt.NDArray[np.int64]],
200-
) -> dict: ...
201+
) -> dict[Hashable, npt.NDArray[np.intp]]: ...
201202
def generate_slices(
202203
labels: np.ndarray, ngroups: int # const intp_t[:]
203204
) -> tuple[npt.NDArray[np.int64], npt.NDArray[np.int64]]: ...

pandas/core/groupby/grouper.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from pandas._typing import (
1818
ArrayLike,
1919
NDFrameT,
20+
npt,
2021
)
2122
from pandas.errors import InvalidIndexError
2223
from pandas.util._decorators import cache_readonly
@@ -604,7 +605,7 @@ def ngroups(self) -> int:
604605
return len(self.group_index)
605606

606607
@cache_readonly
607-
def indices(self):
608+
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
608609
# we have a list of groupers
609610
if isinstance(self.grouping_vector, ops.BaseGrouper):
610611
return self.grouping_vector.indices

pandas/core/groupby/ops.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ def apply(
756756
return result_values, mutated
757757

758758
@cache_readonly
759-
def indices(self):
759+
def indices(self) -> dict[Hashable, npt.NDArray[np.intp]]:
760760
"""dict {group name -> group indices}"""
761761
if len(self.groupings) == 1 and isinstance(self.result_index, CategoricalIndex):
762762
# This shows unused categories in indices GH#38642
@@ -807,7 +807,7 @@ def is_monotonic(self) -> bool:
807807
return Index(self.group_info[0]).is_monotonic
808808

809809
@cache_readonly
810-
def group_info(self):
810+
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
811811
comp_ids, obs_group_ids = self._get_compressed_codes()
812812

813813
ngroups = len(obs_group_ids)
@@ -817,22 +817,26 @@ def group_info(self):
817817

818818
@final
819819
@cache_readonly
820-
def codes_info(self) -> np.ndarray:
820+
def codes_info(self) -> npt.NDArray[np.intp]:
821821
# return the codes of items in original grouped axis
822822
ids, _, _ = self.group_info
823823
if self.indexer is not None:
824824
sorter = np.lexsort((ids, self.indexer))
825825
ids = ids[sorter]
826+
ids = ensure_platform_int(ids)
827+
# TODO: if numpy annotates np.lexsort, this ensure_platform_int
828+
# may become unnecessary
826829
return ids
827830

828831
@final
829-
def _get_compressed_codes(self) -> tuple[np.ndarray, np.ndarray]:
832+
def _get_compressed_codes(self) -> tuple[np.ndarray, npt.NDArray[np.intp]]:
833+
# The first returned ndarray may have any signed integer dtype
830834
if len(self.groupings) > 1:
831835
group_index = get_group_index(self.codes, self.shape, sort=True, xnull=True)
832836
return compress_group_index(group_index, sort=self._sort)
833837

834838
ping = self.groupings[0]
835-
return ping.codes, np.arange(len(ping.group_index))
839+
return ping.codes, np.arange(len(ping.group_index), dtype=np.intp)
836840

837841
@final
838842
@cache_readonly
@@ -1017,7 +1021,7 @@ class BinGrouper(BaseGrouper):
10171021
10181022
"""
10191023

1020-
bins: np.ndarray # np.ndarray[np.int64]
1024+
bins: npt.NDArray[np.int64]
10211025
binlabels: Index
10221026
mutated: bool
10231027

@@ -1101,9 +1105,9 @@ def indices(self):
11011105
return indices
11021106

11031107
@cache_readonly
1104-
def group_info(self):
1108+
def group_info(self) -> tuple[npt.NDArray[np.intp], npt.NDArray[np.intp], int]:
11051109
ngroups = self.ngroups
1106-
obs_group_ids = np.arange(ngroups, dtype=np.int64)
1110+
obs_group_ids = np.arange(ngroups, dtype=np.intp)
11071111
rep = np.diff(np.r_[0, self.bins])
11081112

11091113
rep = ensure_platform_int(rep)

pandas/core/internals/blocks.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ def __repr__(self) -> str:
295295
def __len__(self) -> int:
296296
return len(self.values)
297297

298-
def _slice(self, slicer):
298+
def _slice(self, slicer) -> ArrayLike:
299299
"""return a slice of my values"""
300300

301301
return self.values[slicer]
@@ -344,7 +344,7 @@ def dtype(self) -> DtypeObj:
344344
def iget(self, i):
345345
return self.values[i]
346346

347-
def set_inplace(self, locs, values):
347+
def set_inplace(self, locs, values) -> None:
348348
"""
349349
Modify block values in-place with new item value.
350350
@@ -563,13 +563,13 @@ def _downcast_2d(self) -> list[Block]:
563563
return [self.make_block(new_values)]
564564

565565
@final
566-
def astype(self, dtype, copy: bool = False, errors: str = "raise"):
566+
def astype(self, dtype: DtypeObj, copy: bool = False, errors: str = "raise"):
567567
"""
568568
Coerce to the new dtype.
569569
570570
Parameters
571571
----------
572-
dtype : str, dtype convertible
572+
dtype : np.dtype or ExtensionDtype
573573
copy : bool, default False
574574
copy if indicated
575575
errors : str, {'raise', 'ignore'}, default 'raise'
@@ -1441,7 +1441,7 @@ def iget(self, col):
14411441
raise IndexError(f"{self} only contains one item")
14421442
return self.values
14431443

1444-
def set_inplace(self, locs, values):
1444+
def set_inplace(self, locs, values) -> None:
14451445
# NB: This is a misnomer, is supposed to be inplace but is not,
14461446
# see GH#33457
14471447
assert locs.tolist() == [0]
@@ -1509,7 +1509,7 @@ def setitem(self, indexer, value):
15091509
# https://github.com/pandas-dev/pandas/issues/24020
15101510
# Need a dedicated setitem until GH#24020 (type promotion in setitem
15111511
# for extension arrays) is designed and implemented.
1512-
return self.astype(object).setitem(indexer, value)
1512+
return self.astype(_dtype_obj).setitem(indexer, value)
15131513

15141514
if isinstance(indexer, tuple):
15151515
# TODO(EA2D): not needed with 2D EAs
@@ -1547,7 +1547,7 @@ def take_nd(
15471547

15481548
return self.make_block_same_class(new_values, new_mgr_locs)
15491549

1550-
def _slice(self, slicer):
1550+
def _slice(self, slicer) -> ExtensionArray:
15511551
"""
15521552
Return a slice of my values.
15531553
@@ -1558,7 +1558,7 @@ def _slice(self, slicer):
15581558
15591559
Returns
15601560
-------
1561-
np.ndarray or ExtensionArray
1561+
ExtensionArray
15621562
"""
15631563
# return same dims as we currently have
15641564
if not isinstance(slicer, tuple) and self.ndim == 2:
@@ -1736,7 +1736,7 @@ def is_view(self) -> bool:
17361736
def setitem(self, indexer, value):
17371737
if not self._can_hold_element(value):
17381738
# TODO: general case needs casting logic.
1739-
return self.astype(object).setitem(indexer, value)
1739+
return self.astype(_dtype_obj).setitem(indexer, value)
17401740

17411741
values = self.values
17421742
if self.ndim > 1:
@@ -1750,7 +1750,7 @@ def putmask(self, mask, new) -> list[Block]:
17501750
mask = extract_bool_array(mask)
17511751

17521752
if not self._can_hold_element(new):
1753-
return self.astype(object).putmask(mask, new)
1753+
return self.astype(_dtype_obj).putmask(mask, new)
17541754

17551755
arr = self.values
17561756
arr.T.putmask(mask, new)
@@ -1808,7 +1808,7 @@ def fillna(
18081808
# We support filling a DatetimeTZ with a `value` whose timezone
18091809
# is different by coercing to object.
18101810
# TODO: don't special-case td64
1811-
return self.astype(object).fillna(value, limit, inplace, downcast)
1811+
return self.astype(_dtype_obj).fillna(value, limit, inplace, downcast)
18121812

18131813
values = self.values
18141814
values = values if inplace else values.copy()

pandas/core/sorting.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
TYPE_CHECKING,
77
Callable,
88
DefaultDict,
9+
Hashable,
910
Iterable,
1011
Sequence,
1112
)
@@ -576,7 +577,7 @@ def get_flattened_list(
576577

577578
def get_indexer_dict(
578579
label_list: list[np.ndarray], keys: list[Index]
579-
) -> dict[str | tuple, np.ndarray]:
580+
) -> dict[Hashable, npt.NDArray[np.intp]]:
580581
"""
581582
Returns
582583
-------

pandas/tests/groupby/test_grouping.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -688,7 +688,7 @@ def test_groupby_empty(self):
688688
)
689689

690690
tm.assert_numpy_array_equal(
691-
gr.grouper.group_info[1], np.array([], dtype=np.dtype("int"))
691+
gr.grouper.group_info[1], np.array([], dtype=np.dtype(np.intp))
692692
)
693693

694694
assert gr.grouper.group_info[2] == 0

0 commit comments

Comments
 (0)