Skip to content

Commit 7833fdf

Browse files
authored
BUG: groupby.agg/transform casts UDF results (#40790)
1 parent a0c7028 commit 7833fdf

19 files changed

+221
-57
lines changed

doc/source/user_guide/gotchas.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ To test for membership in the values, use the method :meth:`~pandas.Series.isin`
178178
For ``DataFrames``, likewise, ``in`` applies to the column axis,
179179
testing for membership in the list of column names.
180180

181-
.. _udf-mutation:
181+
.. _gotchas.udf-mutation:
182182

183183
Mutating with User Defined Function (UDF) methods
184184
-------------------------------------------------

doc/source/user_guide/groupby.rst

+29-2
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,26 @@ optimized Cython implementations:
739739
Of course ``sum`` and ``mean`` are implemented on pandas objects, so the above
740740
code would work even without the special versions via dispatching (see below).
741741

742+
.. _groupby.aggregate.udfs:
743+
744+
Aggregations with User-Defined Functions
745+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
746+
747+
Users can also provide their own functions for custom aggregations. When aggregating
748+
with a User-Defined Function (UDF), the UDF should not mutate the provided ``Series``, see
749+
:ref:`gotchas.udf-mutation` for more information.
750+
751+
.. ipython:: python
752+
753+
animals.groupby("kind")[["height"]].agg(lambda x: set(x))
754+
755+
The resulting dtype will reflect that of the aggregating function. If the results from different groups have
756+
different dtypes, then a common dtype will be determined in the same way as ``DataFrame`` construction.
757+
758+
.. ipython:: python
759+
760+
animals.groupby("kind")[["height"]].agg(lambda x: x.astype(int).sum())
761+
742762
.. _groupby.transform:
743763

744764
Transformation
@@ -759,7 +779,11 @@ as the one being grouped. The transform function must:
759779
* (Optionally) operates on the entire group chunk. If this is supported, a
760780
fast path is used starting from the *second* chunk.
761781

762-
For example, suppose we wished to standardize the data within each group:
782+
Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
783+
transformation function. If the results from different groups have different dtypes, then
784+
a common dtype will be determined in the same way as ``DataFrame`` construction.
785+
786+
Suppose we wished to standardize the data within each group:
763787

764788
.. ipython:: python
765789
@@ -1065,13 +1089,16 @@ that is itself a series, and possibly upcast the result to a DataFrame:
10651089
s
10661090
s.apply(f)
10671091
1068-
10691092
.. note::
10701093

10711094
``apply`` can act as a reducer, transformer, *or* filter function, depending on exactly what is passed to it.
10721095
So depending on the path taken, and exactly what you are grouping. Thus the grouped columns(s) may be included in
10731096
the output as well as set the indices.
10741097

1098+
Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
1099+
apply function. If the results from different groups have different dtypes, then
1100+
a common dtype will be determined in the same way as ``DataFrame`` construction.
1101+
10751102

10761103
Numba Accelerated Routines
10771104
--------------------------

doc/source/whatsnew/v1.3.0.rst

+30
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,36 @@ Preserve dtypes in :meth:`~pandas.DataFrame.combine_first`
298298
299299
combined.dtypes
300300
301+
Group by methods agg and transform no longer changes return dtype for callables
302+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
303+
304+
Previously the methods :meth:`.DataFrameGroupBy.aggregate`,
305+
:meth:`.SeriesGroupBy.aggregate`, :meth:`.DataFrameGroupBy.transform`, and
306+
:meth:`.SeriesGroupBy.transform` might cast the result dtype when the argument ``func``
307+
is callable, possibly leading to undesirable results (:issue:`21240`). The cast would
308+
occur if the result is numeric and casting back to the input dtype does not change any
309+
values as measured by ``np.allclose``. Now no such casting occurs.
310+
311+
.. ipython:: python
312+
313+
df = pd.DataFrame({'key': [1, 1], 'a': [True, False], 'b': [True, True]})
314+
df
315+
316+
*pandas 1.2.x*
317+
318+
.. code-block:: ipython
319+
320+
In [5]: df.groupby('key').agg(lambda x: x.sum())
321+
Out[5]:
322+
a b
323+
key
324+
1 True 2
325+
326+
*pandas 1.3.0*
327+
328+
.. ipython:: python
329+
330+
df.groupby('key').agg(lambda x: x.sum())
301331
302332
Try operating inplace when setting values with ``loc`` and ``iloc``
303333
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8552,7 +8552,7 @@ def apply(
85528552
Notes
85538553
-----
85548554
Functions that mutate the passed object can produce unexpected
8555-
behavior or errors and are not supported. See :ref:`udf-mutation`
8555+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
85568556
for more details.
85578557
85588558
Examples

pandas/core/groupby/generic.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
doc,
4545
)
4646

47-
from pandas.core.dtypes.cast import (
48-
find_common_type,
49-
maybe_downcast_numeric,
50-
)
5147
from pandas.core.dtypes.common import (
5248
ensure_int64,
5349
is_bool,
@@ -226,7 +222,16 @@ def _selection_name(self):
226222
... )
227223
minimum maximum
228224
1 1 2
229-
2 3 4"""
225+
2 3 4
226+
227+
.. versionchanged:: 1.3.0
228+
229+
The resulting dtype will reflect the return value of the aggregating function.
230+
231+
>>> s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min())
232+
1 1.0
233+
2 3.0
234+
dtype: float64"""
230235
)
231236

