Skip to content

REGR: Avoid overflow with groupby sum #48059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pandas/_libs/algos.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def ensure_int8(arr: object, copy=...) -> npt.NDArray[np.int8]: ...
def ensure_int16(arr: object, copy=...) -> npt.NDArray[np.int16]: ...
def ensure_int32(arr: object, copy=...) -> npt.NDArray[np.int32]: ...
def ensure_int64(arr: object, copy=...) -> npt.NDArray[np.int64]: ...
def ensure_uint64(arr: object, copy=...) -> npt.NDArray[np.uint64]: ...
def take_1d_int8_int8(
values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=...
) -> None: ...
Expand Down
4 changes: 2 additions & 2 deletions pandas/_libs/algos_common_helper.pxi.in
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ dtypes = [('float64', 'FLOAT64', 'float64'),
('int16', 'INT16', 'int16'),
('int32', 'INT32', 'int32'),
('int64', 'INT64', 'int64'),
('uint64', 'UINT64', 'uint64'),
# Disabling uint and complex dtypes because we do not use them
# (and compiling them increases wheel size)
# (and compiling them increases wheel size) (except uint64)
# ('uint8', 'UINT8', 'uint8'),
# ('uint16', 'UINT16', 'uint16'),
# ('uint32', 'UINT32', 'uint32'),
# ('uint64', 'UINT64', 'uint64'),
# ('complex64', 'COMPLEX64', 'complex64'),
# ('complex128', 'COMPLEX128', 'complex128')
]
Expand Down
7 changes: 0 additions & 7 deletions pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -513,14 +513,7 @@ ctypedef fused mean_t:

ctypedef fused sum_t:
mean_t
int8_t
int16_t
int32_t
int64_t

uint8_t
uint16_t
uint32_t
uint64_t
object

Expand Down
1 change: 1 addition & 0 deletions pandas/core/dtypes/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def ensure_float(arr):
ensure_int8 = algos.ensure_int8
ensure_platform_int = algos.ensure_platform_int
ensure_object = algos.ensure_object
ensure_uint64 = algos.ensure_uint64


def ensure_str(value: bytes | Any) -> str:
Expand Down
8 changes: 8 additions & 0 deletions pandas/core/groupby/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
ensure_float64,
ensure_int64,
ensure_platform_int,
ensure_uint64,
is_1d_only_ea_dtype,
is_bool_dtype,
is_complex_dtype,
Expand Down Expand Up @@ -224,6 +225,13 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
# result may still include NaN, so we have to cast
values = ensure_float64(values)

elif how == "sum":
# Avoid overflow during group op
if values.dtype.kind == "i":
values = ensure_int64(values)
else:
values = ensure_uint64(values)

return values

# TODO: general case implementation overridable by EAs.
Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/groupby/test_groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -2829,3 +2829,16 @@ def test_groupby_sum_support_mask(any_numeric_ea_dtype):
dtype=any_numeric_ea_dtype,
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")])
def test_groupby_sum_overflow(val, dtype):
# GH#37493
df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8")
result = df.groupby("a").sum()
expected = DataFrame(
{"b": [val * 2]},
index=Index([1], name="a", dtype=f"{dtype}64"),
dtype=f"{dtype}64",
)
tm.assert_frame_equal(result, expected)