|
10 | 10 | import collections
|
11 | 11 | import functools
|
12 | 12 | from typing import (
|
| 13 | + Callable, |
13 | 14 | Generic,
|
14 | 15 | Hashable,
|
15 | 16 | Iterator,
|
|
29 | 30 | from pandas._typing import (
|
30 | 31 | ArrayLike,
|
31 | 32 | DtypeObj,
|
32 |
| - F, |
33 | 33 | FrameOrSeries,
|
34 | 34 | Shape,
|
35 | 35 | npt,
|
@@ -700,7 +700,7 @@ def get_iterator(
|
700 | 700 | yield key, group.__finalize__(data, method="groupby")
|
701 | 701 |
|
702 | 702 | @final
|
703 |
| - def _get_splitter(self, data: FrameOrSeries, axis: int = 0) -> DataSplitter: |
| 703 | + def _get_splitter(self, data: NDFrame, axis: int = 0) -> DataSplitter: |
704 | 704 | """
|
705 | 705 | Returns
|
706 | 706 | -------
|
@@ -732,7 +732,9 @@ def group_keys_seq(self):
|
732 | 732 | return get_flattened_list(ids, ngroups, self.levels, self.codes)
|
733 | 733 |
|
734 | 734 | @final
|
735 |
| - def apply(self, f: F, data: FrameOrSeries, axis: int = 0) -> tuple[list, bool]: |
| 735 | + def apply( |
| 736 | + self, f: Callable, data: DataFrame | Series, axis: int = 0 |
| 737 | + ) -> tuple[list, bool]: |
736 | 738 | mutated = self.mutated
|
737 | 739 | splitter = self._get_splitter(data, axis=axis)
|
738 | 740 | group_keys = self.group_keys_seq
|
@@ -918,7 +920,7 @@ def _cython_operation(
|
918 | 920 |
|
919 | 921 | @final
|
920 | 922 | def agg_series(
|
921 |
| - self, obj: Series, func: F, preserve_dtype: bool = False |
| 923 | + self, obj: Series, func: Callable, preserve_dtype: bool = False |
922 | 924 | ) -> ArrayLike:
|
923 | 925 | """
|
924 | 926 | Parameters
|
@@ -960,7 +962,7 @@ def agg_series(
|
960 | 962 |
|
961 | 963 | @final
|
962 | 964 | def _aggregate_series_pure_python(
|
963 |
| - self, obj: Series, func: F |
| 965 | + self, obj: Series, func: Callable |
964 | 966 | ) -> npt.NDArray[np.object_]:
|
965 | 967 | ids, _, ngroups = self.group_info
|
966 | 968 |
|
@@ -1061,7 +1063,7 @@ def _get_grouper(self):
|
1061 | 1063 | """
|
1062 | 1064 | return self
|
1063 | 1065 |
|
1064 |
| - def get_iterator(self, data: FrameOrSeries, axis: int = 0): |
| 1066 | + def get_iterator(self, data: NDFrame, axis: int = 0): |
1065 | 1067 | """
|
1066 | 1068 | Groupby iterator
|
1067 | 1069 |
|
@@ -1142,7 +1144,7 @@ def groupings(self) -> list[grouper.Grouping]:
|
1142 | 1144 | ping = grouper.Grouping(lev, lev, in_axis=False, level=None)
|
1143 | 1145 | return [ping]
|
1144 | 1146 |
|
1145 |
| - def _aggregate_series_fast(self, obj: Series, func: F) -> np.ndarray: |
| 1147 | + def _aggregate_series_fast(self, obj: Series, func: Callable) -> np.ndarray: |
1146 | 1148 | # -> np.ndarray[object]
|
1147 | 1149 | raise NotImplementedError(
|
1148 | 1150 | "This should not be reached; use _aggregate_series_pure_python"
|
@@ -1241,7 +1243,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
|
1241 | 1243 |
|
1242 | 1244 |
|
1243 | 1245 | def get_splitter(
|
1244 |
| - data: FrameOrSeries, labels: np.ndarray, ngroups: int, axis: int = 0 |
| 1246 | + data: NDFrame, labels: np.ndarray, ngroups: int, axis: int = 0 |
1245 | 1247 | ) -> DataSplitter:
|
1246 | 1248 | if isinstance(data, Series):
|
1247 | 1249 | klass: type[DataSplitter] = SeriesSplitter
|
|
0 commit comments