232237
@Appender(
@@ -566,8 +571,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
566571

567572
def _transform_general(self, func, *args, **kwargs):
568573
"""
569-
Transform with a non-str `func`.
574+
Transform with a callable func`.
570575
"""
576+
assert callable(func)
571577
klass = type(self._selected_obj)
572578

573579
results = []
@@ -589,13 +595,6 @@ def _transform_general(self, func, *args, **kwargs):
589595
result = self._set_result_index_ordered(concatenated)
590596
else:
591597
result = self.obj._constructor(dtype=np.float64)
592-
# we will only try to coerce the result type if
593-
# we have a numeric dtype, as these are *always* user-defined funcs
594-
# the cython take a different path (and casting)
595-
if is_numeric_dtype(result.dtype):
596-
common_dtype = find_common_type([self._selected_obj.dtype, result.dtype])
597-
if common_dtype is result.dtype:
598-
result = maybe_downcast_numeric(result, self._selected_obj.dtype)
599598

600599
result.name = self._selected_obj.name
601600
return result
@@ -625,7 +624,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
625624
Notes
626625
-----
627626
Functions that mutate the passed object can produce unexpected
628-
behavior or errors and are not supported. See :ref:`udf-mutation`
627+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
629628
for more details.
630629
631630
Examples
@@ -1006,7 +1005,17 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
10061005
``['column', 'aggfunc']`` to make it clearer what the arguments are.
10071006
As usual, the aggregation can be a callable or a string alias.
10081007
1009-
See :ref:`groupby.aggregate.named` for more."""
1008+
See :ref:`groupby.aggregate.named` for more.
1009+
1010+
.. versionchanged:: 1.3.0
1011+
1012+
The resulting dtype will reflect the return value of the aggregating function.
1013+
1014+
>>> df.groupby("A")[["B"]].agg(lambda x: x.astype(float).min())
1015+
B
1016+
A
1017+
1 1.0
1018+
2 3.0"""
10101019
)
10111020

10121021
@doc(_agg_template, examples=_agg_examples_doc, klass="DataFrame")
@@ -1533,7 +1542,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
15331542
which group you are working on.
15341543
15351544
Functions that mutate the passed object can produce unexpected
1536-
behavior or errors and are not supported. See :ref:`udf-mutation`
1545+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
15371546
for more details.
15381547
15391548
Examples

pandas/core/groupby/groupby.py

