Skip to content

BUG/TST: ensure groupby.agg preserves extension dtype #29144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
8 changes: 7 additions & 1 deletion pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,13 @@ def agg_series(self, obj, func):
pass
else:
raise
return self._aggregate_series_pure_python(obj, func)
except TypeError as err:
if "ndarray" in str(err):
# raised in libreduction if obj's values is no ndarray
pass
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jbrockmendel the whole block is now basically an except (..): pass (so I could make it much shorter), but you might have put in those specific checks as pointers to what to clean up later?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so something is being done with a DecimalArray that raises a TypeError? and i guess by not-catching it here it is getting caught somewhere above that isn't doing re-casting appropriately? can we track down where that is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like in _aggregate_series_fast there is a call to libreduction that tries to assign this DecimalArray to a name that libreduction has typed as an ndarray, which raises TypeError

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But that is expected I think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we get to decide what is "expected"; maybe i dont understand what you're getting at

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually, can you do the same str(err) checking so that we only let through the relevant TypeErrors? One of the other PRs afoot is specifically targeting other TypeErrors

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an error message coming from another library though (cython?). Do we have guarantee that that is stable?
(the other two errors that are catched that way are raised by our own code in reduction.pyx)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have guarantee that that is stable?

If nothing else, you can check that both "ndarray" and "DecimalArray" are present. I'm sure you can figure something out.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not only for DecimalArray, but for any kind of internal/external EA.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But can add the check for just ndarray.

else:
raise
return self._aggregate_series_pure_python(obj, func)

def _aggregate_series_fast(self, obj, func):
func = self._is_builtin_func(func)
Expand Down
52 changes: 52 additions & 0 deletions pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,55 @@ def test_array_ufunc_series_defer():

tm.assert_series_equal(r1, expected)
tm.assert_series_equal(r2, expected)


def test_groupby_agg():
# Ensure that the result of agg is inferred to be decimal dtype
# https://github.com/pandas-dev/pandas/issues/29141

data = make_data()[:5]
df = pd.DataFrame(
{"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
)

# single key, selected column
expected = pd.Series(to_decimal([data[0], data[3]]))
result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)
result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)

# multiple keys, selected column
expected = pd.Series(
to_decimal([data[0], data[1], data[3]]),
index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]),
)
result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)
result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0])
tm.assert_series_equal(result, expected, check_names=False)

# multiple columns
expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])})
result = df.groupby("id1").agg(lambda x: x.iloc[0])
tm.assert_frame_equal(result, expected, check_names=False)


def test_groupby_agg_ea_method(monkeypatch):
# Ensure that the result of agg is inferred to be decimal dtype
# https://github.com/pandas-dev/pandas/issues/29141

def DecimalArray__my_sum(self):
return np.sum(np.array(self))

monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)

data = make_data()[:5]
df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))

result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
tm.assert_series_equal(result, expected, check_names=False)
s = pd.Series(DecimalArray(data))
result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum())
tm.assert_series_equal(result, expected, check_names=False)