Skip to content

Commit c4f6c1b

Browse files
Revert "CLN/BUG: Clean/Simplify _wrap_applied_output (pandas-dev#35792)"
This reverts commit 1dc0795.
1 parent 1dc0795 commit c4f6c1b

File tree

4 files changed

+71
-34
lines changed

4 files changed

+71
-34
lines changed

doc/source/whatsnew/v1.1.4.rst

-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ Bug fixes
3535
- Bug in :meth:`Series.isin` and :meth:`DataFrame.isin` raising a ``ValueError`` when the target was read-only (:issue:`37174`)
3636
- Bug in :meth:`GroupBy.fillna` that introduced a performance regression after 1.0.5 (:issue:`36757`)
3737
- Bug in :meth:`DataFrame.info` was raising a ``KeyError`` when the DataFrame has integer column names (:issue:`37245`)
38-
- Bug in :meth:`DataFrameGroupby.apply` would drop a :class:`CategoricalIndex` when grouped on (:issue:`35792`)
3938

4039
.. ---------------------------------------------------------------------------
4140

pandas/core/groupby/generic.py

+65-25
Original file line numberDiff line numberDiff line change
@@ -1219,25 +1219,57 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12191219
if len(keys) == 0:
12201220
return self.obj._constructor(index=keys)
12211221

1222+
key_names = self.grouper.names
1223+
12221224
# GH12824
12231225
first_not_none = next(com.not_none(*values), None)
12241226

12251227
if first_not_none is None:
1226-
# GH9684 - All values are None, return an empty frame.
1228+
# GH9684. If all values are None, then this will throw an error.
1229+
# We'd prefer it return an empty dataframe.
12271230
return self.obj._constructor()
12281231
elif isinstance(first_not_none, DataFrame):
12291232
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
12301233
else:
1231-
key_index = self.grouper.result_index if self.as_index else None
1234+
if len(self.grouper.groupings) > 1:
1235+
key_index = self.grouper.result_index
1236+
1237+
else:
1238+
ping = self.grouper.groupings[0]
1239+
if len(keys) == ping.ngroups:
1240+
key_index = ping.group_index
1241+
key_index.name = key_names[0]
1242+
1243+
key_lookup = Index(keys)
1244+
indexer = key_lookup.get_indexer(key_index)
1245+
1246+
# reorder the values
1247+
values = [values[i] for i in indexer]
1248+
1249+
# update due to the potential reorder
1250+
first_not_none = next(com.not_none(*values), None)
1251+
else:
1252+
1253+
key_index = Index(keys, name=key_names[0])
1254+
1255+
# don't use the key indexer
1256+
if not self.as_index:
1257+
key_index = None
12321258

1233-
if isinstance(first_not_none, Series):
1259+
# make Nones an empty object
1260+
if first_not_none is None:
1261+
return self.obj._constructor()
1262+
elif isinstance(first_not_none, NDFrame):
12341263

12351264
# this is to silence a DeprecationWarning
12361265
# TODO: Remove when default dtype of empty Series is object
12371266
kwargs = first_not_none._construct_axes_dict()
1238-
backup = create_series_with_explicit_dtype(
1239-
**kwargs, dtype_if_empty=object
1240-
)
1267+
if isinstance(first_not_none, Series):
1268+
backup = create_series_with_explicit_dtype(
1269+
**kwargs, dtype_if_empty=object
1270+
)
1271+
else:
1272+
backup = first_not_none._constructor(**kwargs)
12411273

12421274
values = [x if (x is not None) else backup for x in values]
12431275

@@ -1246,7 +1278,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12461278
if isinstance(v, (np.ndarray, Index, Series)) or not self.as_index:
12471279
if isinstance(v, Series):
12481280
applied_index = self._selected_obj._get_axis(self.axis)
1249-
all_indexed_same = all_indexes_same((x.index for x in values))
1281+
all_indexed_same = all_indexes_same([x.index for x in values])
12501282
singular_series = len(values) == 1 and applied_index.nlevels == 1
12511283

