Skip to content

Commit db29fda

Browse files
jeetjitsujreback
authored andcommitted
BUG: GH16875
Fix inconsistency in groupby trnaformations
1 parent 3524edb commit db29fda

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

doc/source/whatsnew/v0.21.0.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ Groupby/Resample/Rolling
175175
- Bug in ``DataFrame.resample(...).size()`` where an empty ``DataFrame`` did not return a ``Series`` (:issue:`14962`)
176176
- Bug in :func:`infer_freq` causing indices with 2-day gaps during the working week to be wrongly inferred as business daily (:issue:`16624`)
177177
- Bug in ``.rolling(...).quantile()`` which incorrectly used different defaults than :func:`Series.quantile()` and :func:`DataFrame.quantile()` (:issue:`9413`, :issue:`16211`)
178-
178+
- Bug in ``groupby.transform()`` that would coerce boolean dtypes back to float (:issue:`16875`)
179179

180180
Sparse
181181
^^^^^^

pandas/core/groupby.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3055,7 +3055,7 @@ def transform(self, func, *args, **kwargs):
30553055
# we have a numeric dtype, as these are *always* udfs
30563056
# the cython take a different path (and casting)
30573057
dtype = self._selected_obj.dtype
3058-
if is_numeric_dtype(dtype):
3058+
if is_numeric_dtype(dtype) and not is_bool_dtype(result.dtype):
30593059
result = maybe_downcast_to_dtype(result, dtype)
30603060

30613061
result.name = self._selected_obj.name

pandas/tests/groupby/test_transform.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,19 @@ def test_transform_bug(self):
195195
expected = Series(np.arange(5, 0, step=-1), name='B')
196196
assert_series_equal(result, expected)
197197

198+
def test_transform_numeric_to_boolean(self):
199+
# GH 16875
200+
# inconsistency in transforming boolean values
201+
expected = pd.Series([True, True], name='A')
202+
203+
df = pd.DataFrame({'A': [1.1, 2.2], 'B': [1, 2]})
204+
result = df.groupby('B').A.transform(lambda x: True)
205+
assert_series_equal(result, expected)
206+
207+
df = pd.DataFrame({'A': [1, 2], 'B': [1, 2]})
208+
result = df.groupby('B').A.transform(lambda x: True)
209+
assert_series_equal(result, expected)
210+
198211
def test_transform_datetime_to_timedelta(self):
199212
# GH 15429
200213
# transforming a datetime to timedelta

0 commit comments

Comments
 (0)