Skip to content

Commit 6f4c382

Browse files
authored
BUG: Groupby min/max with nullable dtypes (#42567)
1 parent b1f7e58 commit 6f4c382

File tree

6 files changed

+105
-8
lines changed

6 files changed

+105
-8
lines changed

doc/source/whatsnew/v1.4.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,7 @@ Groupby/resample/rolling
394394
^^^^^^^^^^^^^^^^^^^^^^^^
395395
- Fixed bug in :meth:`SeriesGroupBy.apply` where passing an unrecognized string argument failed to raise ``TypeError`` when the underlying ``Series`` is empty (:issue:`42021`)
396396
- Bug in :meth:`Series.rolling.apply`, :meth:`DataFrame.rolling.apply`, :meth:`Series.expanding.apply` and :meth:`DataFrame.expanding.apply` with ``engine="numba"`` where ``*args`` were being cached with the user passed function (:issue:`42287`)
397+
- Bug in :meth:`GroupBy.max` and :meth:`GroupBy.min` with nullable integer dtypes losing precision (:issue:`41743`)
397398
- Bug in :meth:`DataFrame.groupby.rolling.var` would calculate the rolling variance only on the first group (:issue:`42442`)
398399
- Bug in :meth:`GroupBy.shift` that would return the grouping columns if ``fill_value`` was not None (:issue:`41556`)
399400
- Bug in :meth:`SeriesGroupBy.nlargest` and :meth:`SeriesGroupBy.nsmallest` would have an inconsistent index when the input Series was sorted and ``n`` was greater than or equal to all group sizes (:issue:`15272`, :issue:`16345`, :issue:`29129`)

pandas/_libs/groupby.pyi

+4
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,17 @@ def group_max(
123123
values: np.ndarray, # ndarray[groupby_t, ndim=2]
124124
labels: np.ndarray, # const int64_t[:]
125125
min_count: int = ...,
126+
mask: np.ndarray | None = ...,
127+
result_mask: np.ndarray | None = ...,
126128
) -> None: ...
127129
def group_min(
128130
out: np.ndarray, # groupby_t[:, ::1]
129131
counts: np.ndarray, # int64_t[::1]
130132
values: np.ndarray, # ndarray[groupby_t, ndim=2]
131133
labels: np.ndarray, # const int64_t[:]
132134
min_count: int = ...,
135+
mask: np.ndarray | None = ...,
136+
result_mask: np.ndarray | None = ...,
133137
) -> None: ...
134138
def group_cummin(
135139
out: np.ndarray, # groupby_t[:, ::1]

pandas/_libs/groupby.pyx

+31-5
Original file line numberDiff line numberDiff line change
@@ -1182,7 +1182,9 @@ cdef group_min_max(groupby_t[:, ::1] out,
11821182
const intp_t[::1] labels,
11831183
Py_ssize_t min_count=-1,
11841184
bint is_datetimelike=False,
1185-
bint compute_max=True):
1185+
bint compute_max=True,
1186+
const uint8_t[:, ::1] mask=None,
1187+
uint8_t[:, ::1] result_mask=None):
11861188
"""
11871189
Compute minimum/maximum of columns of `values`, in row groups `labels`.
11881190
@@ -1203,6 +1205,12 @@ cdef group_min_max(groupby_t[:, ::1] out,
12031205
True if `values` contains datetime-like entries.
12041206
compute_max : bint, default True
12051207
True to compute group-wise max, False to compute min
1208+
mask : ndarray[bool, ndim=2], optional
1209+
If not None, indices represent missing values,
1210+
otherwise the mask will not be used
1211+
result_mask : ndarray[bool, ndim=2], optional
1212+
If not None, these specify locations in the output that are NA.
1213+
Modified in-place.
12061214
12071215
Notes
12081216
-----
@@ -1215,6 +1223,8 @@ cdef group_min_max(groupby_t[:, ::1] out,
12151223
ndarray[groupby_t, ndim=2] group_min_or_max
12161224
bint runtime_error = False
12171225
int64_t[:, ::1] nobs
1226+
bint uses_mask = mask is not None
1227+
bint isna_entry
12181228

12191229
# TODO(cython 3.0):
12201230
# Instead of `labels.shape[0]` use `len(labels)`
@@ -1249,7 +1259,12 @@ cdef group_min_max(groupby_t[:, ::1] out,
12491259
for j in range(K):
12501260
val = values[i, j]
12511261

1252-
if not _treat_as_na(val, is_datetimelike):
1262+
if uses_mask:
1263+
isna_entry = mask[i, j]
1264+
else:
1265+
isna_entry = _treat_as_na(val, is_datetimelike)
1266+
1267+
if not isna_entry:
12531268
nobs[lab, j] += 1
12541269
if compute_max:
12551270
if val > group_min_or_max[lab, j]:
@@ -1265,7 +1280,10 @@ cdef group_min_max(groupby_t[:, ::1] out,
12651280
runtime_error = True
12661281
break
12671282
else:
1268-
out[i, j] = nan_val
1283+
if uses_mask:
1284+
result_mask[i, j] = True
1285+
else:
1286+
out[i, j] = nan_val
12691287
else:
12701288
out[i, j] = group_min_or_max[i, j]
12711289

@@ -1282,7 +1300,9 @@ def group_max(groupby_t[:, ::1] out,
12821300
ndarray[groupby_t, ndim=2] values,
12831301
const intp_t[::1] labels,
12841302
Py_ssize_t min_count=-1,
1285-
bint is_datetimelike=False) -> None:
1303+
bint is_datetimelike=False,
1304+
const uint8_t[:, ::1] mask=None,
1305+
uint8_t[:, ::1] result_mask=None) -> None:
12861306
"""See group_min_max.__doc__"""
12871307
group_min_max(
12881308
out,
@@ -1292,6 +1312,8 @@ def group_max(groupby_t[:, ::1] out,
12921312
min_count=min_count,
12931313
is_datetimelike=is_datetimelike,
12941314
compute_max=True,
1315+
mask=mask,
1316+
result_mask=result_mask,
12951317
)
12961318

12971319

@@ -1302,7 +1324,9 @@ def group_min(groupby_t[:, ::1] out,
13021324
ndarray[groupby_t, ndim=2] values,
13031325
const intp_t[::1] labels,
13041326
Py_ssize_t min_count=-1,
1305-
bint is_datetimelike=False) -> None:
1327+
bint is_datetimelike=False,
1328+
const uint8_t[:, ::1] mask=None,
1329+
uint8_t[:, ::1] result_mask=None) -> None:
13061330
"""See group_min_max.__doc__"""
13071331
group_min_max(
13081332
out,
@@ -1312,6 +1336,8 @@ def group_min(groupby_t[:, ::1] out,
13121336
min_count=min_count,
13131337
is_datetimelike=is_datetimelike,
13141338
compute_max=False,
1339+
mask=mask,
1340+
result_mask=result_mask,
13151341
)
13161342

13171343

pandas/core/arrays/masked.py

+2
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ def __init__(self, values: np.ndarray, mask: np.ndarray, copy: bool = False):
123123
raise ValueError("values must be a 1D array")
124124
if mask.ndim != 1:
125125
raise ValueError("mask must be a 1D array")
126+
if values.shape != mask.shape:
127+
raise ValueError("values and mask must have same shape")
126128

127129
if copy:
128130
values = values.copy()

pandas/core/groupby/ops.py

+19-3
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self, kind: str, how: str):
138138
},
139139
}
140140

141-
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax"}
141+
_MASKED_CYTHON_FUNCTIONS = {"cummin", "cummax", "min", "max"}
142142

143143
_cython_arity = {"ohlc": 4} # OHLC
144144

@@ -404,6 +404,7 @@ def _masked_ea_wrap_cython_operation(
404404

405405
# Copy to ensure input and result masks don't end up shared
406406
mask = values._mask.copy()
407+
result_mask = np.zeros(ngroups, dtype=bool)
407408
arr = values._data
408409

409410
res_values = self._cython_op_ndim_compat(
@@ -412,13 +413,18 @@ def _masked_ea_wrap_cython_operation(
412413
ngroups=ngroups,
413414
comp_ids=comp_ids,
414415
mask=mask,
416+
result_mask=result_mask,
415417
**kwargs,
416418
)
419+
417420
dtype = self._get_result_dtype(orig_values.dtype)
418421
assert isinstance(dtype, BaseMaskedDtype)
419422
cls = dtype.construct_array_type()
420423

421-
return cls(res_values.astype(dtype.type, copy=False), mask)
424+
if self.kind != "aggregate":
425+
return cls(res_values.astype(dtype.type, copy=False), mask)
426+
else:
427+
return cls(res_values.astype(dtype.type, copy=False), result_mask)
422428

423429
@final
424430
def _cython_op_ndim_compat(
@@ -428,20 +434,24 @@ def _cython_op_ndim_compat(
428434
min_count: int,
429435
ngroups: int,
430436
comp_ids: np.ndarray,
431-
mask: np.ndarray | None,
437+
mask: np.ndarray | None = None,
438+
result_mask: np.ndarray | None = None,
432439
**kwargs,
433440
) -> np.ndarray:
434441
if values.ndim == 1:
435442
# expand to 2d, dispatch, then squeeze if appropriate
436443
values2d = values[None, :]
437444
if mask is not None:
438445
mask = mask[None, :]
446+
if result_mask is not None:
447+
result_mask = result_mask[None, :]
439448
res = self._call_cython_op(
440449
values2d,
441450
min_count=min_count,
442451
ngroups=ngroups,
443452
comp_ids=comp_ids,
444453
mask=mask,
454+
result_mask=result_mask,
445455
**kwargs,
446456
)
447457
if res.shape[0] == 1:
@@ -456,6 +466,7 @@ def _cython_op_ndim_compat(
456466
ngroups=ngroups,
457467
comp_ids=comp_ids,
458468
mask=mask,
469+
result_mask=result_mask,
459470
**kwargs,
460471
)
461472

@@ -468,6 +479,7 @@ def _call_cython_op(
468479
ngroups: int,
469480
comp_ids: np.ndarray,
470481
mask: np.ndarray | None,
482+
result_mask: np.ndarray | None,
471483
**kwargs,
472484
) -> np.ndarray: # np.ndarray[ndim=2]
473485
orig_values = values
@@ -493,6 +505,8 @@ def _call_cython_op(
493505
values = values.T
494506
if mask is not None:
495507
mask = mask.T
508+
if result_mask is not None:
509+
result_mask = result_mask.T
496510

497511
out_shape = self._get_output_shape(ngroups, values)
498512
func, values = self.get_cython_func_and_vals(values, is_numeric)
@@ -508,6 +522,8 @@ def _call_cython_op(
508522
values,
509523
comp_ids,
510524
min_count,
525+
mask=mask,
526+
result_mask=result_mask,
511527
is_datetimelike=is_datetimelike,
512528
)
513529
else:

pandas/tests/groupby/test_min_max.py

+48
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,51 @@ def test_aggregate_categorical_lost_index(func: str):
177177
expected["B"] = expected["B"].astype(ds.dtype)
178178

179179
tm.assert_frame_equal(result, expected)
180+
181+
182+
@pytest.mark.parametrize("dtype", ["Int64", "Int32", "Float64", "Float32", "boolean"])
183+
def test_groupby_min_max_nullable(dtype):
184+
if dtype == "Int64":
185+
# GH#41743 avoid precision loss
186+
ts = 1618556707013635762
187+
elif dtype == "boolean":
188+
ts = 0
189+
else:
190+
ts = 4.0
191+
192+
df = DataFrame({"id": [2, 2], "ts": [ts, ts + 1]})
193+
df["ts"] = df["ts"].astype(dtype)
194+
195+
gb = df.groupby("id")
196+
197+
result = gb.min()
198+
expected = df.iloc[:1].set_index("id")
199+
tm.assert_frame_equal(result, expected)
200+
201+
res_max = gb.max()
202+
expected_max = df.iloc[1:].set_index("id")
203+
tm.assert_frame_equal(res_max, expected_max)
204+
205+
result2 = gb.min(min_count=3)
206+
expected2 = DataFrame({"ts": [pd.NA]}, index=expected.index, dtype=dtype)
207+
tm.assert_frame_equal(result2, expected2)
208+
209+
res_max2 = gb.max(min_count=3)
210+
tm.assert_frame_equal(res_max2, expected2)
211+
212+
# Case with NA values
213+
df2 = DataFrame({"id": [2, 2, 2], "ts": [ts, pd.NA, ts + 1]})
214+
df2["ts"] = df2["ts"].astype(dtype)
215+
gb2 = df2.groupby("id")
216+
217+
result3 = gb2.min()
218+
tm.assert_frame_equal(result3, expected)
219+
220+
res_max3 = gb2.max()
221+
tm.assert_frame_equal(res_max3, expected_max)
222+
223+
result4 = gb2.min(min_count=100)
224+
tm.assert_frame_equal(result4, expected2)
225+
226+
res_max4 = gb2.max(min_count=100)
227+
tm.assert_frame_equal(res_max4, expected2)

0 commit comments

Comments
 (0)