Skip to content

Commit 1ac3aad

Browse files
rhshadrachJulianWgs
authored andcommitted
BUG: groupby.agg/transform casts UDF results (pandas-dev#40790)
1 parent d24109c commit 1ac3aad

19 files changed

+221
-57
lines changed

doc/source/user_guide/gotchas.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ To test for membership in the values, use the method :meth:`~pandas.Series.isin`
178178
For ``DataFrames``, likewise, ``in`` applies to the column axis,
179179
testing for membership in the list of column names.
180180

181-
.. _udf-mutation:
181+
.. _gotchas.udf-mutation:
182182

183183
Mutating with User Defined Function (UDF) methods
184184
-------------------------------------------------

doc/source/user_guide/groupby.rst

+29-2
Original file line numberDiff line numberDiff line change
@@ -739,6 +739,26 @@ optimized Cython implementations:
739739
Of course ``sum`` and ``mean`` are implemented on pandas objects, so the above
740740
code would work even without the special versions via dispatching (see below).
741741

742+
.. _groupby.aggregate.udfs:
743+
744+
Aggregations with User-Defined Functions
745+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
746+
747+
Users can also provide their own functions for custom aggregations. When aggregating
748+
with a User-Defined Function (UDF), the UDF should not mutate the provided ``Series``, see
749+
:ref:`gotchas.udf-mutation` for more information.
750+
751+
.. ipython:: python
752+
753+
animals.groupby("kind")[["height"]].agg(lambda x: set(x))
754+
755+
The resulting dtype will reflect that of the aggregating function. If the results from different groups have
756+
different dtypes, then a common dtype will be determined in the same way as ``DataFrame`` construction.
757+
758+
.. ipython:: python
759+
760+
animals.groupby("kind")[["height"]].agg(lambda x: x.astype(int).sum())
761+
742762
.. _groupby.transform:
743763

744764
Transformation
@@ -759,7 +779,11 @@ as the one being grouped. The transform function must:
759779
* (Optionally) operates on the entire group chunk. If this is supported, a
760780
fast path is used starting from the *second* chunk.
761781

762-
For example, suppose we wished to standardize the data within each group:
782+
Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
783+
transformation function. If the results from different groups have different dtypes, then
784+
a common dtype will be determined in the same way as ``DataFrame`` construction.
785+
786+
Suppose we wished to standardize the data within each group:
763787

764788
.. ipython:: python
765789
@@ -1065,13 +1089,16 @@ that is itself a series, and possibly upcast the result to a DataFrame:
10651089
s
10661090
s.apply(f)
10671091
1068-
10691092
.. note::
10701093

10711094
``apply`` can act as a reducer, transformer, *or* filter function, depending on exactly what is passed to it.
10721095
So depending on the path taken, and exactly what you are grouping. Thus the grouped columns(s) may be included in
10731096
the output as well as set the indices.
10741097

1098+
Similar to :ref:`groupby.aggregate.udfs`, the resulting dtype will reflect that of the
1099+
apply function. If the results from different groups have different dtypes, then
1100+
a common dtype will be determined in the same way as ``DataFrame`` construction.
1101+
10751102

10761103
Numba Accelerated Routines
10771104
--------------------------

doc/source/whatsnew/v1.3.0.rst

+30
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,36 @@ Preserve dtypes in :meth:`~pandas.DataFrame.combine_first`
298298
299299
combined.dtypes
300300
301+
Group by methods agg and transform no longer changes return dtype for callables
302+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
303+
304+
Previously the methods :meth:`.DataFrameGroupBy.aggregate`,
305+
:meth:`.SeriesGroupBy.aggregate`, :meth:`.DataFrameGroupBy.transform`, and
306+
:meth:`.SeriesGroupBy.transform` might cast the result dtype when the argument ``func``
307+
is callable, possibly leading to undesirable results (:issue:`21240`). The cast would
308+
occur if the result is numeric and casting back to the input dtype does not change any
309+
values as measured by ``np.allclose``. Now no such casting occurs.
310+
311+
.. ipython:: python
312+
313+
df = pd.DataFrame({'key': [1, 1], 'a': [True, False], 'b': [True, True]})
314+
df
315+
316+
*pandas 1.2.x*
317+
318+
.. code-block:: ipython
319+
320+
In [5]: df.groupby('key').agg(lambda x: x.sum())
321+
Out[5]:
322+
a b
323+
key
324+
1 True 2
325+
326+
*pandas 1.3.0*
327+
328+
.. ipython:: python
329+
330+
df.groupby('key').agg(lambda x: x.sum())
301331
302332
Try operating inplace when setting values with ``loc`` and ``iloc``
303333
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

pandas/core/frame.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8552,7 +8552,7 @@ def apply(
85528552
Notes
85538553
-----
85548554
Functions that mutate the passed object can produce unexpected
8555-
behavior or errors and are not supported. See :ref:`udf-mutation`
8555+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
85568556
for more details.
85578557
85588558
Examples

pandas/core/groupby/generic.py

+25-16
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@
4444
doc,
4545
)
4646

47-
from pandas.core.dtypes.cast import (
48-
find_common_type,
49-
maybe_downcast_numeric,
50-
)
5147
from pandas.core.dtypes.common import (
5248
ensure_int64,
5349
is_bool,
@@ -226,7 +222,16 @@ def _selection_name(self):
226222
... )
227223
minimum maximum
228224
1 1 2
229-
2 3 4"""
225+
2 3 4
226+
227+
.. versionchanged:: 1.3.0
228+
229+
The resulting dtype will reflect the return value of the aggregating function.
230+
231+
>>> s.groupby([1, 1, 2, 2]).agg(lambda x: x.astype(float).min())
232+
1 1.0
233+
2 3.0
234+
dtype: float64"""
230235
)
231236

