diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 83a9c0ba61c2d..1f11ca6d4cc8b 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1556,6 +1556,28 @@ def mean(self, skipna=True): # Don't have to worry about NA `result`, since no NA went in. return self._box_func(result) + def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwargs): + nv.validate_median(args, kwargs) + + if axis is not None and abs(axis) >= self.ndim: + raise ValueError("abs(axis) must be less than ndim") + + if self.size == 0: + if self.ndim == 1 or axis is None: + return NaT + shape = list(self.shape) + del shape[axis] + shape = [1 if x == 0 else x for x in shape] + result = np.empty(shape, dtype="i8") + result.fill(iNaT) + return self._from_backing_data(result) + + mask = self.isna() + result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask) + if axis is None or self.ndim == 1: + return self._box_func(result) + return self._from_backing_data(result.astype("i8")) + DatetimeLikeArrayMixin._add_comparison_ops() diff --git a/pandas/core/arrays/timedeltas.py b/pandas/core/arrays/timedeltas.py index c97c7da375fd4..f85e3e716bbf9 100644 --- a/pandas/core/arrays/timedeltas.py +++ b/pandas/core/arrays/timedeltas.py @@ -393,19 +393,6 @@ def std( result = nanops.nanstd(self._data, axis=axis, skipna=skipna, ddof=ddof) return Timedelta(result) - def median( - self, - axis=None, - out=None, - overwrite_input: bool = False, - keepdims: bool = False, - skipna: bool = True, - ): - nv.validate_median( - (), dict(out=out, overwrite_input=overwrite_input, keepdims=keepdims) - ) - return nanops.nanmedian(self._data, axis=axis, skipna=skipna) - # ---------------------------------------------------------------- # Rendering Methods diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index e7605125e7420..9f136b4979bb7 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -454,8 +454,9 @@ def test_tz_dtype_matches(self): class TestReductions: - @pytest.mark.parametrize("tz", [None, "US/Central"]) - def test_min_max(self, tz): + @pytest.fixture + def arr1d(self, tz_naive_fixture): + tz = tz_naive_fixture dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]") arr = DatetimeArray._from_sequence( [ @@ -468,6 +469,11 @@ def test_min_max(self, tz): ], dtype=dtype, ) + return arr + + def test_min_max(self, arr1d): + arr = arr1d + tz = arr.tz result = arr.min() expected = pd.Timestamp("2000-01-02", tz=tz) @@ -493,3 +499,70 @@ def test_min_max_empty(self, skipna, tz): result = arr.max(skipna=skipna) assert result is pd.NaT + + @pytest.mark.parametrize("tz", [None, "US/Central"]) + @pytest.mark.parametrize("skipna", [True, False]) + def test_median_empty(self, skipna, tz): + dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]") + arr = DatetimeArray._from_sequence([], dtype=dtype) + result = arr.median(skipna=skipna) + assert result is pd.NaT + + arr = arr.reshape(0, 3) + result = arr.median(axis=0, skipna=skipna) + expected = type(arr)._from_sequence([pd.NaT, pd.NaT, pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) + + result = arr.median(axis=1, skipna=skipna) + expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) + + def test_median(self, arr1d): + arr = arr1d + + result = arr.median() + assert result == arr[0] + result = arr.median(skipna=False) + assert result is pd.NaT + + result = arr.dropna().median(skipna=False) + assert result == arr[0] + + result = arr.median(axis=0) + assert result == arr[0] + + def test_median_axis(self, arr1d): + arr = arr1d + assert arr.median(axis=0) == arr.median() + assert arr.median(axis=0, skipna=False) is pd.NaT + + msg = r"abs\(axis\) must be less than ndim" + with pytest.raises(ValueError, match=msg): + arr.median(axis=1) + + @pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning") + def test_median_2d(self, arr1d): + arr = arr1d.reshape(1, -1) + + # axis = None + assert arr.median() == arr1d.median() + assert arr.median(skipna=False) is pd.NaT + + # axis = 0 + result = arr.median(axis=0) + expected = arr1d + tm.assert_equal(result, expected) + + # Since column 3 is all-NaT, we get NaT there with or without skipna + result = arr.median(axis=0, skipna=False) + expected = arr1d + tm.assert_equal(result, expected) + + # axis = 1 + result = arr.median(axis=1) + expected = type(arr)._from_sequence([arr1d.median()]) + tm.assert_equal(result, expected) + + result = arr.median(axis=1, skipna=False) + expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + tm.assert_equal(result, expected) diff --git a/pandas/tests/reductions/test_stat_reductions.py b/pandas/tests/reductions/test_stat_reductions.py index 59dbcb9ab9fa0..fd2746672a0eb 100644 --- a/pandas/tests/reductions/test_stat_reductions.py +++ b/pandas/tests/reductions/test_stat_reductions.py @@ -96,7 +96,7 @@ def _check_stat_op( string_series_[5:15] = np.NaN # mean, idxmax, idxmin, min, and max are valid for dates - if name not in ["max", "min", "mean"]: + if name not in ["max", "min", "mean", "median"]: ds = Series(pd.date_range("1/1/2001", periods=10)) with pytest.raises(TypeError): f(ds)