@@ -17,13 +17,15 @@ class providing the base-class of operations.
17
17
Callable ,
18
18
Dict ,
19
19
FrozenSet ,
20
+ Generic ,
20
21
Hashable ,
21
22
Iterable ,
22
23
List ,
23
24
Mapping ,
24
25
Optional ,
25
26
Tuple ,
26
27
Type ,
28
+ TypeVar ,
27
29
Union ,
28
30
)
29
31
@@ -376,13 +378,14 @@ def _group_selection_context(groupby):
376
378
]
377
379
378
380
379
- class _GroupBy (PandasObject , SelectionMixin ):
381
+ class _GroupBy (PandasObject , SelectionMixin , Generic [ FrameOrSeries ] ):
380
382
_group_selection = None
381
383
_apply_whitelist : FrozenSet [str ] = frozenset ()
384
+ obj : FrameOrSeries
382
385
383
386
def __init__ (
384
387
self ,
385
- obj : NDFrame ,
388
+ obj : FrameOrSeries ,
386
389
keys : Optional [_KeysArgType ] = None ,
387
390
axis : int = 0 ,
388
391
level = None ,
@@ -1079,7 +1082,11 @@ def _apply_filter(self, indices, dropna):
1079
1082
return filtered
1080
1083
1081
1084
1082
- class GroupBy (_GroupBy ):
1085
+ # We require another typevar to track operations that expand dimensions, like ohlc
1086
+ FrameOrSeries2 = TypeVar ("FrameOrSeries2" , bound = NDFrame )
1087
+
1088
+
1089
+ class GroupBy (_GroupBy [FrameOrSeries ]):
1083
1090
"""
1084
1091
Class for grouping and aggregating relational data.
1085
1092
@@ -1390,25 +1397,25 @@ def size(self):
1390
1397
return self ._reindex_output (result , fill_value = 0 )
1391
1398
1392
1399
@doc (_agg_template , fname = "sum" , no = True , mc = 0 )
1393
- def sum (self , numeric_only : bool = True , min_count : int = 0 ):
1400
+ def sum (self , numeric_only : bool = True , min_count : int = 0 ) -> FrameOrSeries :
1394
1401
return self ._agg_general (
1395
1402
numeric_only = numeric_only , min_count = min_count , alias = "add" , npfunc = np .sum
1396
1403
)
1397
1404
1398
1405
@doc (_agg_template , fname = "prod" , no = True , mc = 0 )
1399
- def prod (self , numeric_only : bool = True , min_count : int = 0 ):
1406
+ def prod (self , numeric_only : bool = True , min_count : int = 0 ) -> FrameOrSeries :
1400
1407
return self ._agg_general (
1401
1408
numeric_only = numeric_only , min_count = min_count , alias = "prod" , npfunc = np .prod
1402
1409
)
1403
1410
1404
1411
@doc (_agg_template , fname = "min" , no = False , mc = - 1 )
1405
- def min (self , numeric_only : bool = False , min_count : int = - 1 ):
1412
+ def min (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
1406
1413
return self ._agg_general (
1407
1414
numeric_only = numeric_only , min_count = min_count , alias = "min" , npfunc = np .min
1408
1415
)
1409
1416
1410
1417
@doc (_agg_template , fname = "max" , no = False , mc = - 1 )
1411
- def max (self , numeric_only : bool = False , min_count : int = - 1 ):
1418
+ def max (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
1412
1419
return self ._agg_general (
1413
1420
numeric_only = numeric_only , min_count = min_count , alias = "max" , npfunc = np .max
1414
1421
)
@@ -1431,7 +1438,7 @@ def get_loc_notna(x, loc: int):
1431
1438
return get_loc_notna (x , loc = loc )
1432
1439
1433
1440
@doc (_agg_template , fname = "first" , no = False , mc = - 1 )
1434
- def first (self , numeric_only : bool = False , min_count : int = - 1 ):
1441
+ def first (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
1435
1442
first_compat = partial (self ._get_loc , loc = 0 )
1436
1443
1437
1444
return self ._agg_general (
@@ -1441,8 +1448,7 @@ def first(self, numeric_only: bool = False, min_count: int = -1):
1441
1448
npfunc = first_compat ,
1442
1449
)
1443
1450
1444
- @doc (_agg_template , fname = "last" , no = False , mc = - 1 )
1445
- def last (self , numeric_only : bool = False , min_count : int = - 1 ):
1451
+ def last (self , numeric_only : bool = False , min_count : int = - 1 ) -> FrameOrSeries :
1446
1452
last_compat = partial (self ._get_loc , loc = - 1 )
1447
1453
1448
1454
return self ._agg_general (
@@ -2467,8 +2473,8 @@ def tail(self, n=5):
2467
2473
return self ._selected_obj [mask ]
2468
2474
2469
2475
def _reindex_output (
2470
- self , output : FrameOrSeries , fill_value : Scalar = np .NaN
2471
- ) -> FrameOrSeries :
2476
+ self , output : FrameOrSeries2 , fill_value : Scalar = np .NaN
2477
+ ) -> FrameOrSeries2 :
2472
2478
"""
2473
2479
If we have categorical groupers, then we might want to make sure that
2474
2480
we have a fully re-indexed output to the levels. This means expanding
0 commit comments