@@ -61,8 +61,7 @@ class BaseGrouper:
61
61
62
62
Parameters
63
63
----------
64
- axis : int
65
- the axis to group
64
+ axis : Index
66
65
groupings : array of grouping
67
66
all the grouping instances to handle in this grouper
68
67
for example for grouper list to groupby, need to pass the list
@@ -78,8 +77,15 @@ class BaseGrouper:
78
77
"""
79
78
80
79
def __init__ (
81
- self , axis , groupings , sort = True , group_keys = True , mutated = False , indexer = None
80
+ self ,
81
+ axis : Index ,
82
+ groupings ,
83
+ sort = True ,
84
+ group_keys = True ,
85
+ mutated = False ,
86
+ indexer = None ,
82
87
):
88
+ assert isinstance (axis , Index ), axis
83
89
self ._filter_empty_groups = self .compressed = len (groupings ) != 1
84
90
self .axis = axis
85
91
self .groupings = groupings
@@ -623,7 +629,7 @@ def _aggregate_series_pure_python(self, obj, func):
623
629
counts = np .zeros (ngroups , dtype = int )
624
630
result = None
625
631
626
- splitter = get_splitter (obj , group_index , ngroups , axis = self . axis )
632
+ splitter = get_splitter (obj , group_index , ngroups , axis = 0 )
627
633
628
634
for label , group in splitter :
629
635
res = func (group )
@@ -635,8 +641,12 @@ def _aggregate_series_pure_python(self, obj, func):
635
641
counts [label ] = group .shape [0 ]
636
642
result [label ] = res
637
643
638
- result = lib .maybe_convert_objects (result , try_float = 0 )
639
- # TODO: try_cast back to EA?
644
+ if result is not None :
645
+ # if splitter is empty, result can be None, in which case
646
+ # maybe_convert_objects would raise TypeError
647
+ result = lib .maybe_convert_objects (result , try_float = 0 )
648
+ # TODO: try_cast back to EA?
649
+
640
650
return result , counts
641
651
642
652
@@ -781,6 +791,11 @@ def groupings(self):
781
791
]
782
792
783
793
def agg_series (self , obj : Series , func ):
794
+ if is_extension_array_dtype (obj .dtype ):
795
+ # pre-empty SeriesBinGrouper from raising TypeError
796
+ # TODO: watch out, this can return None
797
+ return self ._aggregate_series_pure_python (obj , func )
798
+
784
799
dummy = obj [:0 ]
785
800
grouper = libreduction .SeriesBinGrouper (obj , func , self .bins , dummy )
786
801
return grouper .get_result ()
@@ -809,12 +824,13 @@ def _is_indexed_like(obj, axes) -> bool:
809
824
810
825
811
826
class DataSplitter :
812
- def __init__ (self , data , labels , ngroups , axis = 0 ):
827
+ def __init__ (self , data , labels , ngroups , axis : int = 0 ):
813
828
self .data = data
814
829
self .labels = ensure_int64 (labels )
815
830
self .ngroups = ngroups
816
831
817
832
self .axis = axis
833
+ assert isinstance (axis , int ), axis
818
834
819
835
@cache_readonly
820
836
def slabels (self ):
@@ -837,12 +853,6 @@ def __iter__(self):
837
853
starts , ends = lib .generate_slices (self .slabels , self .ngroups )
838
854
839
855
for i , (start , end ) in enumerate (zip (starts , ends )):
840
- # Since I'm now compressing the group ids, it's now not "possible"
841
- # to produce empty slices because such groups would not be observed
842
- # in the data
843
- # if start >= end:
844
- # raise AssertionError('Start %s must be less than end %s'
845
- # % (str(start), str(end)))
846
856
yield i , self ._chop (sdata , slice (start , end ))
847
857
848
858
def _get_sorted_data (self ):
0 commit comments