Skip to content

Commit 6ae710a

Browse files
authored
BUG: Don't cast nullable Boolean to float in groupby (#33089)
1 parent d1b1236 commit 6ae710a

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

doc/source/whatsnew/v1.1.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,8 @@ Groupby/resample/rolling
445445
- Bug in :meth:`DataFrameGroupby.transform` produces incorrect result with transformation functions (:issue:`30918`)
446446
- Bug in :meth:`GroupBy.count` causes segmentation fault when grouped-by column contains NaNs (:issue:`32841`)
447447
- Bug in :meth:`DataFrame.groupby` and :meth:`Series.groupby` produces inconsistent type when aggregating Boolean series (:issue:`32894`)
448+
- Bug in :meth:`SeriesGroupBy.first`, :meth:`SeriesGroupBy.last`, :meth:`SeriesGroupBy.min`, and :meth:`SeriesGroupBy.max` returning floats when applied to nullable Booleans (:issue:`33071`)
449+
- Bug in :meth:`DataFrameGroupBy.agg` with dictionary input losing ``ExtensionArray`` dtypes (:issue:`32194`)
448450
- Bug in :meth:`DataFrame.resample` where an ``AmbiguousTimeError`` would be raised when the resulting timezone aware :class:`DatetimeIndex` had a DST transition at midnight (:issue:`25758`)
449451

450452
Reshaping

pandas/core/dtypes/cast.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
ensure_str,
3434
is_bool,
3535
is_bool_dtype,
36+
is_categorical_dtype,
3637
is_complex,
3738
is_complex_dtype,
3839
is_datetime64_dtype,
@@ -279,12 +280,15 @@ def maybe_cast_result(result, obj: "Series", numeric_only: bool = False, how: st
279280
dtype = maybe_cast_result_dtype(dtype, how)
280281

281282
if not is_scalar(result):
282-
if is_extension_array_dtype(dtype) and dtype.kind != "M":
283-
# The result may be of any type, cast back to original
284-
# type if it's compatible.
285-
if len(result) and isinstance(result[0], dtype.type):
286-
cls = dtype.construct_array_type()
287-
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
283+
if (
284+
is_extension_array_dtype(dtype)
285+
and not is_categorical_dtype(dtype)
286+
and dtype.kind != "M"
287+
):
288+
# We have to special case categorical so as not to upcast
289+
# things like counts back to categorical
290+
cls = dtype.construct_array_type()
291+
result = maybe_cast_to_extension_array(cls, result, dtype=dtype)
288292

289293
elif numeric_only and is_numeric_dtype(dtype) or not numeric_only:
290294
result = maybe_downcast_to_dtype(result, dtype)

pandas/tests/groupby/test_nth.py

+26
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,32 @@ def test_first_last_tz_multi_column(method, ts, alpha):
384384
tm.assert_frame_equal(result, expected)
385385

386386

387+
@pytest.mark.parametrize(
388+
"values",
389+
[
390+
pd.array([True, False], dtype="boolean"),
391+
pd.array([1, 2], dtype="Int64"),
392+
pd.to_datetime(["2020-01-01", "2020-02-01"]),
393+
pd.to_timedelta([1, 2], unit="D"),
394+
],
395+
)
396+
@pytest.mark.parametrize("function", ["first", "last", "min", "max"])
397+
def test_first_last_extension_array_keeps_dtype(values, function):
398+
# https://github.com/pandas-dev/pandas/issues/33071
399+
# https://github.com/pandas-dev/pandas/issues/32194
400+
df = DataFrame({"a": [1, 2], "b": values})
401+
grouped = df.groupby("a")
402+
idx = Index([1, 2], name="a")
403+
expected_series = Series(values, name="b", index=idx)
404+
expected_frame = DataFrame({"b": values}, index=idx)
405+
406+
result_series = getattr(grouped["b"], function)()
407+
tm.assert_series_equal(result_series, expected_series)
408+
409+
result_frame = grouped.agg({"b": function})
410+
tm.assert_frame_equal(result_frame, expected_frame)
411+
412+
387413
def test_nth_multi_index_as_expected():
388414
# PR 9090, related to issue 8979
389415
# test nth on MultiIndex

pandas/tests/resample/test_datetime_index.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,7 @@ def test_resample_integerarray():
122122

123123
result = ts.resample("3T").mean()
124124
expected = Series(
125-
[1, 4, 7],
126-
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
127-
dtype="float64",
125+
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64",
128126
)
129127
tm.assert_series_equal(result, expected)
130128

0 commit comments

Comments
 (0)