232237
@Appender(
@@ -566,8 +571,9 @@ def transform(self, func, *args, engine=None, engine_kwargs=None, **kwargs):
566571

567572
def _transform_general(self, func, *args, **kwargs):
568573
"""
569-
Transform with a non-str `func`.
574+
Transform with a callable func`.
570575
"""
576+
assert callable(func)
571577
klass = type(self._selected_obj)
572578

573579
results = []
@@ -589,13 +595,6 @@ def _transform_general(self, func, *args, **kwargs):
589595
result = self._set_result_index_ordered(concatenated)
590596
else:
591597
result = self.obj._constructor(dtype=np.float64)
592-
# we will only try to coerce the result type if
593-
# we have a numeric dtype, as these are *always* user-defined funcs
594-
# the cython take a different path (and casting)
595-
if is_numeric_dtype(result.dtype):
596-
common_dtype = find_common_type([self._selected_obj.dtype, result.dtype])
597-
if common_dtype is result.dtype:
598-
result = maybe_downcast_numeric(result, self._selected_obj.dtype)
599598

600599
result.name = self._selected_obj.name
601600
return result
@@ -625,7 +624,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
625624
Notes
626625
-----
627626
Functions that mutate the passed object can produce unexpected
628-
behavior or errors and are not supported. See :ref:`udf-mutation`
627+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
629628
for more details.
630629
631630
Examples
@@ -1006,7 +1005,17 @@ class DataFrameGroupBy(GroupBy[DataFrame]):
10061005
``['column', 'aggfunc']`` to make it clearer what the arguments are.
10071006
As usual, the aggregation can be a callable or a string alias.
10081007
1009-
See :ref:`groupby.aggregate.named` for more."""
1008+
See :ref:`groupby.aggregate.named` for more.
1009+
1010+
.. versionchanged:: 1.3.0
1011+
1012+
The resulting dtype will reflect the return value of the aggregating function.
1013+
1014+
>>> df.groupby("A")[["B"]].agg(lambda x: x.astype(float).min())
1015+
B
1016+
A
1017+
1 1.0
1018+
2 3.0"""
10101019
)
10111020

10121021
@doc(_agg_template, examples=_agg_examples_doc, klass="DataFrame")
@@ -1533,7 +1542,7 @@ def filter(self, func, dropna=True, *args, **kwargs):
15331542
which group you are working on.
15341543
15351544
Functions that mutate the passed object can produce unexpected
1536-
behavior or errors and are not supported. See :ref:`udf-mutation`
1545+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
15371546
for more details.
15381547
15391548
Examples

pandas/core/groupby/groupby.py

