Skip to content

Commit 46e5f66

Browse files
committed
move DataFrameGroupBy._transform_general logic to _set_result_index_ordered
1 parent 16544ea commit 46e5f66

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

pandas/core/groupby/generic.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -1397,6 +1397,9 @@ def _wrap_applied_output(self, keys, values, not_indexed_same=False):
13971397
def _transform_general(
13981398
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
13991399
):
1400+
"""
1401+
Transform with a non-str `func`.
1402+
"""
14001403
from pandas.core.reshape.concat import concat
14011404

14021405
applied = []
@@ -1455,17 +1458,11 @@ def _transform_general(
14551458
applied.append(r)
14561459
else:
14571460
applied.append(res)
1458-
14591461
concat_index = obj.columns if self.axis == 0 else obj.index
14601462
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
14611463
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
14621464
concatenated = concatenated.reindex(concat_index, axis=other_axis, copy=False)
1463-
if not self.dropna or not has_nan:
1464-
return self._set_result_index_ordered(concatenated)
1465-
else:
1466-
concatenated.sort_index(inplace=True)
1467-
concatenated.index = obj.index[concatenated.index.asi8]
1468-
return concatenated
1465+
return self._set_result_index_ordered(concatenated)
14691466

14701467
@Substitution(klass="DataFrame")
14711468
@Appender(_transform_template)

pandas/core/groupby/groupby.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -690,7 +690,10 @@ def _set_result_index_ordered(self, result):
690690
result.set_axis(index, axis=self.axis, inplace=True)
691691
result = result.sort_index(axis=self.axis)
692692

693-
result.set_axis(self.obj._get_axis(self.axis), axis=self.axis, inplace=True)
693+
result_idx, obj_idx = result.index, self.obj._get_axis(self.axis)
694+
intersection = result_idx.intersection(obj_idx)
695+
indexer = obj_idx if intersection.empty else intersection
696+
result.set_axis(indexer, axis=self.axis, inplace=True)
694697
return result
695698

696699
def _dir_additions(self):

0 commit comments

Comments
 (0)