Skip to content

Commit daff98c

Browse files
rhshadrachJulianWgs
authored andcommitted
CLN: Simplify gathering of results in aggregate (pandas-dev#37227)
1 parent 8ed71da commit daff98c

File tree

4 files changed

+21
-50
lines changed

4 files changed

+21
-50
lines changed

pandas/core/aggregation.py

+15-46
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131

3232
from pandas.core.dtypes.cast import is_nested_object
3333
from pandas.core.dtypes.common import is_dict_like, is_list_like
34-
from pandas.core.dtypes.generic import ABCDataFrame, ABCSeries
34+
from pandas.core.dtypes.generic import ABCDataFrame, ABCNDFrame, ABCSeries
3535

3636
from pandas.core.base import DataError, SpecificationError
3737
import pandas.core.common as com
@@ -621,58 +621,27 @@ def aggregate(obj, arg: AggFuncType, *args, **kwargs):
621621
# set the final keys
622622
keys = list(arg.keys())
623623

624-
# combine results
625-
626-
def is_any_series() -> bool:
627-
# return a boolean if we have *any* nested series
628-
return any(isinstance(r, ABCSeries) for r in results.values())
629-
630-
def is_any_frame() -> bool:
631-
# return a boolean if we have *any* nested series
632-
return any(isinstance(r, ABCDataFrame) for r in results.values())
633-
634-
if isinstance(results, list):
635-
return concat(results, keys=keys, axis=1, sort=True), True
636-
637-
elif is_any_frame():
638-
# we have a dict of DataFrames
639-
# return a MI DataFrame
624+
# Avoid making two isinstance calls in all and any below
625+
is_ndframe = [isinstance(r, ABCNDFrame) for r in results.values()]
640626

627+
# combine results
628+
if all(is_ndframe):
641629
keys_to_use = [k for k in keys if not results[k].empty]
642630
# Have to check, if at least one DataFrame is not empty.
643631
keys_to_use = keys_to_use if keys_to_use != [] else keys
644-
return (
645-
concat([results[k] for k in keys_to_use], keys=keys_to_use, axis=1),
646-
True,
632+
axis = 0 if isinstance(obj, ABCSeries) else 1
633+
result = concat({k: results[k] for k in keys_to_use}, axis=axis)
634+
elif any(is_ndframe):
635+
# There is a mix of NDFrames and scalars
636+
raise ValueError(
637+
"cannot perform both aggregation "
638+
"and transformation operations "
639+
"simultaneously"
647640
)
641+
else:
642+
from pandas import Series
648643

649-
elif isinstance(obj, ABCSeries) and is_any_series():
650-
651-
# we have a dict of Series
652-
# return a MI Series
653-
try:
654-
result = concat(results)
655-
except TypeError as err:
656-
# we want to give a nice error here if
657-
# we have non-same sized objects, so
658-
# we don't automatically broadcast
659-
660-
raise ValueError(
661-
"cannot perform both aggregation "
662-
"and transformation operations "
663-
"simultaneously"
664-
) from err
665-
666-
return result, True
667-
668-
# fall thru
669-
from pandas import DataFrame, Series
670-
671-
try:
672-
result = DataFrame(results)
673-
except ValueError:
674644
# we have a dict of scalars
675-
676645
# GH 36212 use name only if obj is a series
677646
if obj.ndim == 1:
678647
obj = cast("Series", obj)

pandas/core/dtypes/generic.py

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def _check(cls, inst) -> bool:
5353
},
5454
)
5555

56+
ABCNDFrame = create_pandas_abc_type("ABCNDFrame", "_typ", ("series", "dataframe"))
5657
ABCSeries = create_pandas_abc_type("ABCSeries", "_typ", ("series",))
5758
ABCDataFrame = create_pandas_abc_type("ABCDataFrame", "_typ", ("dataframe",))
5859

pandas/core/frame.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7442,9 +7442,9 @@ def _gotitem(
74427442
74437443
>>> df.agg({'A' : ['sum', 'min'], 'B' : ['min', 'max']})
74447444
A B
7445-
max NaN 8.0
7446-
min 1.0 2.0
74477445
sum 12.0 NaN
7446+
min 1.0 2.0
7447+
max NaN 8.0
74487448
74497449
Aggregate different functions over the columns and rename the index of the resulting
74507450
DataFrame.

pandas/tests/frame/apply/test_frame_apply.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1254,7 +1254,7 @@ def test_agg_reduce(self, axis, float_frame):
12541254
# dict input with lists with multiple
12551255
func = dict([(name1, ["mean", "sum"]), (name2, ["sum", "max"])])
12561256
result = float_frame.agg(func, axis=axis)
1257-
expected = DataFrame(
1257+
expected = pd.concat(
12581258
dict(
12591259
[
12601260
(
@@ -1278,7 +1278,8 @@ def test_agg_reduce(self, axis, float_frame):
12781278
),
12791279
),
12801280
]
1281-
)
1281+
),
1282+
axis=1,
12821283
)
12831284
expected = expected.T if axis in {1, "columns"} else expected
12841285
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)