Skip to content

Commit bbbf48c

Browse files
jeetjitsu3553x
authored andcommitted
BUG: coercing of bools in groupby transform (#16895)
1 parent b0a4193 commit bbbf48c

File tree

4 files changed

+25
-5
lines changed

4 files changed

+25
-5
lines changed

doc/source/whatsnew/v0.21.0.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ Groupby/Resample/Rolling
175175
- Bug in ``DataFrame.resample(...).size()`` where an empty ``DataFrame`` did not return a ``Series`` (:issue:`14962`)
176176
- Bug in :func:`infer_freq` causing indices with 2-day gaps during the working week to be wrongly inferred as business daily (:issue:`16624`)
177177
- Bug in ``.rolling(...).quantile()`` which incorrectly used different defaults than :func:`Series.quantile()` and :func:`DataFrame.quantile()` (:issue:`9413`, :issue:`16211`)
178-
178+
- Bug in ``groupby.transform()`` that would coerce boolean dtypes back to float (:issue:`16875`)
179179

180180
Sparse
181181
^^^^^^

pandas/core/dtypes/cast.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def trans(x): # noqa
110110
np.prod(result.shape)):
111111
return result
112112

113-
if issubclass(dtype.type, np.floating):
114-
return result.astype(dtype)
115-
elif is_bool_dtype(dtype) or is_integer_dtype(dtype):
113+
if is_bool_dtype(dtype) or is_integer_dtype(dtype):
116114

117115
# if we don't have any elements, just astype it
118116
if not np.prod(result.shape):
@@ -144,6 +142,9 @@ def trans(x): # noqa
144142
# hit here
145143
if (new_result == result).all():
146144
return new_result
145+
elif (issubclass(dtype.type, np.floating) and
146+
not is_bool_dtype(result.dtype)):
147+
return result.astype(dtype)
147148

148149
# a datetimelike
149150
# GH12821, iNaT is casted to float

pandas/tests/dtypes/test_cast.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from datetime import datetime, timedelta, date
1010
import numpy as np
1111

12-
from pandas import Timedelta, Timestamp, DatetimeIndex, DataFrame, NaT
12+
from pandas import Timedelta, Timestamp, DatetimeIndex, DataFrame, NaT, Series
1313

1414
from pandas.core.dtypes.cast import (
1515
maybe_downcast_to_dtype,
@@ -45,6 +45,12 @@ def test_downcast_conv(self):
4545
expected = np.array([8, 8, 8, 8, 9])
4646
assert (np.array_equal(result, expected))
4747

48+
# GH16875 coercing of bools
49+
ser = Series([True, True, False])
50+
result = maybe_downcast_to_dtype(ser, np.dtype(np.float64))
51+
expected = ser
52+
tm.assert_series_equal(result, expected)
53+
4854
# conversions
4955

5056
expected = np.array([1, 2])

pandas/tests/groupby/test_transform.py

+13
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,19 @@ def test_transform_bug(self):
195195
expected = Series(np.arange(5, 0, step=-1), name='B')
196196
assert_series_equal(result, expected)
197197

198+
def test_transform_numeric_to_boolean(self):
199+
# GH 16875
200+
# inconsistency in transforming boolean values
201+
expected = pd.Series([True, True], name='A')
202+
203+
df = pd.DataFrame({'A': [1.1, 2.2], 'B': [1, 2]})
204+
result = df.groupby('B').A.transform(lambda x: True)
205+
assert_series_equal(result, expected)
206+
207+
df = pd.DataFrame({'A': [1, 2], 'B': [1, 2]})
208+
result = df.groupby('B').A.transform(lambda x: True)
209+
assert_series_equal(result, expected)
210+
198211
def test_transform_datetime_to_timedelta(self):
199212
# GH 15429
200213
# transforming a datetime to timedelta

0 commit comments

Comments
 (0)