Skip to content

Commit fde773e

Browse files
dcherianIllviljanpre-commit-ci[bot]
authored
Introduce Grouper objects internally (#7561)
* Introduce Grouper objects. * Remove a copy after stacking for a groupby. Upstream bug pandas-dev/pandas#12813 is fixed * Fix typing * [WIP] typing * Cleanup * [WIP] * group as Variable? * Revert "group as Variable?" This reverts commit 2a36e21a031b9e061b932682758551956f3f06d2. * Small cleanup * De-duplicate alignment check * Fix resampling * Bugfix * Partial reverts commit 22ad7fa. * fix tests * small cleanup * more cleanup * Apply suggestions from code review * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add ResolvedGrouper class * GroupBy only handles ResolvedGrouper objects. Much cleaner! * review feedback * minimize diff * dataclass * moar dataclass Co-authored-by: Illviljan <[email protected]> * Add typing * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Ignore type checking error. * Update groupby.py * Move factorize to _factorize * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update groupby.py * Update xarray/core/groupby.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Calculate group_indices only when necessary * Revert "Calculate group_indices only when necessary" This reverts commit 917c77efb05bacffcf901e61eabb9defc9a429d7. * Fix regression from deep copy --------- Co-authored-by: Illviljan <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent da8746b commit fde773e

File tree

7 files changed

+550
-354
lines changed

7 files changed

+550
-354
lines changed

xarray/core/common.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -949,7 +949,7 @@ def _resample(
949949
# TODO support non-string indexer after removing the old API.
950950

951951
from xarray.core.dataarray import DataArray
952-
from xarray.core.groupby import TimeResampleGrouper
952+
from xarray.core.groupby import ResolvedTimeResampleGrouper, TimeResampleGrouper
953953
from xarray.core.resample import RESAMPLE_DIM
954954

955955
if keep_attrs is not None:
@@ -1012,11 +1012,13 @@ def _resample(
10121012
group = DataArray(
10131013
dim_coord, coords=dim_coord.coords, dims=dim_coord.dims, name=RESAMPLE_DIM
10141014
)
1015+
1016+
rgrouper = ResolvedTimeResampleGrouper(grouper, group, self)
1017+
10151018
return resample_cls(
10161019
self,
1017-
group=group,
1020+
(rgrouper,),
10181021
dim=dim_name,
1019-
grouper=grouper,
10201022
resample_dim=RESAMPLE_DIM,
10211023
restore_coord_dims=restore_coord_dims,
10221024
)

xarray/core/computation.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -515,15 +515,16 @@ def apply_groupby_func(func, *args):
515515
groupbys = [arg for arg in args if isinstance(arg, GroupBy)]
516516
assert groupbys, "must have at least one groupby to iterate over"
517517
first_groupby = groupbys[0]
518-
if any(not first_groupby._group.equals(gb._group) for gb in groupbys[1:]):
518+
(grouper,) = first_groupby.groupers
519+
if any(not grouper.group.equals(gb.groupers[0].group) for gb in groupbys[1:]):
519520
raise ValueError(
520521
"apply_ufunc can only perform operations over "
521522
"multiple GroupBy objects at once if they are all "
522523
"grouped the same way"
523524
)
524525

525-
grouped_dim = first_groupby._group.name
526-
unique_values = first_groupby._unique_coord.values
526+
grouped_dim = grouper.name
527+
unique_values = grouper.unique_coord.values
527528

528529
iterators = []
529530
for arg in args:

xarray/core/dataarray.py

+28-19
Original file line numberDiff line numberDiff line change
@@ -6478,21 +6478,20 @@ def groupby(
64786478
core.groupby.DataArrayGroupBy
64796479
pandas.DataFrame.groupby
64806480
"""
6481-
from xarray.core.groupby import DataArrayGroupBy
6482-
6483-
# While we don't generally check the type of every arg, passing
6484-
# multiple dimensions as multiple arguments is common enough, and the
6485-
# consequences hidden enough (strings evaluate as true) to warrant
6486-
# checking here.
6487-
# A future version could make squeeze kwarg only, but would face
6488-
# backward-compat issues.
6489-
if not isinstance(squeeze, bool):
6490-
raise TypeError(
6491-
f"`squeeze` must be True or False, but {squeeze} was supplied"
6492-
)
6481+
from xarray.core.groupby import (
6482+
DataArrayGroupBy,
6483+
ResolvedUniqueGrouper,
6484+
UniqueGrouper,
6485+
_validate_groupby_squeeze,
6486+
)
64936487

6488+
_validate_groupby_squeeze(squeeze)
6489+
rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self)
64946490
return DataArrayGroupBy(
6495-
self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
6491+
self,
6492+
(rgrouper,),
6493+
squeeze=squeeze,
6494+
restore_coord_dims=restore_coord_dims,
64966495
)
64976496

64986497
def groupby_bins(
@@ -6563,21 +6562,31 @@ def groupby_bins(
65636562
----------
65646563
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
65656564
"""
6566-
from xarray.core.groupby import DataArrayGroupBy
6565+
from xarray.core.groupby import (
6566+
BinGrouper,
6567+
DataArrayGroupBy,
6568+
ResolvedBinGrouper,
6569+
_validate_groupby_squeeze,
6570+
)
65676571

6568-
return DataArrayGroupBy(
6569-
self,
6570-
group,
6571-
squeeze=squeeze,
6572+
_validate_groupby_squeeze(squeeze)
6573+
grouper = BinGrouper(
65726574
bins=bins,
6573-
restore_coord_dims=restore_coord_dims,
65746575
cut_kwargs={
65756576
"right": right,
65766577
"labels": labels,
65776578
"precision": precision,
65786579
"include_lowest": include_lowest,
65796580
},
65806581
)
6582+
rgrouper = ResolvedBinGrouper(grouper, group, self)
6583+
6584+
return DataArrayGroupBy(
6585+
self,
6586+
(rgrouper,),
6587+
squeeze=squeeze,
6588+
restore_coord_dims=restore_coord_dims,
6589+
)
65816590

65826591
def weighted(self, weights: DataArray) -> DataArrayWeighted:
65836592
"""

xarray/core/dataset.py

+29-19
Original file line numberDiff line numberDiff line change
@@ -8958,21 +8958,21 @@ def groupby(
89588958
Dataset.resample
89598959
DataArray.resample
89608960
"""
8961-
from xarray.core.groupby import DatasetGroupBy
8962-
8963-
# While we don't generally check the type of every arg, passing
8964-
# multiple dimensions as multiple arguments is common enough, and the
8965-
# consequences hidden enough (strings evaluate as true) to warrant
8966-
# checking here.
8967-
# A future version could make squeeze kwarg only, but would face
8968-
# backward-compat issues.
8969-
if not isinstance(squeeze, bool):
8970-
raise TypeError(
8971-
f"`squeeze` must be True or False, but {squeeze} was supplied"
8972-
)
8961+
from xarray.core.groupby import (
8962+
DatasetGroupBy,
8963+
ResolvedUniqueGrouper,
8964+
UniqueGrouper,
8965+
_validate_groupby_squeeze,
8966+
)
8967+
8968+
_validate_groupby_squeeze(squeeze)
8969+
rgrouper = ResolvedUniqueGrouper(UniqueGrouper(), group, self)
89738970

89748971
return DatasetGroupBy(
8975-
self, group, squeeze=squeeze, restore_coord_dims=restore_coord_dims
8972+
self,
8973+
(rgrouper,),
8974+
squeeze=squeeze,
8975+
restore_coord_dims=restore_coord_dims,
89768976
)
89778977

89788978
def groupby_bins(
@@ -9043,21 +9043,31 @@ def groupby_bins(
90439043
----------
90449044
.. [1] http://pandas.pydata.org/pandas-docs/stable/generated/pandas.cut.html
90459045
"""
9046-
from xarray.core.groupby import DatasetGroupBy
9046+
from xarray.core.groupby import (
9047+
BinGrouper,
9048+
DatasetGroupBy,
9049+
ResolvedBinGrouper,
9050+
_validate_groupby_squeeze,
9051+
)
90479052

9048-
return DatasetGroupBy(
9049-
self,
9050-
group,
9051-
squeeze=squeeze,
9053+
_validate_groupby_squeeze(squeeze)
9054+
grouper = BinGrouper(
90529055
bins=bins,
9053-
restore_coord_dims=restore_coord_dims,
90549056
cut_kwargs={
90559057
"right": right,
90569058
"labels": labels,
90579059
"precision": precision,
90589060
"include_lowest": include_lowest,
90599061
},
90609062
)
9063+
rgrouper = ResolvedBinGrouper(grouper, group, self)
9064+
9065+
return DatasetGroupBy(
9066+
self,
9067+
(rgrouper,),
9068+
squeeze=squeeze,
9069+
restore_coord_dims=restore_coord_dims,
9070+
)
90619071

90629072
def weighted(self, weights: DataArray) -> DatasetWeighted:
90639073
"""

0 commit comments

Comments
 (0)