Skip to content

Commit b44db5c

Browse files
authored
ENH: Support mask properly in GroupBy.quantile (#48496)
* ENH: Support mask properly in GroupBy.quantile * Fix mypy * Fix typo * Add type hint
1 parent bed6b61 commit b44db5c

File tree

5 files changed

+131
-16
lines changed

5 files changed

+131
-16
lines changed

doc/source/whatsnew/v1.6.0.rst

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ enhancement2
2828

2929
Other enhancements
3030
^^^^^^^^^^^^^^^^^^
31+
- :meth:`.GroupBy.quantile` now preserving nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
3132
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
3233
- :func:`assert_frame_equal` now shows the first element where the DataFrames differ, analogously to ``pytest``'s output (:issue:`47910`)
3334
-

pandas/_libs/groupby.pyi

+1
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def group_quantile(
110110
sort_indexer: npt.NDArray[np.intp], # const
111111
qs: npt.NDArray[np.float64], # const
112112
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
113+
result_mask: np.ndarray | None = ...,
113114
) -> None: ...
114115
def group_last(
115116
out: np.ndarray, # rank_t[:, ::1]

pandas/_libs/groupby.pyx

+6-1
Original file line numberDiff line numberDiff line change
@@ -1076,6 +1076,7 @@ def group_quantile(
10761076
const intp_t[:] sort_indexer,
10771077
const float64_t[:] qs,
10781078
str interpolation,
1079+
uint8_t[:, ::1] result_mask=None,
10791080
) -> None:
10801081
"""
10811082
Calculate the quantile per group.
@@ -1106,6 +1107,7 @@ def group_quantile(
11061107
InterpolationEnumType interp
11071108
float64_t q_val, q_idx, frac, val, next_val
11081109
int64_t[::1] counts, non_na_counts
1110+
bint uses_result_mask = result_mask is not None
11091111

11101112
assert values.shape[0] == N
11111113

@@ -1148,7 +1150,10 @@ def group_quantile(
11481150

11491151
if non_na_sz == 0:
11501152
for k in range(nqs):
1151-
out[i, k] = NaN
1153+
if uses_result_mask:
1154+
result_mask[i, k] = 1
1155+
else:
1156+
out[i, k] = NaN
11521157
else:
11531158
for k in range(nqs):
11541159
q_val = qs[k]

pandas/core/groupby/groupby.py

+54-9
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class providing the base-class of operations.
4646
import pandas._libs.groupby as libgroupby
4747
from pandas._typing import (
4848
ArrayLike,
49+
Dtype,
4950
IndexLabel,
5051
NDFrameT,
5152
PositionalIndexer,
@@ -92,6 +93,7 @@ class providing the base-class of operations.
9293
BooleanArray,
9394
Categorical,
9495
ExtensionArray,
96+
FloatingArray,
9597
)
9698
from pandas.core.base import (
9799
PandasObject,
@@ -3247,14 +3249,17 @@ def quantile(
32473249
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
32483250
)
32493251

3250-
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:
3252+
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]:
32513253
if is_object_dtype(vals):
32523254
raise TypeError(
32533255
"'quantile' cannot be performed against 'object' dtypes!"
32543256
)
32553257

3256-
inference: np.dtype | None = None
3257-
if is_integer_dtype(vals.dtype):
3258+
inference: Dtype | None = None
3259+
if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype):
3260+
out = vals.to_numpy(dtype=float, na_value=np.nan)
3261+
inference = vals.dtype
3262+
elif is_integer_dtype(vals.dtype):
32583263
if isinstance(vals, ExtensionArray):
32593264
out = vals.to_numpy(dtype=float, na_value=np.nan)
32603265
else:
@@ -3276,14 +3281,38 @@ def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:
32763281

32773282
return out, inference
32783283

3279-
def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
3284+
def post_processor(
3285+
vals: np.ndarray,
3286+
inference: Dtype | None,
3287+
result_mask: np.ndarray | None,
3288+
orig_vals: ArrayLike,
3289+
) -> ArrayLike:
32803290
if inference:
32813291
# Check for edge case
3282-
if not (
3292+
if isinstance(orig_vals, BaseMaskedArray):
3293+
assert result_mask is not None # for mypy
3294+
3295+
if interpolation in {"linear", "midpoint"} and not is_float_dtype(
3296+
orig_vals
3297+
):
3298+
return FloatingArray(vals, result_mask)
3299+
else:
3300+
# Item "ExtensionDtype" of "Union[ExtensionDtype, str,
3301+
# dtype[Any], Type[object]]" has no attribute "numpy_dtype"
3302+
# [union-attr]
3303+
return type(orig_vals)(
3304+
vals.astype(
3305+
inference.numpy_dtype # type: ignore[union-attr]
3306+
),
3307+
result_mask,
3308+
)
3309+
3310+
elif not (
32833311
is_integer_dtype(inference)
32843312
and interpolation in {"linear", "midpoint"}
32853313
):
3286-
vals = vals.astype(inference)
3314+
assert isinstance(inference, np.dtype) # for mypy
3315+
return vals.astype(inference)
32873316

32883317
return vals
32893318

@@ -3306,7 +3335,14 @@ def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
33063335
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)
33073336

33083337
def blk_func(values: ArrayLike) -> ArrayLike:
3309-
mask = isna(values)
3338+
orig_vals = values
3339+
if isinstance(values, BaseMaskedArray):
3340+
mask = values._mask
3341+
result_mask = np.zeros((ngroups, nqs), dtype=np.bool_)
3342+
else:
3343+
mask = isna(values)
3344+
result_mask = None
3345+
33103346
vals, inference = pre_processor(values)
33113347

33123348
ncols = 1
@@ -3325,16 +3361,25 @@ def blk_func(values: ArrayLike) -> ArrayLike:
33253361
sort_arr = np.lexsort(order).astype(np.intp, copy=False)
33263362

33273363
if vals.ndim == 1:
3328-
func(out[0], values=vals, mask=mask, sort_indexer=sort_arr)
3364+
# Ea is always 1d
3365+
func(
3366+
out[0],
3367+
values=vals,
3368+
mask=mask,
3369+
sort_indexer=sort_arr,
3370+
result_mask=result_mask,
3371+
)
33293372
else:
33303373
for i in range(ncols):
33313374
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])
33323375

33333376
if vals.ndim == 1:
33343377
out = out.ravel("K")
3378+
if result_mask is not None:
3379+
result_mask = result_mask.ravel("K")
33353380
else:
33363381
out = out.reshape(ncols, ngroups * nqs)
3337-
return post_processor(out, inference)
3382+
return post_processor(out, inference, result_mask, orig_vals)
33383383

33393384
obj = self._obj_with_exclusions
33403385
is_ser = obj.ndim == 1

pandas/tests/groupby/test_quantile.py

+69-6
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def test_groupby_quantile_nullable_array(values, q):
237237
idx = Index(["x", "y"], name="a")
238238
true_quantiles = [0.5]
239239

240-
expected = pd.Series(true_quantiles * 2, index=idx, name="b")
240+
expected = pd.Series(true_quantiles * 2, index=idx, name="b", dtype="Float64")
241241
tm.assert_series_equal(result, expected)
242242

243243

@@ -266,14 +266,21 @@ def test_groupby_quantile_NA_float(any_float_dtype):
266266
df = DataFrame({"x": [1, 1], "y": [0.2, np.nan]}, dtype=any_float_dtype)
267267
result = df.groupby("x")["y"].quantile(0.5)
268268
exp_index = Index([1.0], dtype=any_float_dtype, name="x")
269-
expected = pd.Series([0.2], dtype=float, index=exp_index, name="y")
270-
tm.assert_series_equal(expected, result)
269+
270+
if any_float_dtype in ["Float32", "Float64"]:
271+
expected_dtype = any_float_dtype
272+
else:
273+
expected_dtype = None
274+
275+
expected = pd.Series([0.2], dtype=expected_dtype, index=exp_index, name="y")
276+
tm.assert_series_equal(result, expected)
271277

272278
result = df.groupby("x")["y"].quantile([0.5, 0.75])
273279
expected = pd.Series(
274280
[0.2] * 2,
275281
index=pd.MultiIndex.from_product((exp_index, [0.5, 0.75]), names=["x", None]),
276282
name="y",
283+
dtype=expected_dtype,
277284
)
278285
tm.assert_series_equal(result, expected)
279286

@@ -283,12 +290,68 @@ def test_groupby_quantile_NA_int(any_int_ea_dtype):
283290
df = DataFrame({"x": [1, 1], "y": [2, 5]}, dtype=any_int_ea_dtype)
284291
result = df.groupby("x")["y"].quantile(0.5)
285292
expected = pd.Series(
286-
[3.5], dtype=float, index=Index([1], name="x", dtype=any_int_ea_dtype), name="y"
293+
[3.5],
294+
dtype="Float64",
295+
index=Index([1], name="x", dtype=any_int_ea_dtype),
296+
name="y",
287297
)
288298
tm.assert_series_equal(expected, result)
289299

290300
result = df.groupby("x").quantile(0.5)
291-
expected = DataFrame({"y": 3.5}, index=Index([1], name="x", dtype=any_int_ea_dtype))
301+
expected = DataFrame(
302+
{"y": 3.5}, dtype="Float64", index=Index([1], name="x", dtype=any_int_ea_dtype)
303+
)
304+
tm.assert_frame_equal(result, expected)
305+
306+
307+
@pytest.mark.parametrize(
308+
"interpolation, val1, val2", [("lower", 2, 2), ("higher", 2, 3), ("nearest", 2, 2)]
309+
)
310+
def test_groupby_quantile_all_na_group_masked(
311+
interpolation, val1, val2, any_numeric_ea_dtype
312+
):
313+
# GH#37493
314+
df = DataFrame(
315+
{"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype
316+
)
317+
result = df.groupby("a").quantile(q=[0.5, 0.7], interpolation=interpolation)
318+
expected = DataFrame(
319+
{"b": [val1, val2, pd.NA, pd.NA]},
320+
dtype=any_numeric_ea_dtype,
321+
index=pd.MultiIndex.from_arrays(
322+
[pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype), [0.5, 0.7, 0.5, 0.7]],
323+
names=["a", None],
324+
),
325+
)
326+
tm.assert_frame_equal(result, expected)
327+
328+
329+
@pytest.mark.parametrize("interpolation", ["midpoint", "linear"])
330+
def test_groupby_quantile_all_na_group_masked_interp(
331+
interpolation, any_numeric_ea_dtype
332+
):
333+
# GH#37493
334+
df = DataFrame(
335+
{"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype
336+
)
337+
result = df.groupby("a").quantile(q=[0.5, 0.75], interpolation=interpolation)
338+
339+
if any_numeric_ea_dtype == "Float32":
340+
expected_dtype = any_numeric_ea_dtype
341+
else:
342+
expected_dtype = "Float64"
343+
344+
expected = DataFrame(
345+
{"b": [2.0, 2.5, pd.NA, pd.NA]},
346+
dtype=expected_dtype,
347+
index=pd.MultiIndex.from_arrays(
348+
[
349+
pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype),
350+
[0.5, 0.75, 0.5, 0.75],
351+
],
352+
names=["a", None],
353+
),
354+
)
292355
tm.assert_frame_equal(result, expected)
293356

294357

@@ -298,7 +361,7 @@ def test_groupby_quantile_allNA_column(dtype):
298361
df = DataFrame({"x": [1, 1], "y": [pd.NA] * 2}, dtype=dtype)
299362
result = df.groupby("x")["y"].quantile(0.5)
300363
expected = pd.Series(
301-
[np.nan], dtype=float, index=Index([1.0], dtype=dtype), name="y"
364+
[np.nan], dtype=dtype, index=Index([1.0], dtype=dtype), name="y"
302365
)
303366
expected.index.name = "x"
304367
tm.assert_series_equal(expected, result)

0 commit comments

Comments
 (0)