Skip to content

Commit 2baab0c

Browse files
committed
move DataFrameGroupBy._transform_general logic to _set_result_index_ordered
1 parent 16544ea commit 2baab0c

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

pandas/core/groupby/generic.py

+2-7
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def _transform_general(
517517
"""
518518
Transform with a non-str `func`.
519519
"""
520+
print("Calling SeriesGroupBy._transform_general")
520521
if maybe_use_numba(engine):
521522
numba_func, cache_key = generate_numba_func(
522523
func, engine_kwargs, kwargs, "groupby_transform"
@@ -1455,17 +1456,11 @@ def _transform_general(
14551456
applied.append(r)
14561457
else:
14571458
applied.append(res)
1458-
14591459
concat_index = obj.columns if self.axis == 0 else obj.index
14601460
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
14611461
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
14621462
concatenated = concatenated.reindex(concat_index, axis=other_axis, copy=False)
1463-
if not self.dropna or not has_nan:
1464-
return self._set_result_index_ordered(concatenated)
1465-
else:
1466-
concatenated.sort_index(inplace=True)
1467-
concatenated.index = obj.index[concatenated.index.asi8]
1468-
return concatenated
1463+
return self._set_result_index_ordered(concatenated)
14691464

14701465
@Substitution(klass="DataFrame")
14711466
@Appender(_transform_template)

pandas/core/groupby/groupby.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -682,15 +682,17 @@ def _set_group_selection(self):
682682
def _set_result_index_ordered(self, result):
683683
# set the result index on the passed values object and
684684
# return the new object, xref 8046
685-
686685
# the values/counts are repeated according to the group index
687686
# shortcut if we have an already ordered grouper
688687
if not self.grouper.is_monotonic:
689688
index = Index(np.concatenate(self._get_indices(self.grouper.result_index)))
690689
result.set_axis(index, axis=self.axis, inplace=True)
691690
result = result.sort_index(axis=self.axis)
692691

693-
result.set_axis(self.obj._get_axis(self.axis), axis=self.axis, inplace=True)
692+
result_idx, obj_idx = result.index, self.obj._get_axis(self.axis)
693+
intersection = result_idx.intersection(obj_idx)
694+
indexer = obj_idx if intersection.empty else intersection
695+
result.set_axis(indexer, axis=self.axis, inplace=True)
694696
return result
695697

696698
def _dir_additions(self):

0 commit comments

Comments
 (0)