Skip to content

Commit 86ee235

Browse files
authored
REF: avoid special case in DTA/TDA.median, flesh out tests (#37423)
1 parent dcde1f4 commit 86ee235

File tree

4 files changed

+72
-16
lines changed

4 files changed

+72
-16
lines changed

pandas/core/arrays/datetimelike.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -1356,21 +1356,20 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg
13561356
if axis is not None and abs(axis) >= self.ndim:
13571357
raise ValueError("abs(axis) must be less than ndim")
13581358

1359-
if self.size == 0:
1360-
if self.ndim == 1 or axis is None:
1361-
return NaT
1362-
shape = list(self.shape)
1363-
del shape[axis]
1364-
shape = [1 if x == 0 else x for x in shape]
1365-
result = np.empty(shape, dtype="i8")
1366-
result.fill(iNaT)
1359+
if is_period_dtype(self.dtype):
1360+
# pass datetime64 values to nanops to get correct NaT semantics
1361+
result = nanops.nanmedian(
1362+
self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna
1363+
)
1364+
result = result.view("i8")
1365+
if axis is None or self.ndim == 1:
1366+
return self._box_func(result)
13671367
return self._from_backing_data(result)
13681368

1369-
mask = self.isna()
1370-
result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask)
1369+
result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
13711370
if axis is None or self.ndim == 1:
13721371
return self._box_func(result)
1373-
return self._from_backing_data(result.astype("i8"))
1372+
return self._from_backing_data(result)
13741373

13751374

13761375
class DatelikeOps(DatetimeLikeArrayMixin):

pandas/core/nanops.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import itertools
33
import operator
44
from typing import Any, Optional, Tuple, Union, cast
5+
import warnings
56

67
import numpy as np
78

@@ -645,7 +646,11 @@ def get_median(x):
645646
mask = notna(x)
646647
if not skipna and not mask.all():
647648
return np.nan
648-
return np.nanmedian(x[mask])
649+
with warnings.catch_warnings():
650+
# Suppress RuntimeWarning about All-NaN slice
651+
warnings.filterwarnings("ignore", "All-NaN slice encountered")
652+
res = np.nanmedian(x[mask])
653+
return res
649654

650655
values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask)
651656
if not is_float_dtype(values.dtype):
@@ -673,7 +678,11 @@ def get_median(x):
673678
)
674679

675680
# fastpath for the skipna case
676-
return _wrap_results(np.nanmedian(values, axis), dtype)
681+
with warnings.catch_warnings():
682+
# Suppress RuntimeWarning about All-NaN slice
683+
warnings.filterwarnings("ignore", "All-NaN slice encountered")
684+
res = np.nanmedian(values, axis)
685+
return _wrap_results(res, dtype)
677686

678687
# must return the correct shape, but median is not defined for the
679688
# empty set so return nans of shape "everything but the passed axis"

pandas/tests/arrays/test_datetimelike.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytest
55
import pytz
66

7-
from pandas._libs import OutOfBoundsDatetime, Timestamp
7+
from pandas._libs import NaT, OutOfBoundsDatetime, Timestamp
88
from pandas.compat.numpy import np_version_under1p18
99

1010
import pandas as pd
@@ -456,6 +456,54 @@ def test_shift_fill_int_deprecated(self):
456456
expected[1:] = arr[:-1]
457457
tm.assert_equal(result, expected)
458458

459+
def test_median(self, arr1d):
460+
arr = arr1d
461+
if len(arr) % 2 == 0:
462+
# make it easier to define `expected`
463+
arr = arr[:-1]
464+
465+
expected = arr[len(arr) // 2]
466+
467+
result = arr.median()
468+
assert type(result) is type(expected)
469+
assert result == expected
470+
471+
arr[len(arr) // 2] = NaT
472+
if not isinstance(expected, Period):
473+
expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean()
474+
475+
assert arr.median(skipna=False) is NaT
476+
477+
result = arr.median()
478+
assert type(result) is type(expected)
479+
assert result == expected
480+
481+
assert arr[:0].median() is NaT
482+
assert arr[:0].median(skipna=False) is NaT
483+
484+
# 2d Case
485+
arr2 = arr.reshape(-1, 1)
486+
487+
result = arr2.median(axis=None)
488+
assert type(result) is type(expected)
489+
assert result == expected
490+
491+
assert arr2.median(axis=None, skipna=False) is NaT
492+
493+
result = arr2.median(axis=0)
494+
expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype)
495+
tm.assert_equal(result, expected2)
496+
497+
result = arr2.median(axis=0, skipna=False)
498+
expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype)
499+
tm.assert_equal(result, expected2)
500+
501+
result = arr2.median(axis=1)
502+
tm.assert_equal(result, arr)
503+
504+
result = arr2.median(axis=1, skipna=False)
505+
tm.assert_equal(result, arr)
506+
459507

460508
class TestDatetimeArray(SharedTests):
461509
index_cls = pd.DatetimeIndex
@@ -465,7 +513,7 @@ class TestDatetimeArray(SharedTests):
465513
@pytest.fixture
466514
def arr1d(self, tz_naive_fixture, freqstr):
467515
tz = tz_naive_fixture
468-
dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq=freqstr, tz=tz)
516+
dti = pd.date_range("2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz)
469517
dta = dti._data
470518
return dta
471519

pandas/tests/arrays/test_datetimes.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ def test_median_empty(self, skipna, tz):
515515
tm.assert_equal(result, expected)
516516

517517
result = arr.median(axis=1, skipna=skipna)
518-
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
518+
expected = type(arr)._from_sequence([], dtype=arr.dtype)
519519
tm.assert_equal(result, expected)
520520

521521
def test_median(self, arr1d):

0 commit comments

Comments
 (0)