Skip to content

Commit 091bb06

Browse files
jorisvandenbosscheBlake Hawkins
authored and
Blake Hawkins
committed
BUG/TST: ensure groupby.agg preserves extension dtype (pandas-dev#29144)
1 parent dd5d1c5 commit 091bb06

File tree

2 files changed

+59
-1
lines changed

2 files changed

+59
-1
lines changed

pandas/core/groupby/ops.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,13 @@ def agg_series(self, obj, func):
672672
pass
673673
else:
674674
raise
675-
return self._aggregate_series_pure_python(obj, func)
675+
except TypeError as err:
676+
if "ndarray" in str(err):
677+
# raised in libreduction if obj's values is no ndarray
678+
pass
679+
else:
680+
raise
681+
return self._aggregate_series_pure_python(obj, func)
676682

677683
def _aggregate_series_fast(self, obj, func):
678684
func = self._is_builtin_func(func)

pandas/tests/extension/decimal/test_decimal.py

+52
Original file line numberDiff line numberDiff line change
@@ -426,3 +426,55 @@ def test_array_ufunc_series_defer():
426426

427427
tm.assert_series_equal(r1, expected)
428428
tm.assert_series_equal(r2, expected)
429+
430+
431+
def test_groupby_agg():
432+
# Ensure that the result of agg is inferred to be decimal dtype
433+
# https://github.com/pandas-dev/pandas/issues/29141
434+
435+
data = make_data()[:5]
436+
df = pd.DataFrame(
437+
{"id1": [0, 0, 0, 1, 1], "id2": [0, 1, 0, 1, 1], "decimals": DecimalArray(data)}
438+
)
439+
440+
# single key, selected column
441+
expected = pd.Series(to_decimal([data[0], data[3]]))
442+
result = df.groupby("id1")["decimals"].agg(lambda x: x.iloc[0])
443+
tm.assert_series_equal(result, expected, check_names=False)
444+
result = df["decimals"].groupby(df["id1"]).agg(lambda x: x.iloc[0])
445+
tm.assert_series_equal(result, expected, check_names=False)
446+
447+
# multiple keys, selected column
448+
expected = pd.Series(
449+
to_decimal([data[0], data[1], data[3]]),
450+
index=pd.MultiIndex.from_tuples([(0, 0), (0, 1), (1, 1)]),
451+
)
452+
result = df.groupby(["id1", "id2"])["decimals"].agg(lambda x: x.iloc[0])
453+
tm.assert_series_equal(result, expected, check_names=False)
454+
result = df["decimals"].groupby([df["id1"], df["id2"]]).agg(lambda x: x.iloc[0])
455+
tm.assert_series_equal(result, expected, check_names=False)
456+
457+
# multiple columns
458+
expected = pd.DataFrame({"id2": [0, 1], "decimals": to_decimal([data[0], data[3]])})
459+
result = df.groupby("id1").agg(lambda x: x.iloc[0])
460+
tm.assert_frame_equal(result, expected, check_names=False)
461+
462+
463+
def test_groupby_agg_ea_method(monkeypatch):
464+
# Ensure that the result of agg is inferred to be decimal dtype
465+
# https://github.com/pandas-dev/pandas/issues/29141
466+
467+
def DecimalArray__my_sum(self):
468+
return np.sum(np.array(self))
469+
470+
monkeypatch.setattr(DecimalArray, "my_sum", DecimalArray__my_sum, raising=False)
471+
472+
data = make_data()[:5]
473+
df = pd.DataFrame({"id": [0, 0, 0, 1, 1], "decimals": DecimalArray(data)})
474+
expected = pd.Series(to_decimal([data[0] + data[1] + data[2], data[3] + data[4]]))
475+
476+
result = df.groupby("id")["decimals"].agg(lambda x: x.values.my_sum())
477+
tm.assert_series_equal(result, expected, check_names=False)
478+
s = pd.Series(DecimalArray(data))
479+
result = s.groupby(np.array([0, 0, 0, 1, 1])).agg(lambda x: x.values.my_sum())
480+
tm.assert_series_equal(result, expected, check_names=False)

0 commit comments

Comments
 (0)