Skip to content

Commit 30580e7

Browse files
jcristjreback
authored andcommitted
Groupby transform preserves output dtype
Previously `transform` output was always the same dtype as the groupby object. This allows the output dtype to differ from the input. Fixes #9807.
1 parent f8ca3b7 commit 30580e7

File tree

3 files changed

+12
-1
lines changed

3 files changed

+12
-1
lines changed

doc/source/whatsnew/v0.16.1.txt

+3
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,9 @@ Bug Fixes
134134

135135
- Bug in unequal comparisons between a ``Series`` of dtype `"category"` and a scalar (e.g. ``Series(Categorical(list("abc"), categories=list("cba"), ordered=True)) > "b"``, which wouldn't use the order of the categories but use the lexicographical order. (:issue:`9848`)
136136

137+
137138
- Bug in unequal comparisons between categorical data and a scalar, which was not in the categories (e.g. ``Series(Categorical(list("abc"), ordered=True)) > "d"``. This returned ``False`` for all elements, but now raises a ``TypeError``. Equality comparisons also now return ``False`` for ``==`` and ``True`` for ``!=``. (:issue:`9848`)
138139

139140
- Bug in ``MultiIndex.sortlevel()`` results in unicode level name breaks (:issue:`9875`)
141+
142+
- Bug in which ``groupby.transform`` incorrectly enforced output dtypes to match input dtypes. (:issue:`9807`)

pandas/core/groupby.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3006,7 +3006,7 @@ def transform(self, func, *args, **kwargs):
30063006
if ((not isinstance(obj.index,MultiIndex) and
30073007
type(result.index) != type(obj.index)) or
30083008
len(result.index) != len(obj.index)):
3009-
results = obj.values.copy()
3009+
results = np.empty_like(obj.values, result.values.dtype)
30103010
indices = self.indices
30113011
for (name, group), (i, row) in zip(self, result.iterrows()):
30123012
indexer = indices[name]

pandas/tests/test_groupby.py

+8
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,14 @@ def test_transform_broadcast(self):
10031003
for idx in gp.index:
10041004
assert_fp_equal(res.xs(idx), agged[idx])
10051005

1006+
def test_transform_dtype(self):
1007+
# GH 9807
1008+
# Check transform dtype output is preserved
1009+
df = DataFrame([[1, 3], [2, 3]])
1010+
result = df.groupby(1).transform('mean')
1011+
expected = DataFrame([[1.5], [1.5]])
1012+
assert_frame_equal(result, expected)
1013+
10061014
def test_transform_bug(self):
10071015
# GH 5712
10081016
# transforming on a datetime column

0 commit comments

Comments
 (0)