Skip to content

Commit 79b2d66

Browse files
jbrockmendelukarroum
authored andcommitted
REF: avoid special-casing inside DTA/TDA.mean (pandas-dev#37422)
1 parent 6c7d391 commit 79b2d66

File tree

4 files changed

+103
-19
lines changed

4 files changed

+103
-19
lines changed

pandas/core/arrays/datetimelike.py

+8-16
Original file line numberDiff line numberDiff line change
@@ -1304,7 +1304,7 @@ def max(self, axis=None, skipna=True, *args, **kwargs):
13041304
# Don't have to worry about NA `result`, since no NA went in.
13051305
return self._box_func(result)
13061306

1307-
def mean(self, skipna=True):
1307+
def mean(self, skipna=True, axis: Optional[int] = 0):
13081308
"""
13091309
Return the mean value of the Array.
13101310
@@ -1314,6 +1314,7 @@ def mean(self, skipna=True):
13141314
----------
13151315
skipna : bool, default True
13161316
Whether to ignore any NaT elements.
1317+
axis : int, optional, default 0
13171318
13181319
Returns
13191320
-------
@@ -1337,21 +1338,12 @@ def mean(self, skipna=True):
13371338
"obj.to_timestamp(how='start').mean()"
13381339
)
13391340

1340-
mask = self.isna()
1341-
if skipna:
1342-
values = self[~mask]
1343-
elif mask.any():
1344-
return NaT
1345-
else:
1346-
values = self
1347-
1348-
if not len(values):
1349-
# short-circuit for empty max / min
1350-
return NaT
1351-
1352-
result = nanops.nanmean(values.view("i8"), skipna=skipna)
1353-
# Don't have to worry about NA `result`, since no NA went in.
1354-
return self._box_func(result)
1341+
result = nanops.nanmean(
1342+
self._ndarray, axis=axis, skipna=skipna, mask=self.isna()
1343+
)
1344+
if axis is None or self.ndim == 1:
1345+
return self._box_func(result)
1346+
return self._from_backing_data(result)
13551347

13561348
def median(self, axis: Optional[int] = None, skipna: bool = True, *args, **kwargs):
13571349
nv.validate_median(args, kwargs)

pandas/core/nanops.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,12 @@ def _wrap_results(result, dtype: DtypeObj, fill_value=None):
339339
assert not isna(fill_value), "Expected non-null fill_value"
340340
if result == fill_value:
341341
result = np.nan
342-
result = Timestamp(result, tz=tz)
342+
if tz is not None:
343+
result = Timestamp(result, tz=tz)
344+
elif isna(result):
345+
result = np.datetime64("NaT", "ns")
346+
else:
347+
result = np.int64(result).view("datetime64[ns]")
343348
else:
344349
# If we have float dtype, taking a view will give the wrong result
345350
result = result.astype(dtype)
@@ -386,8 +391,9 @@ def _na_for_min_count(
386391

387392
if values.ndim == 1:
388393
return fill_value
394+
elif axis is None:
395+
return fill_value
389396
else:
390-
assert axis is not None # assertion to make mypy happy
391397
result_shape = values.shape[:axis] + values.shape[axis + 1 :]
392398

393399
result = np.full(result_shape, fill_value, dtype=values.dtype)

pandas/tests/arrays/test_datetimes.py

+52
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pandas.core.dtypes.dtypes import DatetimeTZDtype
1010

1111
import pandas as pd
12+
from pandas import NaT
1213
import pandas._testing as tm
1314
from pandas.core.arrays import DatetimeArray
1415
from pandas.core.arrays.datetimes import sequence_to_dt64ns
@@ -566,3 +567,54 @@ def test_median_2d(self, arr1d):
566567
result = arr.median(axis=1, skipna=False)
567568
expected = type(arr)._from_sequence([pd.NaT], dtype=arr.dtype)
568569
tm.assert_equal(result, expected)
570+
571+
def test_mean(self, arr1d):
572+
arr = arr1d
573+
574+
# manually verified result
575+
expected = arr[0] + 0.4 * pd.Timedelta(days=1)
576+
577+
result = arr.mean()
578+
assert result == expected
579+
result = arr.mean(skipna=False)
580+
assert result is pd.NaT
581+
582+
result = arr.dropna().mean(skipna=False)
583+
assert result == expected
584+
585+
result = arr.mean(axis=0)
586+
assert result == expected
587+
588+
def test_mean_2d(self):
589+
dti = pd.date_range("2016-01-01", periods=6, tz="US/Pacific")
590+
dta = dti._data.reshape(3, 2)
591+
592+
result = dta.mean(axis=0)
593+
expected = dta[1]
594+
tm.assert_datetime_array_equal(result, expected)
595+
596+
result = dta.mean(axis=1)
597+
expected = dta[:, 0] + pd.Timedelta(hours=12)
598+
tm.assert_datetime_array_equal(result, expected)
599+
600+
result = dta.mean(axis=None)
601+
expected = dti.mean()
602+
assert result == expected
603+
604+
@pytest.mark.parametrize("skipna", [True, False])
605+
def test_mean_empty(self, arr1d, skipna):
606+
arr = arr1d[:0]
607+
608+
assert arr.mean(skipna=skipna) is NaT
609+
610+
arr2d = arr.reshape(0, 3)
611+
result = arr2d.mean(axis=0, skipna=skipna)
612+
expected = DatetimeArray._from_sequence([NaT, NaT, NaT], dtype=arr.dtype)
613+
tm.assert_datetime_array_equal(result, expected)
614+
615+
result = arr2d.mean(axis=1, skipna=skipna)
616+
expected = arr # i.e. 1D, empty
617+
tm.assert_datetime_array_equal(result, expected)
618+
619+
result = arr2d.mean(axis=None, skipna=skipna)
620+
assert result is NaT

pandas/tests/arrays/test_timedeltas.py

+35-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_neg_freq(self):
177177

178178

179179
class TestReductions:
180-
@pytest.mark.parametrize("name", ["std", "min", "max", "median"])
180+
@pytest.mark.parametrize("name", ["std", "min", "max", "median", "mean"])
181181
@pytest.mark.parametrize("skipna", [True, False])
182182
def test_reductions_empty(self, name, skipna):
183183
tdi = pd.TimedeltaIndex([])
@@ -334,3 +334,37 @@ def test_median(self):
334334

335335
result = tdi.median(skipna=False)
336336
assert result is pd.NaT
337+
338+
def test_mean(self):
339+
tdi = pd.TimedeltaIndex(["0H", "3H", "NaT", "5H06m", "0H", "2H"])
340+
arr = tdi._data
341+
342+
# manually verified result
343+
expected = pd.Timedelta(arr.dropna()._ndarray.mean())
344+
345+
result = arr.mean()
346+
assert result == expected
347+
result = arr.mean(skipna=False)
348+
assert result is pd.NaT
349+
350+
result = arr.dropna().mean(skipna=False)
351+
assert result == expected
352+
353+
result = arr.mean(axis=0)
354+
assert result == expected
355+
356+
def test_mean_2d(self):
357+
tdi = pd.timedelta_range("14 days", periods=6)
358+
tda = tdi._data.reshape(3, 2)
359+
360+
result = tda.mean(axis=0)
361+
expected = tda[1]
362+
tm.assert_timedelta_array_equal(result, expected)
363+
364+
result = tda.mean(axis=1)
365+
expected = tda[:, 0] + pd.Timedelta(hours=12)
366+
tm.assert_timedelta_array_equal(result, expected)
367+
368+
result = tda.mean(axis=None)
369+
expected = tdi.mean()
370+
assert result == expected

0 commit comments

Comments
 (0)