Skip to content

Commit 8fd4cd3

Browse files
authored
BUG: transform with nunique should have dtype int64 (#35152)
1 parent bc901ab commit 8fd4cd3

File tree

4 files changed

+40
-20
lines changed

4 files changed

+40
-20
lines changed

doc/source/whatsnew/v1.1.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ Groupby/resample/rolling
10891089
- Bug in :meth:`DataFrame.groupby` lost index, when one of the ``agg`` keys referenced an empty list (:issue:`32580`)
10901090
- Bug in :meth:`Rolling.apply` where ``center=True`` was ignored when ``engine='numba'`` was specified (:issue:`34784`)
10911091
- Bug in :meth:`DataFrame.ewm.cov` was throwing ``AssertionError`` for :class:`MultiIndex` inputs (:issue:`34440`)
1092+
- Bug in :meth:`core.groupby.DataFrameGroupBy.transform` when ``func='nunique'`` and columns are of type ``datetime64``, the result would also be of type ``datetime64`` instead of ``int64`` (:issue:`35109`)
10921093

10931094
Reshaping
10941095
^^^^^^^^^

pandas/core/common.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
"""
66

77
from collections import abc, defaultdict
8+
import contextlib
89
from datetime import datetime, timedelta
910
from functools import partial
1011
import inspect
11-
from typing import Any, Collection, Iterable, List, Union
12+
from typing import Any, Collection, Iterable, Iterator, List, Union
1213
import warnings
1314

1415
import numpy as np
@@ -502,3 +503,21 @@ def convert_to_list_like(
502503
return list(values)
503504

504505
return [values]
506+
507+
508+
@contextlib.contextmanager
509+
def temp_setattr(obj, attr: str, value) -> Iterator[None]:
510+
"""Temporarily set attribute on an object.
511+
512+
Args:
513+
obj: Object whose attribute will be modified.
514+
attr: Attribute to modify.
515+
value: Value to temporarily set attribute to.
516+
517+
Yields:
518+
obj with modified attribute.
519+
"""
520+
old_value = getattr(obj, attr)
521+
setattr(obj, attr, value)
522+
yield obj
523+
setattr(obj, attr, old_value)

pandas/core/groupby/generic.py

+11-19
Original file line numberDiff line numberDiff line change
@@ -500,8 +500,10 @@ def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
500500
# If func is a reduction, we need to broadcast the
501501
# result to the whole group. Compute func result
502502
# and deal with possible broadcasting below.
503-
result = getattr(self, func)(*args, **kwargs)
504-
return self._transform_fast(result, func)
503+
# Temporarily set observed for dealing with categoricals.
504+
with com.temp_setattr(self, "observed", True):
505+
result = getattr(self, func)(*args, **kwargs)
506+
return self._transform_fast(result)
505507

506508
def _transform_general(
507509
self, func, *args, engine="cython", engine_kwargs=None, **kwargs
@@ -554,17 +556,14 @@ def _transform_general(
554556
result.index = self._selected_obj.index
555557
return result
556558

557-
def _transform_fast(self, result, func_nm: str) -> Series:
559+
def _transform_fast(self, result) -> Series:
558560
"""
559561
fast version of transform, only applicable to
560562
builtin/cythonizable functions
561563
"""
562564
ids, _, ngroup = self.grouper.group_info
563565
result = result.reindex(self.grouper.result_index, copy=False)
564-
cast = self._transform_should_cast(func_nm)
565566
out = algorithms.take_1d(result._values, ids)
566-
if cast:
567-
out = maybe_cast_result(out, self.obj, how=func_nm)
568567
return self.obj._constructor(out, index=self.obj.index, name=self.obj.name)
569568

570569
def filter(self, func, dropna=True, *args, **kwargs):
@@ -1465,25 +1464,23 @@ def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs):
14651464
# If func is a reduction, we need to broadcast the
14661465
# result to the whole group. Compute func result
14671466
# and deal with possible broadcasting below.
1468-
result = getattr(self, func)(*args, **kwargs)
1467+
# Temporarily set observed for dealing with categoricals.
1468+
with com.temp_setattr(self, "observed", True):
1469+
result = getattr(self, func)(*args, **kwargs)
14691470

14701471
if isinstance(result, DataFrame) and result.columns.equals(
14711472
self._obj_with_exclusions.columns
14721473
):
1473-
return self._transform_fast(result, func)
1474+
return self._transform_fast(result)
14741475

14751476
return self._transform_general(
14761477
func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs
14771478
)
14781479

1479-
def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
1480+
def _transform_fast(self, result: DataFrame) -> DataFrame:
14801481
"""
14811482
Fast transform path for aggregations
14821483
"""
1483-
# if there were groups with no observations (Categorical only?)
1484-
# try casting data to original dtype
1485-
cast = self._transform_should_cast(func_nm)
1486-
14871484
obj = self._obj_with_exclusions
14881485

14891486
# for each col, reshape to to size of original frame
@@ -1492,12 +1489,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame:
14921489
result = result.reindex(self.grouper.result_index, copy=False)
14931490
output = []
14941491
for i, _ in enumerate(result.columns):
1495-
res = algorithms.take_1d(result.iloc[:, i].values, ids)
1496-
# TODO: we have no test cases that get here with EA dtypes;
1497-
# maybe_cast_result may not be needed if EAs never get here
1498-
if cast:
1499-
res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm)
1500-
output.append(res)
1492+
output.append(algorithms.take_1d(result.iloc[:, i].values, ids))
15011493

15021494
return self.obj._constructor._from_arrays(
15031495
output, columns=result.columns, index=obj.index

pandas/tests/groupby/test_nunique.py

+8
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,11 @@ def test_nunique_preserves_column_level_names():
167167
result = test.groupby([0, 0, 0]).nunique()
168168
expected = pd.DataFrame([2], columns=test.columns)
169169
tm.assert_frame_equal(result, expected)
170+
171+
172+
def test_nunique_transform_with_datetime():
173+
# GH 35109 - transform with nunique on datetimes results in integers
174+
df = pd.DataFrame(date_range("2008-12-31", "2009-01-02"), columns=["date"])
175+
result = df.groupby([0, 0, 1])["date"].transform("nunique")
176+
expected = pd.Series([2, 2, 1], name="date")
177+
tm.assert_series_equal(result, expected)

0 commit comments

Comments
 (0)