Skip to content

Commit 5bfa653

Browse files
ENH: use correct dtype in groupby cython ops when it is known (without try/except) (#38291)
Co-authored-by: Joris Van den Bossche <[email protected]>
1 parent 37f7bdc commit 5bfa653

File tree

6 files changed

+102
-10
lines changed

6 files changed

+102
-10
lines changed

pandas/core/dtypes/cast.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -357,12 +357,18 @@ def maybe_cast_result_dtype(dtype: DtypeObj, how: str) -> DtypeObj:
357357
The desired dtype of the result.
358358
"""
359359
from pandas.core.arrays.boolean import BooleanDtype
360-
from pandas.core.arrays.integer import Int64Dtype
361-
362-
if how in ["add", "cumsum", "sum"] and (dtype == np.dtype(bool)):
363-
return np.dtype(np.int64)
364-
elif how in ["add", "cumsum", "sum"] and isinstance(dtype, BooleanDtype):
365-
return Int64Dtype()
360+
from pandas.core.arrays.floating import Float64Dtype
361+
from pandas.core.arrays.integer import Int64Dtype, _IntegerDtype
362+
363+
if how in ["add", "cumsum", "sum", "prod"]:
364+
if dtype == np.dtype(bool):
365+
return np.dtype(np.int64)
366+
elif isinstance(dtype, (BooleanDtype, _IntegerDtype)):
367+
return Int64Dtype()
368+
elif how in ["mean", "median", "var"] and isinstance(
369+
dtype, (BooleanDtype, _IntegerDtype)
370+
):
371+
return Float64Dtype()
366372
return dtype
367373

368374

pandas/core/groupby/ops.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
is_datetime64_any_dtype,
4646
is_datetime64tz_dtype,
4747
is_extension_array_dtype,
48+
is_float_dtype,
4849
is_integer_dtype,
4950
is_numeric_dtype,
5051
is_period_dtype,
@@ -521,7 +522,19 @@ def _ea_wrap_cython_operation(
521522
res_values = self._cython_operation(
522523
kind, values, how, axis, min_count, **kwargs
523524
)
524-
result = maybe_cast_result(result=res_values, obj=orig_values, how=how)
525+
dtype = maybe_cast_result_dtype(orig_values.dtype, how)
526+
if is_extension_array_dtype(dtype):
527+
cls = dtype.construct_array_type()
528+
return cls._from_sequence(res_values, dtype=dtype)
529+
return res_values
530+
531+
elif is_float_dtype(values.dtype):
532+
# FloatingArray
533+
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan)
534+
res_values = self._cython_operation(
535+
kind, values, how, axis, min_count, **kwargs
536+
)
537+
result = type(orig_values)._from_sequence(res_values)
525538
return result
526539

527540
raise NotImplementedError(values.dtype)

pandas/tests/arrays/integer/test_arithmetic.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,10 @@ def test_reduce_to_float(op):
277277
result = getattr(df.groupby("A"), op)()
278278

279279
expected = pd.DataFrame(
280-
{"B": np.array([1.0, 3.0]), "C": integer_array([1, 3], dtype="Int64")},
280+
{
281+
"B": np.array([1.0, 3.0]),
282+
"C": pd.array([1, 3], dtype="Float64"),
283+
},
281284
index=pd.Index(["a", "b"], name="A"),
282285
)
283286
tm.assert_frame_equal(result, expected)

pandas/tests/groupby/aggregate/test_cython.py

+68
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import numpy as np
66
import pytest
77

8+
from pandas.core.dtypes.common import is_float_dtype
9+
810
import pandas as pd
911
from pandas import DataFrame, Index, NaT, Series, Timedelta, Timestamp, bdate_range
1012
import pandas._testing as tm
@@ -312,3 +314,69 @@ def test_cython_agg_nullable_int(op_name):
312314
# so for now just checking the values by casting to float
313315
result = result.astype("float64")
314316
tm.assert_series_equal(result, expected)
317+
318+
319+
@pytest.mark.parametrize("with_na", [True, False])
320+
@pytest.mark.parametrize(
321+
"op_name, action",
322+
[
323+
# ("count", "always_int"),
324+
("sum", "large_int"),
325+
# ("std", "always_float"),
326+
("var", "always_float"),
327+
# ("sem", "always_float"),
328+
("mean", "always_float"),
329+
("median", "always_float"),
330+
("prod", "large_int"),
331+
("min", "preserve"),
332+
("max", "preserve"),
333+
("first", "preserve"),
334+
("last", "preserve"),
335+
],
336+
)
337+
@pytest.mark.parametrize(
338+
"data",
339+
[
340+
pd.array([1, 2, 3, 4], dtype="Int64"),
341+
pd.array([1, 2, 3, 4], dtype="Int8"),
342+
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"),
343+
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"),
344+
pd.array([True, True, False, False], dtype="boolean"),
345+
],
346+
)
347+
def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na):
348+
if with_na:
349+
data[3] = pd.NA
350+
351+
df = DataFrame({"key": ["a", "a", "b", "b"], "col": data})
352+
grouped = df.groupby("key")
353+
354+
if action == "always_int":
355+
# always Int64
356+
expected_dtype = pd.Int64Dtype()
357+
elif action == "large_int":
358+
# for any int/bool use Int64, for float preserve dtype
359+
if is_float_dtype(data.dtype):
360+
expected_dtype = data.dtype
361+
else:
362+
expected_dtype = pd.Int64Dtype()
363+
elif action == "always_float":
364+
# for any int/bool use Float64, for float preserve dtype
365+
if is_float_dtype(data.dtype):
366+
expected_dtype = data.dtype
367+
else:
368+
expected_dtype = pd.Float64Dtype()
369+
elif action == "preserve":
370+
expected_dtype = data.dtype
371+
372+
result = getattr(grouped, op_name)()
373+
assert result["col"].dtype == expected_dtype
374+
375+
result = grouped.aggregate(op_name)
376+
assert result["col"].dtype == expected_dtype
377+
378+
result = getattr(grouped["col"], op_name)()
379+
assert result.dtype == expected_dtype
380+
381+
result = grouped["col"].aggregate(op_name)
382+
assert result.dtype == expected_dtype

pandas/tests/groupby/test_function.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1093,7 +1093,7 @@ def test_apply_to_nullable_integer_returns_float(values, function):
10931093
output = 0.5 if function == "var" else 1.5
10941094
arr = np.array([output] * 3, dtype=float)
10951095
idx = Index([1, 2, 3], dtype=object, name="a")
1096-
expected = DataFrame({"b": arr}, index=idx)
1096+
expected = DataFrame({"b": arr}, index=idx).astype("Float64")
10971097

10981098
groups = DataFrame(values, dtype="Int64").groupby("a")
10991099

pandas/tests/resample/test_datetime_index.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,9 @@ def test_resample_integerarray():
124124

125125
result = ts.resample("3T").mean()
126126
expected = Series(
127-
[1, 4, 7], index=pd.date_range("1/1/2000", periods=3, freq="3T"), dtype="Int64"
127+
[1, 4, 7],
128+
index=pd.date_range("1/1/2000", periods=3, freq="3T"),
129+
dtype="Float64",
128130
)
129131
tm.assert_series_equal(result, expected)
130132

0 commit comments

Comments
 (0)