From 65fe4e460b607c41d62408fb664d35e21ca4f52b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 26 Oct 2020 10:29:06 -0700 Subject: [PATCH 1/2] REF: avoid special case in DTA/TDA.median, flesh out tests --- pandas/core/arrays/datetimelike.py | 21 +++++----- pandas/core/nanops.py | 7 +++- pandas/tests/arrays/test_datetimelike.py | 52 +++++++++++++++++++++++- pandas/tests/arrays/test_datetimes.py | 2 +- 4 files changed, 67 insertions(+), 15 deletions(-) diff --git a/pandas/core/arrays/datetimelike.py b/pandas/core/arrays/datetimelike.py index 4523ea1030ef1..b5f71e5bbe853 100644 --- a/pandas/core/arrays/datetimelike.py +++ b/pandas/core/arrays/datetimelike.py @@ -1359,21 +1359,20 @@ def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwarg if axis is not None and abs(axis) >= self.ndim: raise ValueError("abs(axis) must be less than ndim") - if self.size == 0: - if self.ndim == 1 or axis is None: - return NaT - shape = list(self.shape) - del shape[axis] - shape = [1 if x == 0 else x for x in shape] - result = np.empty(shape, dtype="i8") - result.fill(iNaT) + if is_period_dtype(self.dtype): + # pass datetime64 values to nanops to get correct NaT semantics + result = nanops.nanmedian( + self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna + ) + result = result.view("i8") + if axis is None or self.ndim == 1: + return self._box_func(result) return self._from_backing_data(result) - mask = self.isna() - result = nanops.nanmedian(self.asi8, axis=axis, skipna=skipna, mask=mask) + result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna) if axis is None or self.ndim == 1: return self._box_func(result) - return self._from_backing_data(result.astype("i8")) + return self._from_backing_data(result) class DatelikeOps(DatetimeLikeArrayMixin): diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index c7b6e132f9a74..0a4658fceba4a 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -339,7 +339,12 @@ def _wrap_results(result, dtype: DtypeObj, fill_value=None): assert not isna(fill_value), "Expected non-null fill_value" if result == fill_value: result = np.nan - result = Timestamp(result, tz=tz) + if tz is not None: + result = Timestamp(result, tz=tz) + elif isna(result): + result = np.datetime64("NaT", "ns") + else: + result = np.int64(result).view("datetime64[ns]") else: # If we have float dtype, taking a view will give the wrong result result = result.astype(dtype) diff --git a/pandas/tests/arrays/test_datetimelike.py b/pandas/tests/arrays/test_datetimelike.py index 463196eaa36bf..f621479e4f311 100644 --- a/pandas/tests/arrays/test_datetimelike.py +++ b/pandas/tests/arrays/test_datetimelike.py @@ -4,7 +4,7 @@ import pytest import pytz -from pandas._libs import OutOfBoundsDatetime, Timestamp +from pandas._libs import NaT, OutOfBoundsDatetime, Timestamp from pandas.compat.numpy import np_version_under1p18 import pandas as pd @@ -456,6 +456,54 @@ def test_shift_fill_int_deprecated(self): expected[1:] = arr[:-1] tm.assert_equal(result, expected) + def test_median(self, arr1d): + arr = arr1d + if len(arr) % 2 == 0: + # make it easier to define `expected` + arr = arr[:-1] + + expected = arr[len(arr) // 2] + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + arr[len(arr) // 2] = NaT + if not isinstance(expected, Period): + expected = arr[len(arr) // 2 - 1 : len(arr) // 2 + 2].mean() + + assert arr.median(skipna=False) is NaT + + result = arr.median() + assert type(result) is type(expected) + assert result == expected + + assert arr[:0].median() is NaT + assert arr[:0].median(skipna=False) is NaT + + # 2d Case + arr2 = arr.reshape(-1, 1) + + result = arr2.median(axis=None) + assert type(result) is type(expected) + assert result == expected + + assert arr2.median(axis=None, skipna=False) is NaT + + result = arr2.median(axis=0) + expected2 = type(arr)._from_sequence([expected], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=0, skipna=False) + expected2 = type(arr)._from_sequence([NaT], dtype=arr.dtype) + tm.assert_equal(result, expected2) + + result = arr2.median(axis=1) + tm.assert_equal(result, arr) + + result = arr2.median(axis=1, skipna=False) + tm.assert_equal(result, arr) + class TestDatetimeArray(SharedTests): index_cls = pd.DatetimeIndex @@ -465,7 +513,7 @@ class TestDatetimeArray(SharedTests): @pytest.fixture def arr1d(self, tz_naive_fixture, freqstr): tz = tz_naive_fixture - dti = pd.date_range("2016-01-01 01:01:00", periods=3, freq=freqstr, tz=tz) + dti = pd.date_range("2016-01-01 01:01:00", periods=5, freq=freqstr, tz=tz) dta = dti._data return dta diff --git a/pandas/tests/arrays/test_datetimes.py b/pandas/tests/arrays/test_datetimes.py index 78721fc2fe1c1..e7ac32e7c9ccc 100644 --- a/pandas/tests/arrays/test_datetimes.py +++ b/pandas/tests/arrays/test_datetimes.py @@ -514,7 +514,7 @@ def test_median_empty(self, skipna, tz): tm.assert_equal(result, expected) result = arr.median(axis=1, skipna=skipna) - expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype) + expected = type(arr)._from_sequence([], dtype=arr.dtype) tm.assert_equal(result, expected) def test_median(self, arr1d): From c52b0bce6f4f38aa1e1eccd1d6142ee1cb8f526b Mon Sep 17 00:00:00 2001 From: Brock Date: Mon, 26 Oct 2020 13:53:05 -0700 Subject: [PATCH 2/2] suppress warning --- pandas/core/nanops.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/pandas/core/nanops.py b/pandas/core/nanops.py index 0a4658fceba4a..aee5e5262073e 100644 --- a/pandas/core/nanops.py +++ b/pandas/core/nanops.py @@ -2,6 +2,7 @@ import itertools import operator from typing import Any, Optional, Tuple, Union, cast +import warnings import numpy as np @@ -643,7 +644,11 @@ def get_median(x): mask = notna(x) if not skipna and not mask.all(): return np.nan - return np.nanmedian(x[mask]) + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings("ignore", "All-NaN slice encountered") + res = np.nanmedian(x[mask]) + return res values, mask, dtype, _, _ = _get_values(values, skipna, mask=mask) if not is_float_dtype(values.dtype): @@ -671,7 +676,11 @@ def get_median(x): ) # fastpath for the skipna case - return _wrap_results(np.nanmedian(values, axis), dtype) + with warnings.catch_warnings(): + # Suppress RuntimeWarning about All-NaN slice + warnings.filterwarnings("ignore", "All-NaN slice encountered") + res = np.nanmedian(values, axis) + return _wrap_results(res, dtype) # must return the correct shape, but median is not defined for the # empty set so return nans of shape "everything but the passed axis"