diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index cd312b09ab6c1..c460e4fe23248 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1356,21 +1356,20 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg if axis is not None and abs(axis) >= self.ndim: raise ValueError("abs(axis) must be less than ndim") - if self.size == 0: - if self.ndim == 1 or axis is None: - return NaT - shape = list(self.shape) - del shape[axis] - shape = [1 if x == 0 else x for x in shape] - result = np.empty(shape, dtype="i8") - result.fill(iNaT) + if is_period_dtype(self.dtype): + # pass datetime64 values to nanops to get correct NaT semantics + result = nanops.nanmedian( + self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna + ) + result = result.view("i8") + if axis is None or self.ndim == 1: + return self._box_func(result) return self._from_backing_data(result) - mask = self.isna() - result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask) + result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna) if axis is None or self.ndim == 1: return self._box_func(result) - return self._from_backing_data(result.astype("i8")) + return self._from_backing_data(result) class DatelikeOps(DatetimeLikeArrayMixin): diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index 46ff4a0e2f612..fdf27797db3ab 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -2,6 +2,7 @@ import itertools import operator from typing import Any, Optional, Tuple, Union, cast +import warnings import numpy as np @@ -645,7 +646,11 @@ def get_median(x): mask = notna(x) if not skipna and not mask.all(): return np.nan - return np.nanmedian(x[mask]) + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings("ignore", "All-NaN slice encountered") + res = np.nanmedian(x[mask]) + return res values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask) if not is_float_dtype(values.dtype): @@ -673,7 +678,11 @@ def get_median(x): ) # fastpath for the skipna case - return _wrap_results(np.nanmedian(values, axis), dtype) + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings("ignore", "All-NaN slice encountered") + res = np.nanmedian(values, axis) + return _wrap_results(res, dtype) # must return the correct shape, but median is not defined for the # empty set so return nans of shape "everything but the passed axis" diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 463196eaa36bf..f621479e4f311 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -4,7 +4,7 @@ import pytest import pytz -from pandas._libs import OutOfBoundsDatetime, Timestamp +from pandas._libs import NaT, OutOfBoundsDatetime, Timestamp from pandas.compat.numpy import np_version_under1p18 import pandas as pd @@ -456,6 +456,54 @@ def test_shift_fill_int_deprecated(self): expected[1:] = arr[:-1] tm.assert_equal(result, expected) + def test_median(self, arr1d): + arr = arr1d + if len(arr) % 2 == 0: + # make it easier to define `expected` + arr = arr[:-1] + + expected = arr[len(arr) // 2] + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + arr[len(arr) // 2] = NaT + if not isinstance(expected, Period): + expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean() + + assert arr.median(skipna=False) is NaT + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + assert arr[:0].median() is NaT + assert arr[:0].median(skipna=False) is NaT + + # 2d Case + arr2 = arr.reshape(-1, 1) + + result = arr2.median(axis=None) + assert type(result) is type(expected) + assert result == expected + + assert arr2.median(axis=None, skipna=False) is NaT + + result = arr2.median(axis=0) + expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=0, skipna=False) + expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=1) + tm.assert_equal(result, arr) + + result = arr2.median(axis=1, skipna=False) + tm.assert_equal(result, arr) + class TestDatetimeArray(SharedTests): index_cls = pd.DatetimeIndex @@ -465,7 +513,7 @@ class TestDatetimeArray(SharedTests): @pytest.fixture def arr1d(self, tz_naive_fixture, freqstr): tz = tz_naive_fixture - dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq=freqstr, tz=tz) + dti = pd.date_range("2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz) dta = dti._data return dta diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 9245eda2a71fe..66a92dd6f1cff 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -515,7 +515,7 @@ def test_median_empty(self, skipna, tz): tm.assert_equal(result, expected) result = arr.median(axis=1, skipna=skipna) - expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + expected = type(arr)._from_sequence([], dtype=arr.dtype) tm.assert_equal(result, expected) def test_median(self, arr1d):