Skip to content

Commit ca84a1e

Browse files
dcherianabrammerheadtr1ck
authored
Fix groupby binary ops when grouped array is subset relative to other (#7798)
* Fix groupby binary ops when grouped array is subset relative to other Closes #7797 * Fix tests Co-authored-by: Alan Brammer <[email protected]> Co-authored-by: Mick <[email protected]> * fix doc build * [skip-ci] Update doc/whats-new.rst --------- Co-authored-by: Alan Brammer <[email protected]> Co-authored-by: Mick <[email protected]>
1 parent 6d17fa0 commit ca84a1e

File tree

3 files changed

+53
-4
lines changed

3 files changed

+53
-4
lines changed

doc/whats-new.rst

+3-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ New Features
2424
~~~~~~~~~~~~
2525
- Added new method :py:meth:`DataArray.to_dask_dataframe`, convert a dataarray into a dask dataframe (:issue:`7409`).
2626
By `Deeksha <https://github.com/dsgreen2>`_.
27-
- Add support for lshift and rshift binary operators (`<<`, `>>`) on
27+
- Add support for lshift and rshift binary operators (``<<``, ``>>``) on
2828
:py:class:`xr.DataArray` of type :py:class:`int` (:issue:`7727` , :pull:`7741`).
2929
By `Alan Brammer <https://github.com/abrammer>`_.
3030

@@ -40,7 +40,8 @@ Deprecations
4040

4141
Bug fixes
4242
~~~~~~~~~
43-
43+
- Fix groupby binary ops when grouped array is subset relative to other. (:issue:`7797`).
44+
By `Deepak Cherian <https://github.com/dcherian>`_.
4445

4546
Documentation
4647
~~~~~~~~~~~~~

xarray/core/groupby.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -669,7 +669,9 @@ def _binary_op(self, other, f, reflexive=False):
669669
obj = obj.where(~mask, drop=True)
670670
codes = codes.where(~mask, drop=True).astype(int)
671671

672-
other, _ = align(other, coord, join="outer")
672+
# codes are defined for coord, so we align `other` with `coord`
673+
# before indexing
674+
other, _ = align(other, coord, join="right")
673675
expanded = other.isel({name: codes})
674676

675677
result = g(obj, expanded)

xarray/tests/test_groupby.py

+47-1
Original file line numberDiff line numberDiff line change
@@ -829,11 +829,33 @@ def test_groupby_math_bitshift() -> None:
829829
}
830830
)
831831

832+
left_manual = []
833+
for lev, group in ds.groupby("level"):
834+
shifter = shift.sel(level=lev)
835+
left_manual.append(group << shifter)
836+
left_actual = xr.concat(left_manual, dim="index").reset_coords(names="level")
837+
assert_equal(left_expected, left_actual)
838+
832839
left_actual = (ds.groupby("level") << shift).reset_coords(names="level")
833840
assert_equal(left_expected, left_actual)
834841

842+
right_expected = Dataset(
843+
{
844+
"x": ("index", [0, 0, 2, 2]),
845+
"y": ("index", [-1, -1, -2, -2]),
846+
"level": ("index", [0, 0, 4, 4]),
847+
"index": [0, 1, 2, 3],
848+
}
849+
)
850+
right_manual = []
851+
for lev, group in left_expected.groupby("level"):
852+
shifter = shift.sel(level=lev)
853+
right_manual.append(group >> shifter)
854+
right_actual = xr.concat(right_manual, dim="index").reset_coords(names="level")
855+
assert_equal(right_expected, right_actual)
856+
835857
right_actual = (left_expected.groupby("level") >> shift).reset_coords(names="level")
836-
assert_equal(ds, right_actual)
858+
assert_equal(right_expected, right_actual)
837859

838860

839861
@pytest.mark.parametrize("use_flox", [True, False])
@@ -1302,8 +1324,15 @@ def test_groupby_math_not_aligned(self):
13021324
expected = DataArray([10, 11, np.nan, np.nan], array.coords)
13031325
assert_identical(expected, actual)
13041326

1327+
# regression test for #7797
1328+
other = array.groupby("b").sum()
1329+
actual = array.sel(x=[0, 1]).groupby("b") - other
1330+
expected = DataArray([-1, 0], {"b": ("x", [0, 0]), "x": [0, 1]}, dims="x")
1331+
assert_identical(expected, actual)
1332+
13051333
other = DataArray([10], coords={"c": 123, "b": [0]}, dims="b")
13061334
actual = array.groupby("b") + other
1335+
expected = DataArray([10, 11, np.nan, np.nan], array.coords)
13071336
expected.coords["c"] = (["x"], [123] * 2 + [np.nan] * 2)
13081337
assert_identical(expected, actual)
13091338

@@ -2289,3 +2318,20 @@ def test_resample_cumsum(method: str, expected_array: list[float]) -> None:
22892318
actual = getattr(ds.foo.resample(time="3M"), method)(dim="time")
22902319
expected.coords["time"] = ds.time
22912320
assert_identical(expected.drop_vars(["time"]).foo, actual)
2321+
2322+
2323+
def test_groupby_binary_op_regression() -> None:
2324+
# regression test for #7797
2325+
# monthly timeseries that should return "zero anomalies" everywhere
2326+
time = xr.date_range("2023-01-01", "2023-12-31", freq="MS")
2327+
data = np.linspace(-1, 1, 12)
2328+
x = xr.DataArray(data, coords={"time": time})
2329+
clim = xr.DataArray(data, coords={"month": np.arange(1, 13, 1)})
2330+
2331+
# seems to give the correct result if we use the full x, but not with a slice
2332+
x_slice = x.sel(time=["2023-04-01"])
2333+
2334+
# two typical ways of computing anomalies
2335+
anom_gb = x_slice.groupby("time.month") - clim
2336+
2337+
assert_identical(xr.zeros_like(anom_gb), anom_gb)

0 commit comments

Comments
 (0)