Skip to content

ENH: DatetimeArray/PeriodArray median #36694

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 7, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions pandas/core/arrays/datetimelike.py
Original file line number Diff line number Diff line change
Expand Up @@ -1554,6 +1554,28 @@ def mean(self, skipna=True):
# Don't have to worry about NA `result`, since no NA went in.
return self._box_func(result)

def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwargs):
nv.validate_median(args, kwargs)

if axis is not None and abs(axis) >= self.ndim:
raise ValueError("abs(axis) must be less than ndim")

if self.size == 0:
if self.ndim == 1 or axis is None:
return NaT
shape = list(self.shape)
del shape[axis]
shape = [1 if x == 0 else x for x in shape]
result = np.empty(shape, dtype="i8")
result.fill(iNaT)
return self._from_backing_data(result)

mask = self.isna()
result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask)
if axis is None or self.ndim == 1:
return self._box_func(result)
return self._from_backing_data(result.astype("i8"))


DatetimeLikeArrayMixin._add_comparison_ops()

Expand Down
13 changes: 0 additions & 13 deletions pandas/core/arrays/timedeltas.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,19 +370,6 @@ def std(
result = nanops.nanstd(self._data, axis=axis, skipna=skipna, ddof=ddof)
return Timedelta(result)

def median(
self,
axis=None,
out=None,
overwrite_input: bool = False,
keepdims: bool = False,
skipna: bool = True,
):
nv.validate_median(
(), dict(out=out, overwrite_input=overwrite_input, keepdims=keepdims)
)
return nanops.nanmedian(self._data, axis=axis, skipna=skipna)

# ----------------------------------------------------------------
# Rendering Methods

Expand Down
77 changes: 74 additions & 3 deletions pandas/tests/arrays/test_datetimes.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ def test_tz_dtype_matches(self):


class TestReductions:
@pytest.mark.parametrize("tz", [None, "US/Central"])
def test_min_max(self, tz):
@pytest.fixture
def arr1d(self, tz_naive_fixture):
arr = DatetimeArray._from_sequence(
[
"2000-01-03",
Expand All @@ -456,8 +456,13 @@ def test_min_max(self, tz):
"2000-01-05",
"2000-01-04",
],
tz=tz,
tz=tz_naive_fixture,
)
return arr

def test_min_max(self, arr1d):
arr = arr1d
tz = arr.tz

result = arr.min()
expected = pd.Timestamp("2000-01-02", tz=tz)
Expand All @@ -482,3 +487,69 @@ def test_min_max_empty(self, skipna, tz):

result = arr.max(skipna=skipna)
assert result is pd.NaT

@pytest.mark.parametrize("tz", [None, "US/Central"])
@pytest.mark.parametrize("skipna", [True, False])
def test_median_empty(self, skipna, tz):
arr = DatetimeArray._from_sequence([], tz=tz)
result = arr.median(skipna=skipna)
assert result is pd.NaT

arr = arr.reshape(0, 3)
result = arr.median(axis=0, skipna=skipna)
expected = type(arr)._from_sequence([pd.NaT, pd.NaT, pd.NaT], dtype=arr.dtype)
tm.assert_equal(result, expected)

result = arr.median(axis=1, skipna=skipna)
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
tm.assert_equal(result, expected)

def test_median(self, arr1d):
arr = arr1d

result = arr.median()
assert result == arr[0]
result = arr.median(skipna=False)
assert result is pd.NaT

result = arr.dropna().median(skipna=False)
assert result == arr[0]

result = arr.median(axis=0)
assert result == arr[0]

def test_median_axis(self, arr1d):
arr = arr1d
assert arr.median(axis=0) == arr.median()
assert arr.median(axis=0, skipna=False) is pd.NaT

msg = r"abs\(axis\) must be less than ndim"
with pytest.raises(ValueError, match=msg):
arr.median(axis=1)

@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
def test_median_2d(self, arr1d):
arr = arr1d.reshape(1, -1)

# axis = None
assert arr.median() == arr1d.median()
assert arr.median(skipna=False) is pd.NaT

# axis = 0
result = arr.median(axis=0)
expected = arr1d
tm.assert_equal(result, expected)

# Since column 3 is all-NaT, we get NaT there with or without skipna
result = arr.median(axis=0, skipna=False)
expected = arr1d
tm.assert_equal(result, expected)

# axis = 1
result = arr.median(axis=1)
expected = type(arr)._from_sequence([arr1d.median()])
tm.assert_equal(result, expected)

result = arr.median(axis=1, skipna=False)
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
tm.assert_equal(result, expected)
2 changes: 1 addition & 1 deletion pandas/tests/reductions/test_stat_reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _check_stat_op(
string_series_[5:15] = np.NaN

# mean, idxmax, idxmin, min, and max are valid for dates
if name not in ["max", "min", "mean"]:
if name not in ["max", "min", "mean", "median"]:
ds = Series(pd.date_range("1/1/2001", periods=10))
with pytest.raises(TypeError):
f(ds)
Expand Down