Skip to content

Commit 8d543ba

Browse files
BUG/PERF: groupby.transform with unobserved categories (#58084)
1 parent ca55d77 commit 8d543ba

File tree

5 files changed

+120
-19
lines changed

5 files changed

+120
-19
lines changed

doc/source/whatsnew/v3.0.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,7 @@ Groupby/resample/rolling
457457
- Bug in :meth:`DataFrame.ewm` and :meth:`Series.ewm` when passed ``times`` and aggregation functions other than mean (:issue:`51695`)
458458
- Bug in :meth:`DataFrameGroupBy.apply` that was returning a completely empty DataFrame when all return values of ``func`` were ``None`` instead of returning an empty DataFrame with the original columns and dtypes. (:issue:`57775`)
459459
- Bug in :meth:`DataFrameGroupBy.apply` with ``as_index=False`` that was returning :class:`MultiIndex` instead of returning :class:`Index`. (:issue:`58291`)
460+
- Bug in :meth:`DataFrameGroupby.transform` and :meth:`SeriesGroupby.transform` with a reducer and ``observed=False`` that coerces dtype to float when there are unobserved categories. (:issue:`55326`)
460461

461462

462463
Reshaping

pandas/core/groupby/groupby.py

+32-16
Original file line numberDiff line numberDiff line change
@@ -1875,24 +1875,40 @@ def _transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
18751875

18761876
else:
18771877
# i.e. func in base.reduction_kernels
1878+
if self.observed:
1879+
return self._reduction_kernel_transform(
1880+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1881+
)
18781882

1879-
# GH#30918 Use _transform_fast only when we know func is an aggregation
1880-
# If func is a reduction, we need to broadcast the
1881-
# result to the whole group. Compute func result
1882-
# and deal with possible broadcasting below.
1883-
with com.temp_setattr(self, "as_index", True):
1884-
# GH#49834 - result needs groups in the index for
1885-
# _wrap_transform_fast_result
1886-
if func in ["idxmin", "idxmax"]:
1887-
func = cast(Literal["idxmin", "idxmax"], func)
1888-
result = self._idxmax_idxmin(func, True, *args, **kwargs)
1889-
else:
1890-
if engine is not None:
1891-
kwargs["engine"] = engine
1892-
kwargs["engine_kwargs"] = engine_kwargs
1893-
result = getattr(self, func)(*args, **kwargs)
1883+
with (
1884+
com.temp_setattr(self, "observed", True),
1885+
com.temp_setattr(self, "_grouper", self._grouper.observed_grouper),
1886+
):
1887+
return self._reduction_kernel_transform(
1888+
func, *args, engine=engine, engine_kwargs=engine_kwargs, **kwargs
1889+
)
1890+
1891+
@final
1892+
def _reduction_kernel_transform(
1893+
self, func, *args, engine=None, engine_kwargs=None, **kwargs
1894+
):
1895+
# GH#30918 Use _transform_fast only when we know func is an aggregation
1896+
# If func is a reduction, we need to broadcast the
1897+
# result to the whole group. Compute func result
1898+
# and deal with possible broadcasting below.
1899+
with com.temp_setattr(self, "as_index", True):
1900+
# GH#49834 - result needs groups in the index for
1901+
# _wrap_transform_fast_result
1902+
if func in ["idxmin", "idxmax"]:
1903+
func = cast(Literal["idxmin", "idxmax"], func)
1904+
result = self._idxmax_idxmin(func, True, *args, **kwargs)
1905+
else:
1906+
if engine is not None:
1907+
kwargs["engine"] = engine
1908+
kwargs["engine_kwargs"] = engine_kwargs
1909+
result = getattr(self, func)(*args, **kwargs)
18941910

1895-
return self._wrap_transform_fast_result(result)
1911+
return self._wrap_transform_fast_result(result)
18961912

18971913
@final
18981914
def _wrap_transform_fast_result(self, result: NDFrameT) -> NDFrameT:

pandas/core/groupby/grouper.py

+22
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,28 @@ def groups(self) -> dict[Hashable, Index]:
668668
cats = Categorical.from_codes(codes, uniques, validate=False)
669669
return self._index.groupby(cats)
670670

671+
@property
672+
def observed_grouping(self) -> Grouping:
673+
if self._observed:
674+
return self
675+
676+
return self._observed_grouping
677+
678+
@cache_readonly
679+
def _observed_grouping(self) -> Grouping:
680+
grouping = Grouping(
681+
self._index,
682+
self._orig_grouper,
683+
obj=self.obj,
684+
level=self.level,
685+
sort=self._sort,
686+
observed=True,
687+
in_axis=self.in_axis,
688+
dropna=self._dropna,
689+
uniques=self._uniques,
690+
)
691+
return grouping
692+
671693

672694
def get_grouper(
673695
obj: NDFrameT,

pandas/core/groupby/ops.py

+17
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,19 @@ def result_index_and_ids(self) -> tuple[Index, npt.NDArray[np.intp]]:
823823

824824
return result_index, ids
825825

826+
@property
827+
def observed_grouper(self) -> BaseGrouper:
828+
if all(ping._observed for ping in self.groupings):
829+
return self
830+
831+
return self._observed_grouper
832+
833+
@cache_readonly
834+
def _observed_grouper(self) -> BaseGrouper:
835+
groupings = [ping.observed_grouping for ping in self.groupings]
836+
grouper = BaseGrouper(self.axis, groupings, sort=self._sort, dropna=self.dropna)
837+
return grouper
838+
826839
def _ob_index_and_ids(
827840
self,
828841
levels: list[Index],
@@ -1154,6 +1167,10 @@ def groupings(self) -> list[grouper.Grouping]:
11541167
)
11551168
return [ping]
11561169

1170+
@property
1171+
def observed_grouper(self) -> BinGrouper:
1172+
return self
1173+
11571174

11581175
def _is_indexed_like(obj, axes) -> bool:
11591176
if isinstance(obj, Series):

pandas/tests/groupby/transform/test_transform.py

+48-3
Original file line numberDiff line numberDiff line change
@@ -1232,9 +1232,9 @@ def test_categorical_and_not_categorical_key(observed):
12321232
tm.assert_frame_equal(result, expected_explicit)
12331233

12341234
# Series case
1235-
result = df_with_categorical.groupby(["A", "C"], observed=observed)["B"].transform(
1236-
"sum"
1237-
)
1235+
gb = df_with_categorical.groupby(["A", "C"], observed=observed)
1236+
gbp = gb["B"]
1237+
result = gbp.transform("sum")
12381238
expected = df_without_categorical.groupby(["A", "C"])["B"].transform("sum")
12391239
tm.assert_series_equal(result, expected)
12401240
expected_explicit = Series([4, 2, 4], name="B")
@@ -1535,3 +1535,48 @@ def test_transform_sum_one_column_with_matching_labels_and_missing_labels():
15351535
result = df.groupby(series, as_index=False).transform("sum")
15361536
expected = DataFrame({"X": [-93203.0, -93203.0, np.nan]})
15371537
tm.assert_frame_equal(result, expected)
1538+
1539+
1540+
@pytest.mark.parametrize("dtype", ["int32", "float32"])
1541+
def test_min_one_unobserved_category_no_type_coercion(dtype):
1542+
# GH#58084
1543+
df = DataFrame({"A": Categorical([1, 1, 2], categories=[1, 2, 3]), "B": [3, 4, 5]})
1544+
df["B"] = df["B"].astype(dtype)
1545+
gb = df.groupby("A", observed=False)
1546+
result = gb.transform("min")
1547+
1548+
expected = DataFrame({"B": [3, 3, 5]}, dtype=dtype)
1549+
tm.assert_frame_equal(expected, result)
1550+
1551+
1552+
def test_min_all_empty_data_no_type_coercion():
1553+
# GH#58084
1554+
df = DataFrame(
1555+
{
1556+
"X": Categorical(
1557+
[],
1558+
categories=[1, "randomcat", 100],
1559+
),
1560+
"Y": [],
1561+
}
1562+
)
1563+
df["Y"] = df["Y"].astype("int32")
1564+
1565+
gb = df.groupby("X", observed=False)
1566+
result = gb.transform("min")
1567+
1568+
expected = DataFrame({"Y": []}, dtype="int32")
1569+
tm.assert_frame_equal(expected, result)
1570+
1571+
1572+
def test_min_one_dim_no_type_coercion():
1573+
# GH#58084
1574+
df = DataFrame({"Y": [9435, -5465765, 5055, 0, 954960]})
1575+
df["Y"] = df["Y"].astype("int32")
1576+
categories = Categorical([1, 2, 2, 5, 1], categories=[1, 2, 3, 4, 5])
1577+
1578+
gb = df.groupby(categories, observed=False)
1579+
result = gb.transform("min")
1580+
1581+
expected = DataFrame({"Y": [9435, -5465765, -5465765, 0, 9435]}, dtype="int32")
1582+
tm.assert_frame_equal(expected, result)

0 commit comments

Comments
 (0)