From a9a4de5f472a046d1323ad9476fdeb48786e3876 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 27 Sep 2020 17:39:33 -0700 Subject: [PATCH 1/2] ENH: DTA/PA median --- pandas/core/arrays/datetimelike.py | 22 ++++++ pandas/core/arrays/timedeltas.py | 13 ---- pandas/tests/arrays/test_datetimes.py | 76 ++++++++++++++++++- .../tests/reductions/test_stat_reductions.py | 2 +- 4 files changed, 96 insertions(+), 17 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index c90610bdd920c..0547ec658f2dd 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1554,6 +1554,28 @@ def mean(self, skipna=True): # Don't have to worry about NA `result`, since no NA went in. return self._box_func(result) + def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwargs): + nv.validate_median(args, kwargs) + + if axis is not None and abs(axis) >= self.ndim: + raise ValueError("abs(axis) must be less than ndim") + + if self.size == 0: + if self.ndim == 1 or axis is None: + return NaT + shape = list(self.shape) + del shape[axis] + shape = [1 if x == 0 else x for x in shape] + result = np.empty(shape, dtype="i8") + result.fill(iNaT) + return self._from_backing_data(result) + + mask = self.isna() + result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask) + if axis is None or self.ndim == 1: + return self._box_func(result) + return self._from_backing_data(result.astype("i8")) + DatetimeLikeArrayMixin._add_comparison_ops() diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index 145380ecce9fd..4509bc2e8a394 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -370,19 +370,6 @@ def std( result = nanops.nanstd(self._data, axis=axis, skipna=skipna, ddof=ddof) return Timedelta(result) - def median( - self, - axis=None, - out=None, - overwrite_input: bool = False, - keepdims: bool = False, - skipna: bool = True, - ): - nv.validate_median( - (), dict(out=out, overwrite_input=overwrite_input, keepdims=keepdims) - ) - return nanops.nanmedian(self._data, axis=axis, skipna=skipna) - # ---------------------------------------------------------------- # Rendering Methods diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 53f26de09f94e..c0b9b43dabd13 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -445,8 +445,8 @@ def test_tz_dtype_matches(self): class TestReductions: - @pytest.mark.parametrize("tz", [None, "US/Central"]) - def test_min_max(self, tz): + @pytest.fixture + def arr1d(self, tz_naive_fixture): arr = DatetimeArray._from_sequence( [ "2000-01-03", @@ -456,8 +456,13 @@ def test_min_max(self, tz): "2000-01-05", "2000-01-04", ], - tz=tz, + tz=tz_naive_fixture, ) + return arr + + def test_min_max(self, arr1d): + arr = arr1d + tz = arr.tz result = arr.min() expected = pd.Timestamp("2000-01-02", tz=tz) @@ -482,3 +487,68 @@ def test_min_max_empty(self, skipna, tz): result = arr.max(skipna=skipna) assert result is pd.NaT + + @pytest.mark.parametrize("tz", [None, "US/Central"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_median_empty(self, skipna, tz): + arr = DatetimeArray._from_sequence([], tz=tz) + result = arr.median(skipna=skipna) + assert result is pd.NaT + + arr = arr.reshape(0, 3) + result = arr.median(axis=0, skipna=skipna) + expected = type(arr)._from_sequence([pd.NaT, pd.NaT, pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) + + result = arr.median(axis=1, skipna=skipna) + expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) + + def test_median(self, arr1d): + arr = arr1d + + result = arr.median() + assert result == arr[0] + result = arr.median(skipna=False) + assert result is pd.NaT + + result = arr.dropna().median(skipna=False) + assert result == arr[0] + + result = arr.median(axis=0) + assert result == arr[0] + + def test_median_axis(self, arr1d): + arr = arr1d + assert arr.median(axis=0) == arr.median() + assert arr.median(axis=0, skipna=False) is pd.NaT + + msg = r"abs\(axis\) must be less than ndim" + with pytest.raises(ValueError, match=msg): + arr.median(axis=1) + + def test_median_2d(self, arr1d): + arr = arr1d.reshape(1, -1) + + # axis = None + assert arr.median() == arr1d.median() + assert arr.median(skipna=False) is pd.NaT + + # axis = 0 + result = arr.median(axis=0) + expected = arr1d + tm.assert_equal(result, expected) + + # Since column 3 is all-NaT, we get NaT there with or without skipna + result = arr.median(axis=0, skipna=False) + expected = arr1d + tm.assert_equal(result, expected) + + # axis = 1 + result = arr.median(axis=1) + expected = type(arr)._from_sequence([arr1d.median()]) + tm.assert_equal(result, expected) + + result = arr.median(axis=1, skipna=False) + expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) diff --git a/pandas/tests/reductions/test_stat_reductions.py b/pandas/tests/reductions/test_stat_reductions.py index 59dbcb9ab9fa0..fd2746672a0eb 100644 --- a/pandas/tests/reductions/test_stat_reductions.py +++ b/pandas/tests/reductions/test_stat_reductions.py @@ -96,7 +96,7 @@ def _check_stat_op( string_series_[5:15] = np.NaN # mean, idxmax, idxmin, min, and max are valid for dates - if name not in ["max", "min", "mean"]: + if name not in ["max", "min", "mean", "median"]: ds = Series(pd.date_range("1/1/2001", periods=10)) with pytest.raises(TypeError): f(ds) From 781760096062d93b58f2a74cd4bcac65b9ebe879 Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 28 Sep 2020 15:27:27 -0700 Subject: [PATCH 2/2] suppress warning --- pandas/tests/arrays/test_datetimes.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index c0b9b43dabd13..e7e36005c8319 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -527,6 +527,7 @@ def test_median_axis(self, arr1d): with pytest.raises(ValueError, match=msg): arr.median(axis=1) + @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") def test_median_2d(self, arr1d): arr = arr1d.reshape(1, -1)