Skip to content

Commit 16544ea

Browse files
committed
fix DataFrameGroupBy._transform_general
1 parent 77e7fc7 commit 16544ea

File tree

3 files changed

+20
-15
lines changed

3 files changed

+20
-15
lines changed

pandas/core/groupby/generic.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,9 @@ def _transform_general(
14091409
else:
14101410
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
14111411

1412+
has_nan = False
14121413
for name, group in gen:
1414+
has_nan = has_nan or isna(name)
14131415
object.__setattr__(group, "name", name)
14141416

14151417
if maybe_use_numba(engine):
@@ -1418,9 +1420,8 @@ def _transform_general(
14181420
if cache_key not in NUMBA_FUNC_CACHE:
14191421
NUMBA_FUNC_CACHE[cache_key] = numba_func
14201422
# Return the result as a DataFrame for concatenation later
1421-
res = self.obj._constructor(
1422-
res, index=group.index, columns=group.columns
1423-
)
1423+
indexer = self._get_index(name) if self.dropna else group.index
1424+
res = self.obj._constructor(res, index=indexer, columns=group.columns)
14241425
else:
14251426
# Try slow path and fast path.
14261427
try:
@@ -1459,7 +1460,12 @@ def _transform_general(
14591460
other_axis = 1 if self.axis == 0 else 0 # switches between 0 & 1
14601461
concatenated = concat(applied, axis=self.axis, verify_integrity=False)
14611462
concatenated = concatenated.reindex(concat_index, axis=other_axis, copy=False)
1462-
return self._set_result_index_ordered(concatenated)
1463+
if not self.dropna or not has_nan:
1464+
return self._set_result_index_ordered(concatenated)
1465+
else:
1466+
concatenated.sort_index(inplace=True)
1467+
concatenated.index = obj.index[concatenated.index.asi8]
1468+
return concatenated
14631469

14641470
@Substitution(klass="DataFrame")
14651471
@Appender(_transform_template)

pandas/tests/groupby/test_apply.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ def trans2(group):
404404

405405

406406
def test_apply_transform(ts):
407-
grouped = ts.groupby(lambda x: x.month)
407+
grouped = ts.groupby(lambda x: x.month, dropna=False)
408408
result = grouped.apply(lambda x: x * 2)
409409
expected = grouped.transform(lambda x: x * 2)
410410
tm.assert_series_equal(result, expected)

pandas/tests/groupby/transform/test_transform.py

+9-10
Original file line numberDiff line numberDiff line change
@@ -309,11 +309,11 @@ def test_transform_multiple(ts):
309309
def test_dispatch_transform(tsframe):
310310
df = tsframe[::5].reindex(tsframe.index)
311311

312-
grouped = df.groupby(lambda x: x.month, dropna=False)
312+
grouped = df.groupby(lambda x: x.month)
313313

314314
filled = grouped.fillna(method="pad")
315315
fillit = lambda x: x.fillna(method="pad")
316-
expected = df.groupby(lambda x: x.month, dropna=False).transform(fillit)
316+
expected = df.groupby(lambda x: x.month).transform(fillit)
317317
tm.assert_frame_equal(filled, expected)
318318

319319

@@ -412,10 +412,10 @@ def nsum(x):
412412
return np.nansum(x)
413413

414414
results = [
415-
df.groupby("col1", dropna=False).transform(sum)["col2"],
416-
df.groupby("col1", dropna=False)["col2"].transform(sum),
417-
df.groupby("col1", dropna=False).transform(nsum)["col2"],
418-
df.groupby("col1", dropna=False)["col2"].transform(nsum),
415+
df.groupby("col1").transform(sum)["col2"],
416+
df.groupby("col1")["col2"].transform(sum),
417+
df.groupby("col1").transform(nsum)["col2"],
418+
df.groupby("col1")["col2"].transform(nsum),
419419
]
420420
for result in results:
421421
tm.assert_series_equal(result, expected, check_names=False)
@@ -448,9 +448,7 @@ def test_groupby_transform_with_int():
448448
)
449449
)
450450
with np.errstate(all="ignore"):
451-
result = df.groupby("A", dropna=False).transform(
452-
lambda x: (x - x.mean()) / x.std()
453-
)
451+
result = df.groupby("A").transform(lambda x: (x - x.mean()) / x.std())
454452
expected = DataFrame(
455453
dict(B=np.nan, C=Series([-1, 0, 1, -1, 0, 1], dtype="float64"))
456454
)
@@ -614,7 +612,8 @@ def test_cython_transform_series(op, args, targop):
614612

615613
# series
616614
for data in [s, s_missing]:
617-
expected = data.groupby(labels, dropna=False).transform(targop)
615+
# print(data.head())
616+
expected = data.groupby(labels).transform(targop)
618617

619618
tm.assert_series_equal(expected, data.groupby(labels).transform(op, *args))
620619
tm.assert_series_equal(expected, getattr(data.groupby(labels), op)(*args))

0 commit comments

Comments
 (0)