12521284
# GH3596
@@ -1278,6 +1310,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12781310
# GH 8467
12791311
return self._concat_objects(keys, values, not_indexed_same=True)
12801312

1313+
if self.axis == 0 and isinstance(v, ABCSeries):
12811314
# GH6124 if the list of Series have a consistent name,
12821315
# then propagate that name to the result.
12831316
index = v.index.copy()
@@ -1290,27 +1323,34 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12901323
if len(names) == 1:
12911324
index.name = list(names)[0]
12921325

1293-
# Combine values
1294-
# vstack+constructor is faster than concat and handles MI-columns
1295-
stacked_values = np.vstack([np.asarray(v) for v in values])
1296-
1297-
if self.axis == 0:
1298-
index = key_index
1299-
columns = v.index.copy()
1300-
if columns.name is None:
1301-
# GH6124 - propagate name of Series when it's consistent
1302-
names = {v.name for v in values}
1303-
if len(names) == 1:
1304-
columns.name = list(names)[0]
1326+
# normally use vstack as its faster than concat
1327+
# and if we have mi-columns
1328+
if (
1329+
isinstance(v.index, MultiIndex)
1330+
or key_index is None
1331+
or isinstance(key_index, MultiIndex)
1332+
):
1333+
stacked_values = np.vstack([np.asarray(v) for v in values])
1334+
result = self.obj._constructor(
1335+
stacked_values, index=key_index, columns=index
1336+
)
13051337
else:
1306-
index = v.index
1307-
columns = key_index
1308-
stacked_values = stacked_values.T
1309-
1338+
# GH5788 instead of stacking; concat gets the
1339+
# dtypes correct
1340+
from pandas.core.reshape.concat import concat
1341+
1342+
result = concat(
1343+
values,
1344+
keys=key_index,
1345+
names=key_index.names,
1346+
axis=self.axis,
1347+
).unstack()
1348+
result.columns = index
1349+
elif isinstance(v, ABCSeries):
1350+
stacked_values = np.vstack([np.asarray(v) for v in values])
13101351
result = self.obj._constructor(
1311-
stacked_values, index=index, columns=columns
1352+
stacked_values.T, index=v.index, columns=key_index
13121353
)
1313-
13141354
elif not self.as_index:
13151355
# We add grouping column below, so create a frame here
13161356
result = DataFrame(

pandas/core/indexes/api.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -298,16 +298,15 @@ def all_indexes_same(indexes):
298298
299299
Parameters
300300
----------
301-
indexes : iterable of Index objects
301+
indexes : list of Index objects
302302
303303
Returns
304304
-------
305305
bool
306306
True if all indexes contain the same elements, False otherwise.
307307
"""
308-
itr = iter(indexes)
309-
first = next(itr)
310-
for index in itr:
308+
first = indexes[0]
309+
for index in indexes[1:]:
311310
if not first.equals(index):
312311
return False
313312
return True

pandas/tests/groupby/test_apply.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -868,14 +868,13 @@ def test_apply_multi_level_name(category):
868868
b = [1, 2] * 5
869869
if category:
870870
b = pd.Categorical(b, categories=[1, 2, 3])
871-
expected_index = pd.CategoricalIndex([1, 2], categories=[1, 2, 3], name="B")
872-
else:
873-
expected_index = pd.Index([1, 2], name="B")
874871
df = pd.DataFrame(
875872
{"A": np.arange(10), "B": b, "C": list(range(10)), "D": list(range(10))}
876873
).set_index(["A", "B"])
877874
result = df.groupby("B").apply(lambda x: x.sum())
878-
expected = pd.DataFrame({"C": [20, 25], "D": [20, 25]}, index=expected_index)
875+
expected = pd.DataFrame(
876+
{"C": [20, 25], "D": [20, 25]}, index=pd.Index([1, 2], name="B")
877+
)
879878
tm.assert_frame_equal(result, expected)
880879
assert df.index.names == ["A", "B"]
881880

0 commit comments

Comments
 (0)