Skip to content

Commit d90b73b

Browse files
authored
CLN/BUG: Clean/Simplify _wrap_applied_output (#35792)
1 parent 9f3e429 commit d90b73b

File tree

4 files changed

+34
-71
lines changed

4 files changed

+34
-71
lines changed

doc/source/whatsnew/v1.2.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ Groupby/resample/rolling
254254
- Bug in :meth:`DataFrameGroupBy.apply` that would some times throw an erroneous ``ValueError`` if the grouping axis had duplicate entries (:issue:`16646`)
255255
- Bug when combining methods :meth:`DataFrame.groupby` with :meth:`DataFrame.resample` and :meth:`DataFrame.interpolate` raising an ``TypeError`` (:issue:`35325`)
256256
- Bug in :meth:`DataFrameGroupBy.apply` where a non-nuisance grouping column would be dropped from the output columns if another groupby method was called before ``.apply()`` (:issue:`34656`)
257+
- Bug in :meth:`DataFrameGroupby.apply` would drop a :class:`CategoricalIndex` when grouped on. (:issue:`35792`)
257258

258259
Reshaping
259260
^^^^^^^^^

pandas/core/groupby/generic.py

+25-65
Original file line numberDiff line numberDiff line change
@@ -1197,57 +1197,25 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
11971197
if len(keys) == 0:
11981198
return self.obj._constructor(index=keys)
11991199

1200-
key_names = self.grouper.names
1201-
12021200
# GH12824
12031201
first_not_none = next(com.not_none(*values), None)
12041202

12051203
if first_not_none is None:
1206-
# GH9684. If all values are None, then this will throw an error.
1207-
# We'd prefer it return an empty dataframe.
1204+
# GH9684 - All values are None, return an empty frame.
12081205
return self.obj._constructor()
12091206
elif isinstance(first_not_none, DataFrame):
12101207
return self._concat_objects(keys, values, not_indexed_same=not_indexed_same)
12111208
else:
1212-
if len(self.grouper.groupings) > 1:
1213-
key_index = self.grouper.result_index
1214-
1215-
else:
1216-
ping = self.grouper.groupings[0]
1217-
if len(keys) == ping.ngroups:
1218-
key_index = ping.group_index
1219-
key_index.name = key_names[0]
1220-
1221-
key_lookup = Index(keys)
1222-
indexer = key_lookup.get_indexer(key_index)
1223-
1224-
# reorder the values
1225-
values = [values[i] for i in indexer]
1226-
1227-
# update due to the potential reorder
1228-
first_not_none = next(com.not_none(*values), None)
1229-
else:
1230-
1231-
key_index = Index(keys, name=key_names[0])
1232-
1233-
# don't use the key indexer
1234-
if not self.as_index:
1235-
key_index = None
1209+
key_index = self.grouper.result_index if self.as_index else None
12361210

1237-
# make Nones an empty object
1238-
if first_not_none is None:
1239-
return self.obj._constructor()
1240-
elif isinstance(first_not_none, NDFrame):
1211+
if isinstance(first_not_none, Series):
12411212

12421213
# this is to silence a DeprecationWarning
12431214
# TODO: Remove when default dtype of empty Series is object
12441215
kwargs = first_not_none._construct_axes_dict()
1245-
if isinstance(first_not_none, Series):
1246-
backup = create_series_with_explicit_dtype(
1247-
**kwargs, dtype_if_empty=object
1248-
)
1249-
else:
1250-
backup = first_not_none._constructor(**kwargs)
1216+
backup = create_series_with_explicit_dtype(
1217+
**kwargs, dtype_if_empty=object
1218+
)
12511219

12521220
values = [x if (x is not None) else backup for x in values]
12531221

@@ -1256,7 +1224,7 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12561224
if isinstance(v, (np.ndarray, Index, Series)) or not self.as_index:
12571225
if isinstance(v, Series):
12581226
applied_index = self._selected_obj._get_axis(self.axis)
1259-
all_indexed_same = all_indexes_same([x.index for x in values])
1227+
all_indexed_same = all_indexes_same((x.index for x in values))
12601228
singular_series = len(values) == 1 and applied_index.nlevels == 1
12611229

12621230
# GH3596
@@ -1288,7 +1256,6 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
12881256
# GH 8467
12891257
return self._concat_objects(keys, values, not_indexed_same=True)
12901258

1291-
if self.axis == 0 and isinstance(v, ABCSeries):
12921259
# GH6124 if the list of Series have a consistent name,
12931260
# then propagate that name to the result.
12941261
index = v.index.copy()
@@ -1301,34 +1268,27 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13011268
if len(names) == 1:
13021269
index.name = list(names)[0]
13031270

1304-
# normally use vstack as its faster than concat
1305-
# and if we have mi-columns
1306-
if (
1307-
isinstance(v.index, MultiIndex)
1308-
or key_index is None
1309-
or isinstance(key_index, MultiIndex)
1310-
):
1311-
stacked_values = np.vstack([np.asarray(v) for v in values])
1312-
result = self.obj._constructor(
1313-
stacked_values, index=key_index, columns=index
1314-
)
1315-
else:
1316-
# GH5788 instead of stacking; concat gets the
1317-
# dtypes correct
1318-
from pandas.core.reshape.concat import concat
1319-
1320-
result = concat(
1321-
values,
1322-
keys=key_index,
1323-
names=key_index.names,
1324-
axis=self.axis,
1325-
).unstack()
1326-
result.columns = index
1327-
elif isinstance(v, ABCSeries):
1271+
# Combine values
1272+
# vstack+constructor is faster than concat and handles MI-columns
13281273
stacked_values = np.vstack([np.asarray(v) for v in values])
1274+
1275+
if self.axis == 0:
1276+
index = key_index
1277+
columns = v.index.copy()
1278+
if columns.name is None:
1279+
# GH6124 - propagate name of Series when it's consistent
1280+
names = {v.name for v in values}
1281+
if len(names) == 1:
1282+
columns.name = list(names)[0]
1283+
else:
1284+
index = v.index
1285+
columns = key_index
1286+
stacked_values = stacked_values.T
1287+
13291288
result = self.obj._constructor(
1330-
stacked_values.T, index=v.index, columns=key_index
1289+
stacked_values, index=index, columns=columns
13311290
)
1291+
13321292
elif not self.as_index:
13331293
# We add grouping column below, so create a frame here
13341294
result = DataFrame(

pandas/core/indexes/api.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -297,15 +297,16 @@ def all_indexes_same(indexes):
297297
298298
Parameters
299299
----------
300-
indexes : list of Index objects
300+
indexes : iterable of Index objects
301301
302302
Returns
303303
-------
304304
bool
305305
True if all indexes contain the same elements, False otherwise.
306306
"""
307-
first = indexes[0]
308-
for index in indexes[1:]:
307+
itr = iter(indexes)
308+
first = next(itr)
309+
for index in itr:
309310
if not first.equals(index):
310311
return False
311312
return True

pandas/tests/groupby/test_apply.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -861,13 +861,14 @@ def test_apply_multi_level_name(category):
861861
b = [1, 2] * 5
862862
if category:
863863
b = pd.Categorical(b, categories=[1, 2, 3])
864+
expected_index = pd.CategoricalIndex([1, 2], categories=[1, 2, 3], name="B")
865+
else:
866+
expected_index = pd.Index([1, 2], name="B")
864867
df = pd.DataFrame(
865868
{"A": np.arange(10), "B": b, "C": list(range(10)), "D": list(range(10))}
866869
).set_index(["A", "B"])
867870
result = df.groupby("B").apply(lambda x: x.sum())
868-
expected = pd.DataFrame(
869-
{"C": [20, 25], "D": [20, 25]}, index=pd.Index([1, 2], name="B")
870-
)
871+
expected = pd.DataFrame({"C": [20, 25], "D": [20, 25]}, index=expected_index)
871872
tm.assert_frame_equal(result, expected)
872873
assert df.index.names == ["A", "B"]
873874

0 commit comments

Comments
 (0)