44
44
from xarray .core .utils import Frozen
45
45
46
46
GroupKey = Any
47
+ GroupIndex = int | slice | list [int ]
47
48
48
49
T_GroupIndicesListInt = list [list [int ]]
49
50
T_GroupIndices = Union [T_GroupIndicesListInt , list [slice ], np .ndarray ]
@@ -129,11 +130,11 @@ def _dummy_copy(xarray_obj):
129
130
return res
130
131
131
132
132
- def _is_one_or_none (obj ):
133
+ def _is_one_or_none (obj ) -> bool :
133
134
return obj == 1 or obj is None
134
135
135
136
136
- def _consolidate_slices (slices ) :
137
+ def _consolidate_slices (slices : list [ slice ]) -> list [ slice ] :
137
138
"""Consolidate adjacent slices in a list of slices."""
138
139
result = []
139
140
last_slice = slice (None )
@@ -191,7 +192,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
191
192
self .name = name
192
193
self .coords = coords
193
194
self .size = obj .sizes [name ]
194
- self .dataarray = obj [name ]
195
195
196
196
@property
197
197
def dims (self ) -> tuple [Hashable ]:
@@ -228,6 +228,13 @@ def __getitem__(self, key):
228
228
def copy (self , deep : bool = True , data : Any = None ):
229
229
raise NotImplementedError
230
230
231
+ def as_dataarray (self ) -> DataArray :
232
+ from xarray .core .dataarray import DataArray
233
+
234
+ return DataArray (
235
+ data = self .data , dims = (self .name ,), coords = self .coords , name = self .name
236
+ )
237
+
231
238
232
239
T_Group = TypeVar ("T_Group" , bound = Union ["DataArray" , "IndexVariable" , _DummyGroup ])
233
240
@@ -294,14 +301,16 @@ def _apply_loffset(
294
301
295
302
296
303
class Grouper :
297
- def __init__ (self , group : T_Group ):
298
- self .group : T_Group | None = group
299
- self . codes : np . ndarry | None = None
304
+ def __init__ (self , group : T_Group | Hashable ):
305
+ self .group : T_Group | Hashable = group
306
+
300
307
self .labels = None
301
- self .group_indices : list [list [int , ...]] | None = None
302
- self .unique_coord = None
303
- self .full_index : pd .Index | None = None
304
- self ._group_as_index = None
308
+ self ._group_as_index : pd .Index | None = None
309
+
310
+ self .codes : DataArray
311
+ self .group_indices : list [int ] | list [slice ] | list [list [int ]]
312
+ self .unique_coord : IndexVariable | _DummyGroup
313
+ self .full_index : pd .Index
305
314
306
315
@property
307
316
def name (self ) -> Hashable :
@@ -334,10 +343,9 @@ def group_as_index(self) -> pd.Index:
334
343
self ._group_as_index = safe_cast_to_index (self .group1d )
335
344
return self ._group_as_index
336
345
337
- def _resolve_group (self , obj : T_DataArray | T_Dataset ) -> None :
346
+ def _resolve_group (self , obj : T_Xarray ) :
338
347
from xarray .core .dataarray import DataArray
339
348
340
- group : T_Group
341
349
group = self .group
342
350
if not isinstance (group , (DataArray , IndexVariable )):
343
351
if not hashable (group ):
@@ -346,15 +354,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
346
354
"name of an xarray variable or dimension. "
347
355
f"Received { group !r} instead."
348
356
)
349
- group_da : T_DataArray = obj [group ]
350
- if len (group_da ) == 0 :
351
- raise ValueError (f"{ group_da .name } must not be empty" )
352
-
353
- if group_da .name not in obj .coords and group_da .name in obj .dims :
357
+ group = obj [group ]
358
+ if len (group ) == 0 :
359
+ raise ValueError (f"{ group .name } must not be empty" )
360
+ if group .name not in obj ._indexes and group .name in obj .dims :
354
361
# DummyGroups should not appear on groupby results
355
362
group = _DummyGroup (obj , group .name , group .coords )
356
363
357
- if getattr (group , "name" , None ) is None :
364
+ elif getattr (group , "name" , None ) is None :
358
365
group .name = "group"
359
366
360
367
self .group = group
@@ -408,10 +415,10 @@ def _factorize_dummy(self, squeeze) -> None:
408
415
# equivalent to: group_indices = group_indices.reshape(-1, 1)
409
416
self .group_indices = [slice (i , i + 1 ) for i in range (size )]
410
417
else :
411
- self .group_indices = np . arange ( size )
418
+ self .group_indices = list ( range ( size ) )
412
419
codes = np .arange (size )
413
420
if isinstance (self .group , _DummyGroup ):
414
- self .codes = self .group .dataarray .copy (data = codes )
421
+ self .codes = self .group .as_dataarray () .copy (data = codes )
415
422
else :
416
423
self .codes = self .group .copy (data = codes )
417
424
self .unique_coord = self .group
@@ -489,7 +496,7 @@ def __init__(
489
496
raise ValueError ("index must be monotonic for resampling" )
490
497
491
498
if isinstance (group_as_index , CFTimeIndex ):
492
- self . grouper = CFTimeGrouper (
499
+ grouper = CFTimeGrouper (
493
500
freq = self .freq ,
494
501
closed = self .closed ,
495
502
label = self .label ,
@@ -498,15 +505,16 @@ def __init__(
498
505
loffset = self .loffset ,
499
506
)
500
507
else :
501
- self . grouper = pd .Grouper (
508
+ grouper = pd .Grouper (
502
509
freq = self .freq ,
503
510
closed = self .closed ,
504
511
label = self .label ,
505
512
origin = self .origin ,
506
513
offset = self .offset ,
507
514
)
515
+ self .grouper : CFTimeGrouper | pd .Grouper = grouper
508
516
509
- def _get_index_and_items (self ):
517
+ def _get_index_and_items (self ) -> tuple [ pd . Index , pd . Series , np . ndarray ] :
510
518
first_items , codes = self .first_items ()
511
519
full_index = first_items .index
512
520
if first_items .isnull ().any ():
@@ -515,7 +523,7 @@ def _get_index_and_items(self):
515
523
full_index = full_index .rename ("__resample_dim__" )
516
524
return full_index , first_items , codes
517
525
518
- def first_items (self ):
526
+ def first_items (self ) -> tuple [ pd . Series , np . ndarray ] :
519
527
from xarray import CFTimeIndex
520
528
521
529
if isinstance (self .group_as_index , CFTimeIndex ):
@@ -670,7 +678,7 @@ def reduce(
670
678
raise NotImplementedError ()
671
679
672
680
@property
673
- def groups (self ) -> dict [GroupKey , slice | int | list [ int ] ]:
681
+ def groups (self ) -> dict [GroupKey , GroupIndex ]:
674
682
"""
675
683
Mapping from group labels to indices. The indices can be used to index the underlying object.
676
684
"""
@@ -735,7 +743,7 @@ def _binary_op(self, other, f, reflexive=False):
735
743
dims = group .dims
736
744
737
745
if isinstance (group , _DummyGroup ):
738
- group = coord = group .dataarray
746
+ group = coord = group .as_dataarray ()
739
747
else :
740
748
coord = grouper .unique_coord
741
749
if not isinstance (coord , DataArray ):
0 commit comments