From 1d2c2ea2b57b250d0c1ab70975fe2d3a4aefa121 Mon Sep 17 00:00:00 2001 From: Mabel Villalba Date: Mon, 25 May 2020 23:58:06 +0200 Subject: [PATCH] Backport PR #33644 on branch 1.0.x (BUG: Groupby quantiles incorrect bins) --- doc/source/whatsnew/v1.0.4.rst | 1 + pandas/_libs/groupby.pyx | 8 +++++++- pandas/tests/groupby/test_function.py | 29 +++++++++++++++++++++------ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/doc/source/whatsnew/v1.0.4.rst b/doc/source/whatsnew/v1.0.4.rst index 132681c00d79c..1f2d4e2dba370 100644 --- a/doc/source/whatsnew/v1.0.4.rst +++ b/doc/source/whatsnew/v1.0.4.rst @@ -41,6 +41,7 @@ Bug fixes - Bug in :meth:`~DataFrame.to_csv` was silently failing when writing to an invalid s3 bucket. (:issue:`32486`) - Bug in :meth:`read_parquet` was raising a ``FileNotFoundError`` when passed an s3 directory path. (:issue:`26388`) - Bug in :meth:`~DataFrame.to_parquet` was throwing an ``AttributeError`` when writing a partitioned parquet file to s3 (:issue:`27596`) +- Bug in :meth:`GroupBy.quantile` causes the quantiles to be shifted when the ``by`` axis contains ``NaN`` (:issue:`33200`, :issue:`33569`) - Contributors diff --git a/pandas/_libs/groupby.pyx b/pandas/_libs/groupby.pyx index 53c37c8cc8190..68f1057aa7959 100644 --- a/pandas/_libs/groupby.pyx +++ b/pandas/_libs/groupby.pyx @@ -780,7 +780,13 @@ def group_quantile(ndarray[float64_t] out, non_na_counts[lab] += 1 # Get an index of values sorted by labels and then values - order = (values, labels) + if labels.any(): + # Put '-1' (NaN) labels as the last group so it does not interfere + # with the calculations. + labels_for_lexsort = np.where(labels == -1, labels.max() + 1, labels) + else: + labels_for_lexsort = labels + order = (values, labels_for_lexsort) sort_arr = np.lexsort(order).astype(np.int64, copy=False) with nogil: diff --git a/pandas/tests/groupby/test_function.py b/pandas/tests/groupby/test_function.py index 11a9b476e67cd..16aec6e52c7d0 100644 --- a/pandas/tests/groupby/test_function.py +++ b/pandas/tests/groupby/test_function.py @@ -1473,15 +1473,32 @@ def test_quantile_missing_group_values_no_segfaults(): grp.quantile() -def test_quantile_missing_group_values_correct_results(): - # GH 28662 - data = np.array([1.0, np.nan, 3.0, np.nan]) - df = pd.DataFrame(dict(key=data, val=range(4))) +@pytest.mark.parametrize( + "key, val, expected_key, expected_val", + [ + ([1.0, np.nan, 3.0, np.nan], range(4), [1.0, 3.0], [0.0, 2.0]), + ([1.0, np.nan, 2.0, 2.0], range(4), [1.0, 2.0], [0.0, 2.5]), + (["a", "b", "b", np.nan], range(4), ["a", "b"], [0, 1.5]), + ([0], [42], [0], [42.0]), + ([], [], np.array([], dtype="float64"), np.array([], dtype="float64")), + ], +) +def test_quantile_missing_group_values_correct_results( + key, val, expected_key, expected_val +): + # GH 28662, GH 33200, GH 33569 + df = pd.DataFrame({"key": key, "val": val}) - result = df.groupby("key").quantile() expected = pd.DataFrame( - [1.0, 3.0], index=pd.Index([1.0, 3.0], name="key"), columns=["val"] + expected_val, index=pd.Index(expected_key, name="key"), columns=["val"] ) + + grp = df.groupby("key") + + result = grp.quantile(0.5) + tm.assert_frame_equal(result, expected) + + result = grp.quantile() tm.assert_frame_equal(result, expected)