+50-17
Original file line numberDiff line numberDiff line change
@@ -160,14 +160,19 @@ class providing the base-class of operations.
160160
side-effects, as they will take effect twice for the first
161161
group.
162162
163+
.. versionchanged:: 1.3.0
164+
165+
The resulting dtype will reflect the return value of the passed ``func``,
166+
see the examples below.
167+
163168
Examples
164169
--------
165170
{examples}
166171
""",
167172
"dataframe_examples": """
168173
>>> df = pd.DataFrame({'A': 'a a b'.split(),
169174
... 'B': [1,2,3],
170-
... 'C': [4,6, 5]})
175+
... 'C': [4,6,5]})
171176
>>> g = df.groupby('A')
172177
173178
Notice that ``g`` has two groups, ``a`` and ``b``.
@@ -185,13 +190,17 @@ class providing the base-class of operations.
185190
186191
Example 2: The function passed to `apply` takes a DataFrame as
187192
its argument and returns a Series. `apply` combines the result for
188-
each group together into a new DataFrame:
193+
each group together into a new DataFrame.
194+
195+
.. versionchanged:: 1.3.0
189196
190-
>>> g[['B', 'C']].apply(lambda x: x.max() - x.min())
191-
B C
197+
The resulting dtype will reflect the return value of the passed ``func``.
198+
199+
>>> g[['B', 'C']].apply(lambda x: x.astype(float).max() - x.min())
200+
B C
192201
A
193-
a 1 2
194-
b 0 0
202+
a 1.0 2.0
203+
b 0.0 0.0
195204
196205
Example 3: The function passed to `apply` takes a DataFrame as
197206
its argument and returns a scalar. `apply` combines the result for
@@ -212,12 +221,16 @@ class providing the base-class of operations.
212221
213222
Example 1: The function passed to `apply` takes a Series as
214223
its argument and returns a Series. `apply` combines the result for
215-
each group together into a new Series:
224+
each group together into a new Series.
225+
226+
.. versionchanged:: 1.3.0
216227
217-
>>> g.apply(lambda x: x*2 if x.name == 'b' else x/2)
228+
The resulting dtype will reflect the return value of the passed ``func``.
229+
230+
>>> g.apply(lambda x: x*2 if x.name == 'a' else x/2)
218231
a 0.0
219-
a 0.5
220-
b 4.0
232+
a 2.0
233+
b 1.0
221234
dtype: float64
222235
223236
Example 2: The function passed to `apply` takes a Series as
@@ -369,12 +382,17 @@ class providing the base-class of operations.
369382
in the subframe. If f also supports application to the entire subframe,
370383
then a fast path is used starting from the second chunk.
371384
* f must not mutate groups. Mutation is not supported and may
372-
produce unexpected results. See :ref:`udf-mutation` for more details.
385+
produce unexpected results. See :ref:`gotchas.udf-mutation` for more details.
373386
374387
When using ``engine='numba'``, there will be no "fall back" behavior internally.
375388
The group data and group index will be passed as numpy arrays to the JITed
376389
user defined function, and no alternative execution attempts will be tried.
377390
391+
.. versionchanged:: 1.3.0
392+
393+
The resulting dtype will reflect the return value of the passed ``func``,
394+
see the examples below.
395+
378396
Examples
379397
--------
380398
@@ -404,6 +422,20 @@ class providing the base-class of operations.
404422
3 3 8.0
405423
4 4 6.0
406424
5 3 8.0
425+
426+
.. versionchanged:: 1.3.0
427+
428+
The resulting dtype will reflect the return value of the passed ``func``,
429+
for example:
430+
431+
>>> grouped[['C', 'D']].transform(lambda x: x.astype(int).max())
432+
C D
433+
0 5 8
434+
1 5 9
435+
2 5 8
436+
3 5 9
437+
4 5 8
438+
5 5 9
407439
"""
408440

409441
_agg_template = """
@@ -471,12 +503,16 @@ class providing the base-class of operations.
471503
When using ``engine='numba'``, there will be no "fall back" behavior internally.
472504
The group data and group index will be passed as numpy arrays to the JITed
473505
user defined function, and no alternative execution attempts will be tried.
474-
{examples}
475506
476507
Functions that mutate the passed object can produce unexpected
477-
behavior or errors and are not supported. See :ref:`udf-mutation`
508+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
478509
for more details.
479-
"""
510+
511+
.. versionchanged:: 1.3.0
512+
513+
The resulting dtype will reflect the return value of the passed ``func``,
514+
see the examples below.
515+
{examples}"""
480516

481517

482518
@final
@@ -1237,9 +1273,6 @@ def _python_agg_general(self, func, *args, **kwargs):
12371273
assert result is not None
12381274
key = base.OutputKey(label=name, position=idx)
12391275

1240-
if is_numeric_dtype(obj.dtype):
1241-
result = maybe_downcast_numeric(result, obj.dtype)
1242-
12431276
if self.grouper._filter_empty_groups:
12441277
mask = counts.ravel() > 0
12451278

pandas/core/series.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4190,7 +4190,7 @@ def apply(
41904190
Notes
41914191
-----
41924192
Functions that mutate the passed object can produce unexpected
4193-
behavior or errors and are not supported. See :ref:`udf-mutation`
4193+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
41944194
for more details.
41954195
41964196
Examples

pandas/core/shared_docs.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
`agg` is an alias for `aggregate`. Use the alias.
4343
4444
Functions that mutate the passed object can produce unexpected
45-
behavior or errors and are not supported. See :ref:`udf-mutation`
45+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
4646
for more details.
4747
4848
A passed user-defined-function will be passed a Series for evaluation.
@@ -303,7 +303,7 @@
303303
Notes
304304
-----
305305
Functions that mutate the passed object can produce unexpected
306-
behavior or errors and are not supported. See :ref:`udf-mutation`
306+
behavior or errors and are not supported. See :ref:`gotchas.udf-mutation`
307307
for more details.
308308
309309
Examples

0 commit comments

Comments
 (0)