Skip to content

Commit ff55413

Browse files
authored
ENH: DatetimeArray/PeriodArray median (#36694)
1 parent 4f674a1 commit ff55413

File tree

4 files changed

+98
-16
lines changed

4 files changed

+98
-16
lines changed

pandas/core/arrays/datetimelike.py

+22
Original file line numberDiff line numberDiff line change
@@ -1559,6 +1559,28 @@ def mean(self, skipna=True):
15591559
# Don't have to worry about NA `result`, since no NA went in.
15601560
return self._box_func(result)
15611561

1562+
def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwargs):
1563+
nv.validate_median(args, kwargs)
1564+
1565+
if axis is not None and abs(axis) >= self.ndim:
1566+
raise ValueError("abs(axis) must be less than ndim")
1567+
1568+
if self.size == 0:
1569+
if self.ndim == 1 or axis is None:
1570+
return NaT
1571+
shape = list(self.shape)
1572+
del shape[axis]
1573+
shape = [1 if x == 0 else x for x in shape]
1574+
result = np.empty(shape, dtype="i8")
1575+
result.fill(iNaT)
1576+
return self._from_backing_data(result)
1577+
1578+
mask = self.isna()
1579+
result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask)
1580+
if axis is None or self.ndim == 1:
1581+
return self._box_func(result)
1582+
return self._from_backing_data(result.astype("i8"))
1583+
15621584

15631585
DatetimeLikeArrayMixin._add_comparison_ops()
15641586

pandas/core/arrays/timedeltas.py

-13
Original file line numberDiff line numberDiff line change
@@ -393,19 +393,6 @@ def std(
393393
result = nanops.nanstd(self._data, axis=axis, skipna=skipna, ddof=ddof)
394394
return Timedelta(result)
395395

396-
def median(
397-
self,
398-
axis=None,
399-
out=None,
400-
overwrite_input: bool = False,
401-
keepdims: bool = False,
402-
skipna: bool = True,
403-
):
404-
nv.validate_median(
405-
(), dict(out=out, overwrite_input=overwrite_input, keepdims=keepdims)
406-
)
407-
return nanops.nanmedian(self._data, axis=axis, skipna=skipna)
408-
409396
# ----------------------------------------------------------------
410397
# Rendering Methods
411398

pandas/tests/arrays/test_datetimes.py

+75-2
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,9 @@ def test_tz_dtype_matches(self):
454454

455455

456456
class TestReductions:
457-
@pytest.mark.parametrize("tz", [None, "US/Central"])
458-
def test_min_max(self, tz):
457+
@pytest.fixture
458+
def arr1d(self, tz_naive_fixture):
459+
tz = tz_naive_fixture
459460
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
460461
arr = DatetimeArray._from_sequence(
461462
[
@@ -468,6 +469,11 @@ def test_min_max(self, tz):
468469
],
469470
dtype=dtype,
470471
)
472+
return arr
473+
474+
def test_min_max(self, arr1d):
475+
arr = arr1d
476+
tz = arr.tz
471477

472478
result = arr.min()
473479
expected = pd.Timestamp("2000-01-02", tz=tz)
@@ -493,3 +499,70 @@ def test_min_max_empty(self, skipna, tz):
493499

494500
result = arr.max(skipna=skipna)
495501
assert result is pd.NaT
502+
503+
@pytest.mark.parametrize("tz", [None, "US/Central"])
504+
@pytest.mark.parametrize("skipna", [True, False])
505+
def test_median_empty(self, skipna, tz):
506+
dtype = DatetimeTZDtype(tz=tz) if tz is not None else np.dtype("M8[ns]")
507+
arr = DatetimeArray._from_sequence([], dtype=dtype)
508+
result = arr.median(skipna=skipna)
509+
assert result is pd.NaT
510+
511+
arr = arr.reshape(0, 3)
512+
result = arr.median(axis=0, skipna=skipna)
513+
expected = type(arr)._from_sequence([pd.NaT, pd.NaT, pd.NaT], dtype=arr.dtype)
514+
tm.assert_equal(result, expected)
515+
516+
result = arr.median(axis=1, skipna=skipna)
517+
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
518+
tm.assert_equal(result, expected)
519+
520+
def test_median(self, arr1d):
521+
arr = arr1d
522+
523+
result = arr.median()
524+
assert result == arr[0]
525+
result = arr.median(skipna=False)
526+
assert result is pd.NaT
527+
528+
result = arr.dropna().median(skipna=False)
529+
assert result == arr[0]
530+
531+
result = arr.median(axis=0)
532+
assert result == arr[0]
533+
534+
def test_median_axis(self, arr1d):
535+
arr = arr1d
536+
assert arr.median(axis=0) == arr.median()
537+
assert arr.median(axis=0, skipna=False) is pd.NaT
538+
539+
msg = r"abs\(axis\) must be less than ndim"
540+
with pytest.raises(ValueError, match=msg):
541+
arr.median(axis=1)
542+
543+
@pytest.mark.filterwarnings("ignore:All-NaN slice encountered:RuntimeWarning")
544+
def test_median_2d(self, arr1d):
545+
arr = arr1d.reshape(1, -1)
546+
547+
# axis = None
548+
assert arr.median() == arr1d.median()
549+
assert arr.median(skipna=False) is pd.NaT
550+
551+
# axis = 0
552+
result = arr.median(axis=0)
553+
expected = arr1d
554+
tm.assert_equal(result, expected)
555+
556+
# Since column 3 is all-NaT, we get NaT there with or without skipna
557+
result = arr.median(axis=0, skipna=False)
558+
expected = arr1d
559+
tm.assert_equal(result, expected)
560+
561+
# axis = 1
562+
result = arr.median(axis=1)
563+
expected = type(arr)._from_sequence([arr1d.median()])
564+
tm.assert_equal(result, expected)
565+
566+
result = arr.median(axis=1, skipna=False)
567+
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
568+
tm.assert_equal(result, expected)

pandas/tests/reductions/test_stat_reductions.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _check_stat_op(
9696
string_series_[5:15] = np.NaN
9797

9898
# mean, idxmax, idxmin, min, and max are valid for dates
99-
if name not in ["max", "min", "mean"]:
99+
if name not in ["max", "min", "mean", "median"]:
100100
ds = Series(pd.date_range("1/1/2001", periods=10))
101101
with pytest.raises(TypeError):
102102
f(ds)

0 commit comments

Comments
 (0)