Skip to content

ENH: Support mask properly in GroupBy.quantile #48496

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v1.6.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ enhancement2

Other enhancements
^^^^^^^^^^^^^^^^^^
- :meth:`.GroupBy.quantile` now preserving nullable dtypes instead of casting to numpy dtypes (:issue:`37493`)
- :meth:`Series.add_suffix`, :meth:`DataFrame.add_suffix`, :meth:`Series.add_prefix` and :meth:`DataFrame.add_prefix` support an ``axis`` argument. If ``axis`` is set, the default behaviour of which axis to consider can be overwritten (:issue:`47819`)
-

Expand Down
1 change: 1 addition & 0 deletions pandas/_libs/groupby.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def group_quantile(
sort_indexer: npt.NDArray[np.intp], # const
qs: npt.NDArray[np.float64], # const
interpolation: Literal["linear", "lower", "higher", "nearest", "midpoint"],
result_mask: np.ndarray | None = ...,
) -> None: ...
def group_last(
out: np.ndarray, # rank_t[:, ::1]
Expand Down
7 changes: 6 additions & 1 deletion pandas/_libs/groupby.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,7 @@ def group_quantile(
const intp_t[:] sort_indexer,
const float64_t[:] qs,
str interpolation,
uint8_t[:, ::1] result_mask=None,
) -> None:
"""
Calculate the quantile per group.
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def group_quantile(
InterpolationEnumType interp
float64_t q_val, q_idx, frac, val, next_val
int64_t[::1] counts, non_na_counts
bint uses_result_mask = result_mask is not None

assert values.shape[0] == N

Expand Down Expand Up @@ -1148,7 +1150,10 @@ def group_quantile(

if non_na_sz == 0:
for k in range(nqs):
out[i, k] = NaN
if uses_result_mask:
result_mask[i, k] = 1
else:
out[i, k] = NaN
else:
for k in range(nqs):
q_val = qs[k]
Expand Down
63 changes: 54 additions & 9 deletions pandas/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class providing the base-class of operations.
import pandas._libs.groupby as libgroupby
from pandas._typing import (
ArrayLike,
Dtype,
IndexLabel,
NDFrameT,
PositionalIndexer,
Expand Down Expand Up @@ -92,6 +93,7 @@ class providing the base-class of operations.
BooleanArray,
Categorical,
ExtensionArray,
FloatingArray,
)
from pandas.core.base import (
PandasObject,
Expand Down Expand Up @@ -3247,14 +3249,17 @@ def quantile(
f"numeric_only={numeric_only} and dtype {self.obj.dtype}"
)

def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:
def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]:
if is_object_dtype(vals):
raise TypeError(
"'quantile' cannot be performed against 'object' dtypes!"
)

inference: np.dtype | None = None
if is_integer_dtype(vals.dtype):
inference: Dtype | None = None
if isinstance(vals, BaseMaskedArray) and is_numeric_dtype(vals.dtype):
out = vals.to_numpy(dtype=float, na_value=np.nan)
inference = vals.dtype
elif is_integer_dtype(vals.dtype):
if isinstance(vals, ExtensionArray):
out = vals.to_numpy(dtype=float, na_value=np.nan)
else:
Expand All @@ -3276,14 +3281,38 @@ def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, np.dtype | None]:

return out, inference

def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
def post_processor(
vals: np.ndarray,
inference: Dtype | None,
result_mask: np.ndarray | None,
orig_vals: ArrayLike,
) -> ArrayLike:
if inference:
# Check for edge case
if not (
if isinstance(orig_vals, BaseMaskedArray):
assert result_mask is not None # for mypy

if interpolation in {"linear", "midpoint"} and not is_float_dtype(
orig_vals
):
return FloatingArray(vals, result_mask)
else:
# Item "ExtensionDtype" of "Union[ExtensionDtype, str,
# dtype[Any], Type[object]]" has no attribute "numpy_dtype"
# [union-attr]
return type(orig_vals)(
vals.astype(
inference.numpy_dtype # type: ignore[union-attr]
),
result_mask,
)

elif not (
is_integer_dtype(inference)
and interpolation in {"linear", "midpoint"}
):
vals = vals.astype(inference)
assert isinstance(inference, np.dtype) # for mypy
return vals.astype(inference)

return vals

Expand All @@ -3306,7 +3335,14 @@ def post_processor(vals: np.ndarray, inference: np.dtype | None) -> np.ndarray:
labels_for_lexsort = np.where(ids == -1, na_label_for_sorting, ids)

def blk_func(values: ArrayLike) -> ArrayLike:
mask = isna(values)
orig_vals = values
if isinstance(values, BaseMaskedArray):
mask = values._mask
result_mask = np.zeros((ngroups, nqs), dtype=np.bool_)
else:
mask = isna(values)
result_mask = None

vals, inference = pre_processor(values)

ncols = 1
Expand All @@ -3325,16 +3361,25 @@ def blk_func(values: ArrayLike) -> ArrayLike:
sort_arr = np.lexsort(order).astype(np.intp, copy=False)

if vals.ndim == 1:
func(out[0], values=vals, mask=mask, sort_indexer=sort_arr)
# Ea is always 1d
func(
out[0],
values=vals,
mask=mask,
sort_indexer=sort_arr,
result_mask=result_mask,
)
else:
for i in range(ncols):
func(out[i], values=vals[i], mask=mask[i], sort_indexer=sort_arr[i])

if vals.ndim == 1:
out = out.ravel("K")
if result_mask is not None:
result_mask = result_mask.ravel("K")
else:
out = out.reshape(ncols, ngroups * nqs)
return post_processor(out, inference)
return post_processor(out, inference, result_mask, orig_vals)

obj = self._obj_with_exclusions
is_ser = obj.ndim == 1
Expand Down
75 changes: 69 additions & 6 deletions pandas/tests/groupby/test_quantile.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ def test_groupby_quantile_nullable_array(values, q):
idx = Index(["x", "y"], name="a")
true_quantiles = [0.5]

expected = pd.Series(true_quantiles * 2, index=idx, name="b")
expected = pd.Series(true_quantiles * 2, index=idx, name="b", dtype="Float64")
tm.assert_series_equal(result, expected)


Expand Down Expand Up @@ -266,14 +266,21 @@ def test_groupby_quantile_NA_float(any_float_dtype):
df = DataFrame({"x": [1, 1], "y": [0.2, np.nan]}, dtype=any_float_dtype)
result = df.groupby("x")["y"].quantile(0.5)
exp_index = Index([1.0], dtype=any_float_dtype, name="x")
expected = pd.Series([0.2], dtype=float, index=exp_index, name="y")
tm.assert_series_equal(expected, result)

if any_float_dtype in ["Float32", "Float64"]:
expected_dtype = any_float_dtype
else:
expected_dtype = None

expected = pd.Series([0.2], dtype=expected_dtype, index=exp_index, name="y")
tm.assert_series_equal(result, expected)

result = df.groupby("x")["y"].quantile([0.5, 0.75])
expected = pd.Series(
[0.2] * 2,
index=pd.MultiIndex.from_product((exp_index, [0.5, 0.75]), names=["x", None]),
name="y",
dtype=expected_dtype,
)
tm.assert_series_equal(result, expected)

Expand All @@ -283,12 +290,68 @@ def test_groupby_quantile_NA_int(any_int_ea_dtype):
df = DataFrame({"x": [1, 1], "y": [2, 5]}, dtype=any_int_ea_dtype)
result = df.groupby("x")["y"].quantile(0.5)
expected = pd.Series(
[3.5], dtype=float, index=Index([1], name="x", dtype=any_int_ea_dtype), name="y"
[3.5],
dtype="Float64",
index=Index([1], name="x", dtype=any_int_ea_dtype),
name="y",
)
tm.assert_series_equal(expected, result)

result = df.groupby("x").quantile(0.5)
expected = DataFrame({"y": 3.5}, index=Index([1], name="x", dtype=any_int_ea_dtype))
expected = DataFrame(
{"y": 3.5}, dtype="Float64", index=Index([1], name="x", dtype=any_int_ea_dtype)
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize(
"interpolation, val1, val2", [("lower", 2, 2), ("higher", 2, 3), ("nearest", 2, 2)]
)
def test_groupby_quantile_all_na_group_masked(
interpolation, val1, val2, any_numeric_ea_dtype
):
# GH#37493
df = DataFrame(
{"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype
)
result = df.groupby("a").quantile(q=[0.5, 0.7], interpolation=interpolation)
expected = DataFrame(
{"b": [val1, val2, pd.NA, pd.NA]},
dtype=any_numeric_ea_dtype,
index=pd.MultiIndex.from_arrays(
[pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype), [0.5, 0.7, 0.5, 0.7]],
names=["a", None],
),
)
tm.assert_frame_equal(result, expected)


@pytest.mark.parametrize("interpolation", ["midpoint", "linear"])
def test_groupby_quantile_all_na_group_masked_interp(
interpolation, any_numeric_ea_dtype
):
# GH#37493
df = DataFrame(
{"a": [1, 1, 1, 2], "b": [1, 2, 3, pd.NA]}, dtype=any_numeric_ea_dtype
)
result = df.groupby("a").quantile(q=[0.5, 0.75], interpolation=interpolation)

if any_numeric_ea_dtype == "Float32":
expected_dtype = any_numeric_ea_dtype
else:
expected_dtype = "Float64"

expected = DataFrame(
{"b": [2.0, 2.5, pd.NA, pd.NA]},
dtype=expected_dtype,
index=pd.MultiIndex.from_arrays(
[
pd.Series([1, 1, 2, 2], dtype=any_numeric_ea_dtype),
[0.5, 0.75, 0.5, 0.75],
],
names=["a", None],
),
)
tm.assert_frame_equal(result, expected)


Expand All @@ -298,7 +361,7 @@ def test_groupby_quantile_allNA_column(dtype):
df = DataFrame({"x": [1, 1], "y": [pd.NA] * 2}, dtype=dtype)
result = df.groupby("x")["y"].quantile(0.5)
expected = pd.Series(
[np.nan], dtype=float, index=Index([1.0], dtype=dtype), name="y"
[np.nan], dtype=dtype, index=Index([1.0], dtype=dtype), name="y"
)
expected.index.name = "x"
tm.assert_series_equal(expected, result)
Expand Down