diff --git a/doc/source/whatsnew/v1.1.0.rst b/doc/source/whatsnew/v1.1.0.rst index 9bd4ddbb624d9..0e929ff062cff 100644 --- a/doc/source/whatsnew/v1.1.0.rst +++ b/doc/source/whatsnew/v1.1.0.rst @@ -1080,6 +1080,7 @@ Groupby/resample/rolling - Bug in :meth:`DataFrame.groupby` lost index, when one of the ``agg`` keys referenced an empty list (:issue:`32580`) - Bug in :meth:`Rolling.apply` where ``center=True`` was ignored when ``engine='numba'`` was specified (:issue:`34784`) - Bug in :meth:`DataFrame.ewm.cov` was throwing ``AssertionError`` for :class:`MultiIndex` inputs (:issue:`34440`) +- Bug in :meth:`core.groupby.DataFrameGroupBy.transform` when ``func='nunique'`` and columns are of type ``datetime64``, the result would also be of type ``datetime64`` instead of ``int64`` (:issue:`35109`) Reshaping ^^^^^^^^^ diff --git a/pandas/core/common.py b/pandas/core/common.py index b4f726f4e59a9..e7260a9923ee0 100644 --- a/pandas/core/common.py +++ b/pandas/core/common.py @@ -5,10 +5,11 @@ """ from collections import abc, defaultdict +import contextlib from datetime import datetime, timedelta from functools import partial import inspect -from typing import Any, Collection, Iterable, List, Union +from typing import Any, Collection, Iterable, Iterator, List, Union import warnings import numpy as np @@ -502,3 +503,21 @@ def convert_to_list_like( return list(values) return [values] + + +@contextlib.contextmanager +def temp_setattr(obj, attr: str, value) -> Iterator[None]: + """Temporarily set attribute on an object. + + Args: + obj: Object whose attribute will be modified. + attr: Attribute to modify. + value: Value to temporarily set attribute to. + + Yields: + obj with modified attribute. + """ + old_value = getattr(obj, attr) + setattr(obj, attr, value) + yield obj + setattr(obj, attr, old_value) diff --git a/pandas/core/groupby/generic.py b/pandas/core/groupby/generic.py index 6f956a3dcc9b6..61a739a009cf8 100644 --- a/pandas/core/groupby/generic.py +++ b/pandas/core/groupby/generic.py @@ -485,8 +485,10 @@ def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): # If func is a reduction, we need to broadcast the # result to the whole group. Compute func result # and deal with possible broadcasting below. - result = getattr(self, func)(*args, **kwargs) - return self._transform_fast(result, func) + # Temporarily set observed for dealing with categoricals. + with com.temp_setattr(self, "observed", True): + result = getattr(self, func)(*args, **kwargs) + return self._transform_fast(result) def _transform_general( self, func, *args, engine="cython", engine_kwargs=None, **kwargs @@ -539,17 +541,14 @@ def _transform_general( result.index = self._selected_obj.index return result - def _transform_fast(self, result, func_nm: str) -> Series: + def _transform_fast(self, result) -> Series: """ fast version of transform, only applicable to builtin/cythonizable functions """ ids, _, ngroup = self.grouper.group_info result = result.reindex(self.grouper.result_index, copy=False) - cast = self._transform_should_cast(func_nm) out = algorithms.take_1d(result._values, ids) - if cast: - out = maybe_cast_result(out, self.obj, how=func_nm) return self.obj._constructor(out, index=self.obj.index, name=self.obj.name) def filter(self, func, dropna=True, *args, **kwargs): @@ -1467,25 +1466,23 @@ def transform(self, func, *args, engine="cython", engine_kwargs=None, **kwargs): # If func is a reduction, we need to broadcast the # result to the whole group. Compute func result # and deal with possible broadcasting below. - result = getattr(self, func)(*args, **kwargs) + # Temporarily set observed for dealing with categoricals. + with com.temp_setattr(self, "observed", True): + result = getattr(self, func)(*args, **kwargs) if isinstance(result, DataFrame) and result.columns.equals( self._obj_with_exclusions.columns ): - return self._transform_fast(result, func) + return self._transform_fast(result) return self._transform_general( func, engine=engine, engine_kwargs=engine_kwargs, *args, **kwargs ) - def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: + def _transform_fast(self, result: DataFrame) -> DataFrame: """ Fast transform path for aggregations """ - # if there were groups with no observations (Categorical only?) - # try casting data to original dtype - cast = self._transform_should_cast(func_nm) - obj = self._obj_with_exclusions # for each col, reshape to to size of original frame @@ -1494,12 +1491,7 @@ def _transform_fast(self, result: DataFrame, func_nm: str) -> DataFrame: result = result.reindex(self.grouper.result_index, copy=False) output = [] for i, _ in enumerate(result.columns): - res = algorithms.take_1d(result.iloc[:, i].values, ids) - # TODO: we have no test cases that get here with EA dtypes; - # maybe_cast_result may not be needed if EAs never get here - if cast: - res = maybe_cast_result(res, obj.iloc[:, i], how=func_nm) - output.append(res) + output.append(algorithms.take_1d(result.iloc[:, i].values, ids)) return self.obj._constructor._from_arrays( output, columns=result.columns, index=obj.index diff --git a/pandas/tests/groupby/test_nunique.py b/pandas/tests/groupby/test_nunique.py index 1475b1ce2907c..c3347b7ae52f3 100644 --- a/pandas/tests/groupby/test_nunique.py +++ b/pandas/tests/groupby/test_nunique.py @@ -167,3 +167,11 @@ def test_nunique_preserves_column_level_names(): result = test.groupby([0, 0, 0]).nunique() expected = pd.DataFrame([2], columns=test.columns) tm.assert_frame_equal(result, expected) + + +def test_nunique_transform_with_datetime(): + # GH 35109 - transform with nunique on datetimes results in integers + df = pd.DataFrame(date_range("2008-12-31", "2009-01-02"), columns=["date"]) + result = df.groupby([0, 0, 1])["date"].transform("nunique") + expected = pd.Series([2, 2, 1], name="date") + tm.assert_series_equal(result, expected)