Skip to content

Commit ce595b9

Browse files
committed
BUG: incorrect broadcasting that could casuse dtype coercion in a groupby-transform
closes #14457
1 parent 2e77536 commit ce595b9

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

doc/source/whatsnew/v0.19.1.txt

+4
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,13 @@ Bug Fixes
4343
- Bug in string indexing against data with ``object`` ``Index`` may raise ``AttributeError`` (:issue:`14424`)
4444
- Corrrecly raise ``ValueError`` on empty input to ``pd.eval()`` and ``df.query()`` (:issue:`13139`)
4545

46+
4647
- Bug in ``RangeIndex.intersection`` when result is a empty set (:issue:`14364`).
4748
- Bug in union of differences from a ``DatetimeIndex`; this is a regression in 0.19.0 from 0.18.1 (:issue:`14323`)
4849

50+
- Bug in groupby-transform broadcasting that could cause incorrect dtype coercion (:issue:`14457`)
51+
52+
4953
- Bug in ``Series.__setitem__`` which allowed mutating read-only arrays (:issue:`14359`).
5054

5155

pandas/core/groupby.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -3454,7 +3454,6 @@ def _transform_general(self, func, *args, **kwargs):
34543454
from pandas.tools.merge import concat
34553455

34563456
applied = []
3457-
34583457
obj = self._obj_with_exclusions
34593458
gen = self.grouper.get_iterator(obj, axis=self.axis)
34603459
fast_path, slow_path = self._define_paths(func, *args, **kwargs)
@@ -3475,14 +3474,24 @@ def _transform_general(self, func, *args, **kwargs):
34753474
else:
34763475
res = path(group)
34773476

3478-
# broadcasting
34793477
if isinstance(res, Series):
3480-
if res.index.is_(obj.index):
3481-
group.T.values[:] = res
3478+
3479+
# we need to broadcast across the
3480+
# other dimension; this will preserve dtypes
3481+
# GH14457
3482+
if not np.prod(group.shape):
3483+
continue
3484+
elif res.index.is_(obj.index):
3485+
r = concat([res] * len(group.columns), axis=1)
3486+
r.columns = group.columns
3487+
r.index = group.index
34823488
else:
3483-
group.values[:] = res
3489+
r = DataFrame(
3490+
np.concatenate([res.values] * len(group.index)
3491+
).reshape(group.shape),
3492+
columns=group.columns, index=group.index)
34843493

3485-
applied.append(group)
3494+
applied.append(r)
34863495
else:
34873496
applied.append(res)
34883497

pandas/tests/test_groupby.py

+12
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,18 @@ def nsum(x):
13361336
for result in results:
13371337
assert_series_equal(result, expected, check_names=False)
13381338

1339+
def test_transform_coercion(self):
1340+
1341+
# 14457
1342+
# when we are transforming be sure to not coerce
1343+
# via assignment
1344+
df = pd.DataFrame(dict(A=['a', 'a'], B=[0, 1]))
1345+
g = df.groupby('A')
1346+
1347+
expected = g.transform(np.mean)
1348+
result = g.transform(lambda x: np.mean(x))
1349+
assert_frame_equal(result, expected)
1350+
13391351
def test_with_na(self):
13401352
index = Index(np.arange(10))
13411353

0 commit comments

Comments
 (0)