7
7
"""
8
8
9
9
import collections
10
- from typing import Any , Dict , List , Optional , Sequence , Tuple , Type
10
+ from typing import Any , Dict , Generic , List , Optional , Sequence , Tuple , Type
11
11
12
12
import numpy as np
13
13
41
41
from pandas .core .base import SelectionMixin
42
42
import pandas .core .common as com
43
43
from pandas .core .frame import DataFrame
44
- from pandas .core .generic import NDFrame
45
44
from pandas .core .groupby import base , grouper
46
45
from pandas .core .index import Index , MultiIndex , ensure_index
47
46
from pandas .core .series import Series
@@ -861,7 +860,7 @@ def _is_indexed_like(obj, axes) -> bool:
861
860
# Splitting / application
862
861
863
862
864
- class DataSplitter :
863
+ class DataSplitter ( Generic [ FrameOrSeries ]) :
865
864
def __init__ (self , data : FrameOrSeries , labels , ngroups : int , axis : int = 0 ):
866
865
self .data = data
867
866
self .labels = ensure_int64 (labels )
@@ -896,7 +895,7 @@ def __iter__(self):
896
895
def _get_sorted_data (self ) -> FrameOrSeries :
897
896
return self .data .take (self .sort_idx , axis = self .axis )
898
897
899
- def _chop (self , sdata , slice_obj : slice ) -> NDFrame :
898
+ def _chop (self , sdata : FrameOrSeries , slice_obj : slice ) -> FrameOrSeries :
900
899
raise AbstractMethodError (self )
901
900
902
901
@@ -920,7 +919,7 @@ def _chop(self, sdata: DataFrame, slice_obj: slice) -> DataFrame:
920
919
return sdata ._slice (slice_obj , axis = 1 )
921
920
922
921
923
- def get_splitter (data : FrameOrSeries , * args , ** kwargs ) -> DataSplitter :
922
+ def get_splitter (data : FrameOrSeries , * args , ** kwargs ) -> " DataSplitter[FrameOrSeries]" :
924
923
klass : Type [DataSplitter ]
925
924
if isinstance (data , Series ):
926
925
klass = SeriesSplitter
0 commit comments