Skip to content

Commit 1893b10

Browse files
phoflnoatamir
authored andcommitted
REGR: Avoid overflow with groupby sum (pandas-dev#48059)
* REGR: Avoid overflow with groupby sum * Add comment
1 parent 15f2f78 commit 1893b10

File tree

6 files changed

+25
-9
lines changed

6 files changed

+25
-9
lines changed

pandas/_libs/algos.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,7 @@ def ensure_int8(arr: object, copy=...) -> npt.NDArray[np.int8]: ...
132132
def ensure_int16(arr: object, copy=...) -> npt.NDArray[np.int16]: ...
133133
def ensure_int32(arr: object, copy=...) -> npt.NDArray[np.int32]: ...
134134
def ensure_int64(arr: object, copy=...) -> npt.NDArray[np.int64]: ...
135+
def ensure_uint64(arr: object, copy=...) -> npt.NDArray[np.uint64]: ...
135136
def take_1d_int8_int8(
136137
values: np.ndarray, indexer: npt.NDArray[np.intp], out: np.ndarray, fill_value=...
137138
) -> None: ...

pandas/_libs/algos_common_helper.pxi.in

+2-2
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,12 @@ dtypes = [('float64', 'FLOAT64', 'float64'),
4141
('int16', 'INT16', 'int16'),
4242
('int32', 'INT32', 'int32'),
4343
('int64', 'INT64', 'int64'),
44+
('uint64', 'UINT64', 'uint64'),
4445
# Disabling uint and complex dtypes because we do not use them
45-
# (and compiling them increases wheel size)
46+
# (and compiling them increases wheel size) (except uint64)
4647
# ('uint8', 'UINT8', 'uint8'),
4748
# ('uint16', 'UINT16', 'uint16'),
4849
# ('uint32', 'UINT32', 'uint32'),
49-
# ('uint64', 'UINT64', 'uint64'),
5050
# ('complex64', 'COMPLEX64', 'complex64'),
5151
# ('complex128', 'COMPLEX128', 'complex128')
5252
]

pandas/_libs/groupby.pyx

-7
Original file line numberDiff line numberDiff line change
@@ -513,14 +513,7 @@ ctypedef fused mean_t:
513513

514514
ctypedef fused sum_t:
515515
mean_t
516-
int8_t
517-
int16_t
518-
int32_t
519516
int64_t
520-
521-
uint8_t
522-
uint16_t
523-
uint32_t
524517
uint64_t
525518
object
526519

pandas/core/dtypes/common.py

+1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def ensure_float(arr):
100100
ensure_int8 = algos.ensure_int8
101101
ensure_platform_int = algos.ensure_platform_int
102102
ensure_object = algos.ensure_object
103+
ensure_uint64 = algos.ensure_uint64
103104

104105

105106
def ensure_str(value: bytes | Any) -> str:

pandas/core/groupby/ops.py

+8
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
ensure_float64,
4747
ensure_int64,
4848
ensure_platform_int,
49+
ensure_uint64,
4950
is_1d_only_ea_dtype,
5051
is_bool_dtype,
5152
is_complex_dtype,
@@ -224,6 +225,13 @@ def _get_cython_vals(self, values: np.ndarray) -> np.ndarray:
224225
# result may still include NaN, so we have to cast
225226
values = ensure_float64(values)
226227

228+
elif how == "sum":
229+
# Avoid overflow during group op
230+
if values.dtype.kind == "i":
231+
values = ensure_int64(values)
232+
else:
233+
values = ensure_uint64(values)
234+
227235
return values
228236

229237
# TODO: general case implementation overridable by EAs.

pandas/tests/groupby/test_groupby.py

+13
Original file line numberDiff line numberDiff line change
@@ -2829,3 +2829,16 @@ def test_groupby_sum_support_mask(any_numeric_ea_dtype):
28292829
dtype=any_numeric_ea_dtype,
28302830
)
28312831
tm.assert_frame_equal(result, expected)
2832+
2833+
2834+
@pytest.mark.parametrize("val, dtype", [(111, "int"), (222, "uint")])
2835+
def test_groupby_sum_overflow(val, dtype):
2836+
# GH#37493
2837+
df = DataFrame({"a": 1, "b": [val, val]}, dtype=f"{dtype}8")
2838+
result = df.groupby("a").sum()
2839+
expected = DataFrame(
2840+
{"b": [val * 2]},
2841+
index=Index([1], name="a", dtype=f"{dtype}64"),
2842+
dtype=f"{dtype}64",
2843+
)
2844+
tm.assert_frame_equal(result, expected)

0 commit comments

Comments
 (0)