-
-
Notifications
You must be signed in to change notification settings - Fork 18.4k
ENH: use correct dtype in groupby cython ops when it is known (without try/except) #38291
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
279c4d1
e6dc529
ea79027
97fcd22
b04d91f
202bee8
e888b3e
2566ec4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,6 +45,7 @@ | |
is_datetime64_any_dtype, | ||
is_datetime64tz_dtype, | ||
is_extension_array_dtype, | ||
is_float_dtype, | ||
is_integer_dtype, | ||
is_numeric_dtype, | ||
is_period_dtype, | ||
|
@@ -521,7 +522,19 @@ def _ea_wrap_cython_operation( | |
res_values = self._cython_operation( | ||
kind, values, how, axis, min_count, **kwargs | ||
) | ||
result = maybe_cast_result(result=res_values, obj=orig_values, how=how) | ||
dtype = maybe_cast_result_dtype(orig_values.dtype, how) | ||
if is_extension_array_dtype(dtype): | ||
cls = dtype.construct_array_type() | ||
return cls._from_sequence(res_values, dtype=dtype) | ||
return res_values | ||
|
||
elif is_float_dtype(values.dtype): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so would really like to move this entire wrapping to a method on EA / generic casting. We do this in multiple places (e.g. also on _reduce operatiosn), and this is likely leading to missing functionaility in various places. |
||
# FloatingArray | ||
values = values.to_numpy(values.dtype.numpy_dtype, na_value=np.nan) | ||
res_values = self._cython_operation( | ||
kind, values, how, axis, min_count, **kwargs | ||
) | ||
result = type(orig_values)._from_sequence(res_values) | ||
return result | ||
|
||
raise NotImplementedError(values.dtype) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,8 @@ | |
import numpy as np | ||
import pytest | ||
|
||
from pandas.core.dtypes.common import is_float_dtype | ||
|
||
import pandas as pd | ||
from pandas import DataFrame, Index, NaT, Series, Timedelta, Timestamp, bdate_range | ||
import pandas._testing as tm | ||
|
@@ -312,3 +314,69 @@ def test_cython_agg_nullable_int(op_name): | |
# so for now just checking the values by casting to float | ||
result = result.astype("float64") | ||
tm.assert_series_equal(result, expected) | ||
|
||
|
||
@pytest.mark.parametrize("with_na", [True, False]) | ||
@pytest.mark.parametrize( | ||
"op_name, action", | ||
[ | ||
# ("count", "always_int"), | ||
("sum", "large_int"), | ||
# ("std", "always_float"), | ||
("var", "always_float"), | ||
# ("sem", "always_float"), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BTW, the reason those are commented out, it because of And |
||
("mean", "always_float"), | ||
("median", "always_float"), | ||
("prod", "large_int"), | ||
("min", "preserve"), | ||
("max", "preserve"), | ||
("first", "preserve"), | ||
("last", "preserve"), | ||
], | ||
) | ||
@pytest.mark.parametrize( | ||
"data", | ||
[ | ||
pd.array([1, 2, 3, 4], dtype="Int64"), | ||
pd.array([1, 2, 3, 4], dtype="Int8"), | ||
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float32"), | ||
pd.array([0.1, 0.2, 0.3, 0.4], dtype="Float64"), | ||
pd.array([True, True, False, False], dtype="boolean"), | ||
], | ||
) | ||
def test_cython_agg_EA_known_dtypes(data, op_name, action, with_na): | ||
if with_na: | ||
data[3] = pd.NA | ||
|
||
df = DataFrame({"key": ["a", "a", "b", "b"], "col": data}) | ||
grouped = df.groupby("key") | ||
|
||
if action == "always_int": | ||
# always Int64 | ||
expected_dtype = pd.Int64Dtype() | ||
elif action == "large_int": | ||
# for any int/bool use Int64, for float preserve dtype | ||
if is_float_dtype(data.dtype): | ||
expected_dtype = data.dtype | ||
else: | ||
expected_dtype = pd.Int64Dtype() | ||
elif action == "always_float": | ||
# for any int/bool use Float64, for float preserve dtype | ||
if is_float_dtype(data.dtype): | ||
expected_dtype = data.dtype | ||
else: | ||
expected_dtype = pd.Float64Dtype() | ||
elif action == "preserve": | ||
expected_dtype = data.dtype | ||
|
||
result = getattr(grouped, op_name)() | ||
assert result["col"].dtype == expected_dtype | ||
|
||
result = grouped.aggregate(op_name) | ||
assert result["col"].dtype == expected_dtype | ||
|
||
result = getattr(grouped["col"], op_name)() | ||
assert result.dtype == expected_dtype | ||
|
||
result = grouped["col"].aggregate(op_name) | ||
assert result.dtype == expected_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i think the linter is going to complain about _IntegerDtype. we can either find a non-private thing to import or add it to the whitelist in scripts._validate_unwanted_patterns
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apparantly it's not complaining at the moment, but indeed something we can de-privatize internally