From 7a1776a6368a6cea2ac4a7eec83d0cf1920116e0 Mon Sep 17 00:00:00 2001 From: Brock Date: Wed, 25 Jan 2023 15:17:13 -0800 Subject: [PATCH] BUG/ENH: groupby.quantile support non-nano --- pandas/core/groupby/groupby.py | 4 ++-- pandas/tests/groupby/test_quantile.py | 24 ++++++++++++++++++------ pandas/tests/resample/test_timedelta.py | 8 +++++--- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/pandas/core/groupby/groupby.py b/pandas/core/groupby/groupby.py index c15948ce877a8..a70ad73366cc8 100644 --- a/pandas/core/groupby/groupby.py +++ b/pandas/core/groupby/groupby.py @@ -3196,10 +3196,10 @@ def pre_processor(vals: ArrayLike) -> tuple[np.ndarray, Dtype | None]: elif is_bool_dtype(vals.dtype) and isinstance(vals, ExtensionArray): out = vals.to_numpy(dtype=float, na_value=np.nan) elif is_datetime64_dtype(vals.dtype): - inference = np.dtype("datetime64[ns]") + inference = vals.dtype out = np.asarray(vals).astype(float) elif is_timedelta64_dtype(vals.dtype): - inference = np.dtype("timedelta64[ns]") + inference = vals.dtype out = np.asarray(vals).astype(float) elif isinstance(vals, ExtensionArray) and is_float_dtype(vals): inference = np.dtype(np.float64) diff --git a/pandas/tests/groupby/test_quantile.py b/pandas/tests/groupby/test_quantile.py index 62978596ff4fb..8cba3a8afdfae 100644 --- a/pandas/tests/groupby/test_quantile.py +++ b/pandas/tests/groupby/test_quantile.py @@ -26,8 +26,12 @@ ([np.nan, 4.0, np.nan, 2.0, np.nan], [np.nan, 4.0, np.nan, 2.0, np.nan]), # Timestamps ( - list(pd.date_range("1/1/18", freq="D", periods=5)), - list(pd.date_range("1/1/18", freq="D", periods=5))[::-1], + pd.date_range("1/1/18", freq="D", periods=5), + pd.date_range("1/1/18", freq="D", periods=5)[::-1], + ), + ( + pd.date_range("1/1/18", freq="D", periods=5).as_unit("s"), + pd.date_range("1/1/18", freq="D", periods=5)[::-1].as_unit("s"), ), # All NA ([np.nan] * 5, [np.nan] * 5), @@ -35,24 +39,32 @@ ) @pytest.mark.parametrize("q", [0, 0.25, 0.5, 0.75, 1]) def test_quantile(interpolation, a_vals, b_vals, q, request): - if interpolation == "nearest" and q == 0.5 and b_vals == [4, 3, 2, 1]: + if ( + interpolation == "nearest" + and q == 0.5 + and isinstance(b_vals, list) + and b_vals == [4, 3, 2, 1] + ): request.node.add_marker( pytest.mark.xfail( reason="Unclear numpy expectation for nearest " "result with equidistant data" ) ) + all_vals = pd.concat([pd.Series(a_vals), pd.Series(b_vals)]) a_expected = pd.Series(a_vals).quantile(q, interpolation=interpolation) b_expected = pd.Series(b_vals).quantile(q, interpolation=interpolation) - df = DataFrame( - {"key": ["a"] * len(a_vals) + ["b"] * len(b_vals), "val": a_vals + b_vals} - ) + df = DataFrame({"key": ["a"] * len(a_vals) + ["b"] * len(b_vals), "val": all_vals}) expected = DataFrame( [a_expected, b_expected], columns=["val"], index=Index(["a", "b"], name="key") ) + if all_vals.dtype.kind == "M" and expected.dtypes.values[0].kind == "M": + # TODO(non-nano): this should be unnecessary once array_to_datetime + # correctly infers non-nano from Timestamp.unit + expected = expected.astype(all_vals.dtype) result = df.groupby("key").quantile(q, interpolation=interpolation) tm.assert_frame_equal(result, expected) diff --git a/pandas/tests/resample/test_timedelta.py b/pandas/tests/resample/test_timedelta.py index ad1c361373189..5c9f61a4adc28 100644 --- a/pandas/tests/resample/test_timedelta.py +++ b/pandas/tests/resample/test_timedelta.py @@ -174,10 +174,12 @@ def test_resample_with_timedelta_yields_no_empty_groups(duplicates): tm.assert_frame_equal(result, expected) -def test_resample_quantile_timedelta(): +@pytest.mark.parametrize("unit", ["s", "ms", "us", "ns"]) +def test_resample_quantile_timedelta(unit): # GH: 29485 + dtype = np.dtype(f"m8[{unit}]") df = DataFrame( - {"value": pd.to_timedelta(np.arange(4), unit="s")}, + {"value": pd.to_timedelta(np.arange(4), unit="s").astype(dtype)}, index=pd.date_range("20200101", periods=4, tz="UTC"), ) result = df.resample("2D").quantile(0.99) @@ -189,7 +191,7 @@ def test_resample_quantile_timedelta(): ] }, index=pd.date_range("20200101", periods=2, tz="UTC", freq="2D"), - ) + ).astype(dtype) tm.assert_frame_equal(result, expected)