Skip to content

Commit 9eec5bf

Browse files
authored
API: dont infer dtype for object-dtype groupby reductions (#51205)
* API: dont infer dtype for object-dtype groupby reductions * GH ref
1 parent f33105f commit 9eec5bf

File tree

7 files changed

+30
-9
lines changed

7 files changed

+30
-9
lines changed

doc/source/whatsnew/v2.0.0.rst

+2
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,9 @@ Other API changes
790790
- The levels of the index of the :class:`Series` returned from ``Series.sparse.from_coo`` now always have dtype ``int32``. Previously they had dtype ``int64`` (:issue:`50926`)
791791
- :func:`to_datetime` with ``unit`` of either "Y" or "M" will now raise if a sequence contains a non-round ``float`` value, matching the ``Timestamp`` behavior (:issue:`50301`)
792792
- The methods :meth:`Series.round`, :meth:`DataFrame.__invert__`, :meth:`Series.__invert__`, :meth:`DataFrame.swapaxes`, :meth:`DataFrame.first`, :meth:`DataFrame.last`, :meth:`Series.first`, :meth:`Series.last` and :meth:`DataFrame.align` will now always return new objects (:issue:`51032`)
793+
- :class:`DataFrameGroupBy` aggregations (e.g. "sum") with object-dtype columns no longer infer non-object dtypes for their results, explicitly call ``result.infer_objects(copy=False)`` on the result to obtain the old behavior (:issue:`51205`)
793794
- Added :func:`pandas.api.types.is_any_real_numeric_dtype` to check for real numeric dtypes (:issue:`51152`)
795+
-
794796

795797
.. ---------------------------------------------------------------------------
796798
.. _whatsnew_200.deprecations:

pandas/core/groupby/groupby.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,9 @@ def _agg_py_fallback(
14951495
# TODO: if we ever get "rank" working, exclude it here.
14961496
res_values = type(values)._from_sequence(res_values, dtype=values.dtype)
14971497

1498+
elif ser.dtype == object:
1499+
res_values = res_values.astype(object, copy=False)
1500+
14981501
# If we are DataFrameGroupBy and went through a SeriesGroupByPath
14991502
# then we need to reshape
15001503
# GH#32223 includes case with IntegerArray values, ndarray res_values
@@ -1537,8 +1540,7 @@ def array_func(values: ArrayLike) -> ArrayLike:
15371540
new_mgr = data.grouped_reduce(array_func)
15381541
res = self._wrap_agged_manager(new_mgr)
15391542
out = self._wrap_aggregated_output(res)
1540-
if data.ndim == 2:
1541-
# TODO: don't special-case DataFrame vs Series
1543+
if self.axis == 1:
15421544
out = out.infer_objects(copy=False)
15431545
return out
15441546

pandas/tests/groupby/aggregate/test_aggregate.py

+2
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def test_multiindex_groupby_mixed_cols_axis1(func, expected, dtype, result_dtype
258258
expected = DataFrame([expected] * 3, columns=["i", "j", "k"]).astype(
259259
result_dtype_dict
260260
)
261+
261262
tm.assert_frame_equal(result, expected)
262263

263264

@@ -675,6 +676,7 @@ def test_agg_split_object_part_datetime():
675676
"F": [1],
676677
},
677678
index=np.array([0]),
679+
dtype=object,
678680
)
679681
tm.assert_frame_equal(result, expected)
680682

pandas/tests/groupby/aggregate/test_other.py

+1
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,7 @@ def test_sum_uint64_overflow():
517517
expected = DataFrame(
518518
{1: [9223372036854775809, 9223372036854775811, 9223372036854775813]},
519519
index=index,
520+
dtype=object,
520521
)
521522

522523
expected.index.name = 0

pandas/tests/groupby/test_function.py

+6
Original file line numberDiff line numberDiff line change
@@ -1509,6 +1509,12 @@ def test_deprecate_numeric_only_series(dtype, groupby_func, request):
15091509
"sum",
15101510
"diff",
15111511
"pct_change",
1512+
"var",
1513+
"mean",
1514+
"median",
1515+
"min",
1516+
"max",
1517+
"prod",
15121518
)
15131519

15141520
# Test default behavior; kernels that fail may be enabled in the future but kernels

pandas/tests/groupby/test_groupby.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2380,7 +2380,9 @@ def test_groupby_duplicate_columns():
23802380
).astype(object)
23812381
df.columns = ["A", "B", "B"]
23822382
result = df.groupby([0, 0, 0, 0]).min()
2383-
expected = DataFrame([["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"])
2383+
expected = DataFrame(
2384+
[["e", "a", 1]], index=np.array([0]), columns=["A", "B", "B"], dtype=object
2385+
)
23842386
tm.assert_frame_equal(result, expected)
23852387

23862388

pandas/tests/groupby/test_min_max.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -148,19 +148,25 @@ def test_aggregate_numeric_object_dtype():
148148
{"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": [np.nan] * 4},
149149
).astype(object)
150150
result = df.groupby("key").min()
151-
expected = DataFrame(
152-
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]}
153-
).set_index("key")
151+
expected = (
152+
DataFrame(
153+
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [np.nan, np.nan]},
154+
)
155+
.set_index("key")
156+
.astype(object)
157+
)
154158
tm.assert_frame_equal(result, expected)
155159

156160
# same but with numbers
157161
df = DataFrame(
158162
{"key": ["A", "A", "B", "B"], "col1": list("abcd"), "col2": range(4)},
159163
).astype(object)
160164
result = df.groupby("key").min()
161-
expected = DataFrame(
162-
{"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]}
163-
).set_index("key")
165+
expected = (
166+
DataFrame({"key": ["A", "B"], "col1": ["a", "c"], "col2": [0, 2]})
167+
.set_index("key")
168+
.astype(object)
169+
)
164170
tm.assert_frame_equal(result, expected)
165171

166172

0 commit comments

Comments
 (0)