+50-17
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,19 @@ class providing the base-class of operations.
158158
side-effects, as they will take effect twice for the first
159159
group.
160160
161+
.. versionchanged:: 1.3.0
162+
163+
The resulting dtype will reflect the return value of the passed ``func``,
164+
see the examples below.
165+
161166
Examples
162167
--------
163168
{examples}
164169
""",
165170
"dataframe_examples": """
166171
>>> df = pd.DataFrame({'A': 'a a b'.split(),
167172
... 'B': [1,2,3],
168-
... 'C': [4,6, 5]})
173+
... 'C': [4,6,5]})
169174
>>> g = df.groupby('A')
170175
171176
Notice that ``g`` has two groups, ``a`` and ``b``.
@@ -183,13 +188,17 @@ class providing the base-class of operations.
183188
184189
Example 2: The function passed to `apply` takes a DataFrame as
185190
its argument and returns a Series. `apply` combines the result for
186-
each group together into a new DataFrame:
191+
each group together into a new DataFrame.
192+
193+
.. versionchanged:: 1.3.0
187194
188-
>>> g[['B', 'C']].apply(lambda x: x.max() - x.min())
189-
B C
195+
The resulting dtype will reflect the return value of the passed ``func``.
196+
197+
>>> g[['B', 'C']].apply(lambda x: x.astype(float).max() - x.min())
198+
B C
190199
A
191-
a 1 2
192-
b 0 0
200+
a 1.0 2.0
201+
b 0.0 0.0
193202
194203
Example 3: The function passed to `apply` takes a DataFrame as
195204
its argument and returns a scalar. `apply` combines the result for
@@ -210,12 +219,16 @@ class providing the base-class of operations.
210219
211220
Example 1: The function passed to `apply` takes a Series as
212221
its argument and returns a Series. `apply` combines the result for
213-
each group together into a new Series:
222+
each group together into a new Series.
223+
224+
.. versionchanged:: 1.3.0
214225
215-
>>> g.apply(lambda x: x*2 if x.name == 'b' else x/2)
226+
The resulting dtype will reflect the return value of the passed ``func``.
227+
228+
>>> g.apply(lambda x: x*2 if x.name == 'a' else x/2)
216229
a 0.0
217-
a 0.5
218-
b 4.0
230+
a 2.0
231+
b 1.0
219232
dtype: float64
220233
221234
Example 2: The function passed to `apply` takes a Series as
@@ -367,12 +380,17 @@ class providing the base-class of operations.
367380
in the subframe. If f also supports application to the entire subframe,
368381
then a fast path is used starting from the second chunk.
369382
* f must not mutate groups. Mutation is not supported and may
370-
produce unexpected results. See :ref:`udf-mutation` for more details.
383+
produce unexpected results. See :ref:`gotchas.udf-mutation` for more details.
371384
372385
When using ``engine='numba'``, there will be no "fall back" behavior internally.
373386
The group data and group index will be passed as numpy arrays to the JITed
374387
user defined function, and no alternative execution attempts will be tried.
375388
389+
.. versionchanged:: 1.3.0
390+
391+
The resulting dtype will reflect the return value of the passed ``func``,
392+
see the examples below.
393+
376394
Examples
377395
--------
378396
@@ -402,6 +420,20 @@ class providing the base-class of operations.
402420
3 3 8.0
403421
4 4 6.0
404422
5 3 8.0
423+
424+
.. versionchanged:: 1.3.0
425+
426+
The resulting dtype will reflect the return value of the passed ``func``,
427+
for example:
428+
429+
>>> grouped[['C', 'D']].transform(lambda x: x.astype(int).max())
430+
C D
431+
0 5 8
432+
1 5 9
433+
2 5 8
434+
3 5 9
435+
4 5 8
436+
5 5 9
405437
"""
406438

407439
_agg_template = """
@@ -469,12 +501,16 @@ class providing the base-class of operations.
469501
When using ``engine='numba'``, there will be no "fall back" behavior internally.
470502
The group data and group index will be passed as numpy arrays to the JITed
471503
user defined function, and no alternative execution attempts will be tried.
472-
{examples}
473504
474505
Functions that mutate the passed object can produce unexpected
475-
behavior or errors and are not supported. See :ref:`udf-mutation`
506+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
476507
for more details.
477-
"""
508+
509+
.. versionchanged:: 1.3.0
510+
511+
The resulting dtype will reflect the return value of the passed ``func``,
512+
see the examples below.
513+
{examples}"""
478514

479515

480516
@final
@@ -1232,9 +1268,6 @@ def _python_agg_general(self, func, *args, **kwargs):
12321268
assert result is not None
12331269
key = base.OutputKey(label=name, position=idx)
12341270

1235-
if is_numeric_dtype(obj.dtype):
1236-
result = maybe_downcast_numeric(result, obj.dtype)
1237-
12381271
if self.grouper._filter_empty_groups:
12391272
mask = counts.ravel() > 0
12401273

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4190,7 +4190,7 @@ def apply(
41904190
Notes
41914191
-----
41924192
Functions that mutate the passed object can produce unexpected
4193-
behavior or errors and are not supported. See :ref:`udf-mutation`
4193+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
41944194
for more details.
41954195
41964196
Examples

pandas/core/shared_docs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
`agg` is an alias for `aggregate`. Use the alias.
4343
4444
Functions that mutate the passed object can produce unexpected
45-
behavior or errors and are not supported. See :ref:`udf-mutation`
45+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
4646
for more details.
4747
4848
A passed user-defined-function will be passed a Series for evaluation.
@@ -303,7 +303,7 @@
303303
Notes
304304
-----
305305
Functions that mutate the passed object can produce unexpected
306-
behavior or errors and are not supported. See :ref:`udf-mutation`
306+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
307307
for more details.
308308
309309
Examples

0 commit comments

Comments
 (0)