Skip to content

Commit c905b74

Browse files
committed
Cleanup
1 parent 1168ab7 commit c905b74

File tree

1 file changed

+34
-26
lines changed

1 file changed

+34
-26
lines changed

xarray/core/groupby.py

+34-26
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
from xarray.core.utils import Frozen
4545

4646
GroupKey = Any
47+
GroupIndex = int | slice | list[int]
4748

4849
T_GroupIndicesListInt = list[list[int]]
4950
T_GroupIndices = Union[T_GroupIndicesListInt, list[slice], np.ndarray]
@@ -129,11 +130,11 @@ def _dummy_copy(xarray_obj):
129130
return res
130131

131132

132-
def _is_one_or_none(obj):
133+
def _is_one_or_none(obj) -> bool:
133134
return obj == 1 or obj is None
134135

135136

136-
def _consolidate_slices(slices):
137+
def _consolidate_slices(slices: list[slice]) -> list[slice]:
137138
"""Consolidate adjacent slices in a list of slices."""
138139
result = []
139140
last_slice = slice(None)
@@ -191,7 +192,6 @@ def __init__(self, obj: T_Xarray, name: Hashable, coords) -> None:
191192
self.name = name
192193
self.coords = coords
193194
self.size = obj.sizes[name]
194-
self.dataarray = obj[name]
195195

196196
@property
197197
def dims(self) -> tuple[Hashable]:
@@ -228,6 +228,13 @@ def __getitem__(self, key):
228228
def copy(self, deep: bool = True, data: Any = None):
229229
raise NotImplementedError
230230

231+
def as_dataarray(self) -> DataArray:
232+
from xarray.core.dataarray import DataArray
233+
234+
return DataArray(
235+
data=self.data, dims=(self.name,), coords=self.coords, name=self.name
236+
)
237+
231238

232239
T_Group = TypeVar("T_Group", bound=Union["DataArray", "IndexVariable", _DummyGroup])
233240

@@ -294,14 +301,16 @@ def _apply_loffset(
294301

295302

296303
class Grouper:
297-
def __init__(self, group: T_Group):
298-
self.group : T_Group | None = group
299-
self.codes : np.ndarry | None = None
304+
def __init__(self, group: T_Group | Hashable):
305+
self.group: T_Group | Hashable = group
306+
300307
self.labels = None
301-
self.group_indices : list[list[int, ...]] | None= None
302-
self.unique_coord = None
303-
self.full_index : pd.Index | None = None
304-
self._group_as_index = None
308+
self._group_as_index: pd.Index | None = None
309+
310+
self.codes: DataArray
311+
self.group_indices: list[int] | list[slice] | list[list[int]]
312+
self.unique_coord: IndexVariable | _DummyGroup
313+
self.full_index: pd.Index
305314

306315
@property
307316
def name(self) -> Hashable:
@@ -334,10 +343,9 @@ def group_as_index(self) -> pd.Index:
334343
self._group_as_index = safe_cast_to_index(self.group1d)
335344
return self._group_as_index
336345

337-
def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
346+
def _resolve_group(self, obj: T_Xarray):
338347
from xarray.core.dataarray import DataArray
339348

340-
group: T_Group
341349
group = self.group
342350
if not isinstance(group, (DataArray, IndexVariable)):
343351
if not hashable(group):
@@ -346,15 +354,14 @@ def _resolve_group(self, obj: T_DataArray | T_Dataset) -> None:
346354
"name of an xarray variable or dimension. "
347355
f"Received {group!r} instead."
348356
)
349-
group_da : T_DataArray = obj[group]
350-
if len(group_da) == 0:
351-
raise ValueError(f"{group_da.name} must not be empty")
352-
353-
if group_da.name not in obj.coords and group_da.name in obj.dims:
357+
group = obj[group]
358+
if len(group) == 0:
359+
raise ValueError(f"{group.name} must not be empty")
360+
if group.name not in obj._indexes and group.name in obj.dims:
354361
# DummyGroups should not appear on groupby results
355362
group = _DummyGroup(obj, group.name, group.coords)
356363

357-
if getattr(group, "name", None) is None:
364+
elif getattr(group, "name", None) is None:
358365
group.name = "group"
359366

360367
self.group = group
@@ -408,10 +415,10 @@ def _factorize_dummy(self, squeeze) -> None:
408415
# equivalent to: group_indices = group_indices.reshape(-1, 1)
409416
self.group_indices = [slice(i, i + 1) for i in range(size)]
410417
else:
411-
self.group_indices = np.arange(size)
418+
self.group_indices = list(range(size))
412419
codes = np.arange(size)
413420
if isinstance(self.group, _DummyGroup):
414-
self.codes = self.group.dataarray.copy(data=codes)
421+
self.codes = self.group.as_dataarray().copy(data=codes)
415422
else:
416423
self.codes = self.group.copy(data=codes)
417424
self.unique_coord = self.group
@@ -489,7 +496,7 @@ def __init__(
489496
raise ValueError("index must be monotonic for resampling")
490497

491498
if isinstance(group_as_index, CFTimeIndex):
492-
self.grouper = CFTimeGrouper(
499+
grouper = CFTimeGrouper(
493500
freq=self.freq,
494501
closed=self.closed,
495502
label=self.label,
@@ -498,15 +505,16 @@ def __init__(
498505
loffset=self.loffset,
499506
)
500507
else:
501-
self.grouper = pd.Grouper(
508+
grouper = pd.Grouper(
502509
freq=self.freq,
503510
closed=self.closed,
504511
label=self.label,
505512
origin=self.origin,
506513
offset=self.offset,
507514
)
515+
self.grouper: CFTimeGrouper | pd.Grouper = grouper
508516

509-
def _get_index_and_items(self):
517+
def _get_index_and_items(self) -> tuple[pd.Index, pd.Series, np.ndarray]:
510518
first_items, codes = self.first_items()
511519
full_index = first_items.index
512520
if first_items.isnull().any():
@@ -515,7 +523,7 @@ def _get_index_and_items(self):
515523
full_index = full_index.rename("__resample_dim__")
516524
return full_index, first_items, codes
517525

518-
def first_items(self):
526+
def first_items(self) -> tuple[pd.Series, np.ndarray]:
519527
from xarray import CFTimeIndex
520528

521529
if isinstance(self.group_as_index, CFTimeIndex):
@@ -670,7 +678,7 @@ def reduce(
670678
raise NotImplementedError()
671679

672680
@property
673-
def groups(self) -> dict[GroupKey, slice | int | list[int]]:
681+
def groups(self) -> dict[GroupKey, GroupIndex]:
674682
"""
675683
Mapping from group labels to indices. The indices can be used to index the underlying object.
676684
"""
@@ -735,7 +743,7 @@ def _binary_op(self, other, f, reflexive=False):
735743
dims = group.dims
736744

737745
if isinstance(group, _DummyGroup):
738-
group = coord = group.dataarray
746+
group = coord = group.as_dataarray()
739747
else:
740748
coord = grouper.unique_coord
741749
if not isinstance(coord, DataArray):

0 commit comments

Comments
 (0)