Skip to content

BUG/TST: ensure groupby.agg preserves extension dtype #29144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
4 changes: 2 additions & 2 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,15 +655,15 @@ def agg_series(self, obj, func):
return self._aggregate_series_fast(obj, func)
except AssertionError:
raise
except ValueError as err:
except (ValueError, AttributeError) as err:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the AttributeError in reference?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See the issue #29141. Now, since #29100 is already merged, the AttributeError might not be needed anymore for our tests to pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you confirm whether 29100 makes this change unnecessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It only solves the fact that an AttributeError needs to be catched (I removed that in the last commit), it does not make the tests pass (as I mentioned yesterday)

if "No result." in str(err):
# raised in libreduction
pass
elif "Function does not reduce" in str(err):
# raised in libreduction
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel is it the intention to keep this long term, or is this also planned to be cleaned-up in a follow-up?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cleaned

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And what's the idea of how to clean this up?
(we could also raise a more specific internal error (subclass) to catch)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(we could also raise a more specific internal error (subclass) to catch)

That's what I'm leaning towards, yah

pass
else:
raise
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel the whole block is now basically an except (..): pass (so I could make it much shorter), but you might have put in those specific checks as pointers to what to clean up later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so something is being done with a DecimalArray that raises a TypeError? and i guess by not-catching it here it is getting caught somewhere above that isn't doing re-casting appropriately? can we track down where that is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in _aggregate_series_fast there is a call to libreduction that tries to assign this DecimalArray to a name that libreduction has typed as an ndarray, which raises TypeError

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is expected I think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we get to decide what is "expected"; maybe i dont understand what you're getting at

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, can you do the same str(err) checking so that we only let through the relevant TypeErrors? One of the other PRs afoot is specifically targeting other TypeErrors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an error message coming from another library though (cython?). Do we have guarantee that that is stable?
(the other two errors that are catched that way are raised by our own code in reduction.pyx)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have guarantee that that is stable?

If nothing else, you can check that both "ndarray" and "DecimalArray" are present. I'm sure you can figure something out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not only for DecimalArray, but for any kind of internal/external EA.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can add the check for just ndarray.

return self._aggregate_series_pure_python(obj, func)

def _aggregate_series_fast(self, obj, func):
Expand Down
34 changes: 34 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,37 @@ def test_array_ufunc_series_defer():

tm.assert_series_equal(r1, expected)
tm.assert_series_equal(r2, expected)


def test_groupby_agg():
# Ensure that the result of agg is inferred to be decimal dtype
# https://github.com/pandas-dev/pandas/issues/29141

data = make_data()[:5]
df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
expected = pd.Series(to_decimal([data[0], data[3]]))

result = df.groupby("id")["decimals"].agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)
result = df["decimals"].groupby(df["id"]).agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)


def test_groupby_agg_ea_method(monkeypatch):
# Ensure that the result of agg is inferred to be decimal dtype
# https://github.com/pandas-dev/pandas/issues/29141

def DecimalArray__my_sum(self):
return np.sum(np.array(self))

monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)

data = make_data()[:5]
df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))

result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
tm.assert_series_equal(result, expected, check_names=False)
s = pd.Series(DecimalArray(data))
result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum())
tm.assert_series_equal(result, expected, check_names=False)