Skip to content

Commit e0b93cc

Browse files
authored
REF: de-duplicate period-dispatch (#50215)
* REF: de-duplicate period-dispatch * mypy fixup
1 parent 749d59d commit e0b93cc

File tree

2 files changed

+42
-47
lines changed

2 files changed

+42
-47
lines changed

pandas/core/arrays/datetimelike.py

+38-34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
datetime,
55
timedelta,
66
)
7+
from functools import wraps
78
import operator
89
from typing import (
910
TYPE_CHECKING,
@@ -57,6 +58,7 @@
5758
DatetimeLikeScalar,
5859
Dtype,
5960
DtypeObj,
61+
F,
6062
NpDtype,
6163
PositionalIndexer2D,
6264
PositionalIndexerTuple,
@@ -157,6 +159,31 @@
157159
DatetimeLikeArrayT = TypeVar("DatetimeLikeArrayT", bound="DatetimeLikeArrayMixin")
158160

159161

162+
def _period_dispatch(meth: F) -> F:
163+
"""
164+
For PeriodArray methods, dispatch to DatetimeArray and re-wrap the results
165+
in PeriodArray. We cannot use ._ndarray directly for the affected
166+
methods because the i8 data has different semantics on NaT values.
167+
"""
168+
169+
@wraps(meth)
170+
def new_meth(self, *args, **kwargs):
171+
if not is_period_dtype(self.dtype):
172+
return meth(self, *args, **kwargs)
173+
174+
arr = self.view("M8[ns]")
175+
result = meth(arr, *args, **kwargs)
176+
if result is NaT:
177+
return NaT
178+
elif isinstance(result, Timestamp):
179+
return self._box_func(result.value)
180+
181+
res_i8 = result.view("i8")
182+
return self._from_backing_data(res_i8)
183+
184+
return cast(F, new_meth)
185+
186+
160187
class DatetimeLikeArrayMixin(OpsMixin, NDArrayBackedExtensionArray):
161188
"""
162189
Shared Base/Mixin class for DatetimeArray, TimedeltaArray, PeriodArray
@@ -1546,6 +1573,15 @@ def __isub__(self: DatetimeLikeArrayT, other) -> DatetimeLikeArrayT:
15461573
# --------------------------------------------------------------
15471574
# Reductions
15481575

1576+
@_period_dispatch
1577+
def _quantile(
1578+
self: DatetimeLikeArrayT,
1579+
qs: npt.NDArray[np.float64],
1580+
interpolation: str,
1581+
) -> DatetimeLikeArrayT:
1582+
return super()._quantile(qs=qs, interpolation=interpolation)
1583+
1584+
@_period_dispatch
15491585
def min(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs):
15501586
"""
15511587
Return the minimum value of the Array or minimum along
@@ -1560,21 +1596,10 @@ def min(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs):
15601596
nv.validate_min((), kwargs)
15611597
nv.validate_minmax_axis(axis, self.ndim)
15621598

1563-
if is_period_dtype(self.dtype):
1564-
# pass datetime64 values to nanops to get correct NaT semantics
1565-
result = nanops.nanmin(
1566-
self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna
1567-
)
1568-
if result is NaT:
1569-
return NaT
1570-
result = result.view("i8")
1571-
if axis is None or self.ndim == 1:
1572-
return self._box_func(result)
1573-
return self._from_backing_data(result)
1574-
15751599
result = nanops.nanmin(self._ndarray, axis=axis, skipna=skipna)
15761600
return self._wrap_reduction_result(axis, result)
15771601

1602+
@_period_dispatch
15781603
def max(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs):
15791604
"""
15801605
Return the maximum value of the Array or maximum along
@@ -1589,18 +1614,6 @@ def max(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs):
15891614
nv.validate_max((), kwargs)
15901615
nv.validate_minmax_axis(axis, self.ndim)
15911616

1592-
if is_period_dtype(self.dtype):
1593-
# pass datetime64 values to nanops to get correct NaT semantics
1594-
result = nanops.nanmax(
1595-
self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna
1596-
)
1597-
if result is NaT:
1598-
return result
1599-
result = result.view("i8")
1600-
if axis is None or self.ndim == 1:
1601-
return self._box_func(result)
1602-
return self._from_backing_data(result)
1603-
16041617
result = nanops.nanmax(self._ndarray, axis=axis, skipna=skipna)
16051618
return self._wrap_reduction_result(axis, result)
16061619

@@ -1641,22 +1654,13 @@ def mean(self, *, skipna: bool = True, axis: AxisInt | None = 0):
16411654
)
16421655
return self._wrap_reduction_result(axis, result)
16431656

1657+
@_period_dispatch
16441658
def median(self, *, axis: AxisInt | None = None, skipna: bool = True, **kwargs):
16451659
nv.validate_median((), kwargs)
16461660

16471661
if axis is not None and abs(axis) >= self.ndim:
16481662
raise ValueError("abs(axis) must be less than ndim")
16491663

1650-
if is_period_dtype(self.dtype):
1651-
# pass datetime64 values to nanops to get correct NaT semantics
1652-
result = nanops.nanmedian(
1653-
self._ndarray.view("M8[ns]"), axis=axis, skipna=skipna
1654-
)
1655-
result = result.view("i8")
1656-
if axis is None or self.ndim == 1:
1657-
return self._box_func(result)
1658-
return self._from_backing_data(result)
1659-
16601664
result = nanops.nanmedian(self._ndarray, axis=axis, skipna=skipna)
16611665
return self._wrap_reduction_result(axis, result)
16621666

pandas/core/arrays/period.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -672,31 +672,22 @@ def searchsorted(
672672
) -> npt.NDArray[np.intp] | np.intp:
673673
npvalue = self._validate_setitem_value(value).view("M8[ns]")
674674

675-
# Cast to M8 to get datetime-like NaT placement
675+
# Cast to M8 to get datetime-like NaT placement,
676+
# similar to dtl._period_dispatch
676677
m8arr = self._ndarray.view("M8[ns]")
677678
return m8arr.searchsorted(npvalue, side=side, sorter=sorter)
678679

679680
def fillna(self, value=None, method=None, limit=None) -> PeriodArray:
680681
if method is not None:
681-
# view as dt64 so we get treated as timelike in core.missing
682+
# view as dt64 so we get treated as timelike in core.missing,
683+
# similar to dtl._period_dispatch
682684
dta = self.view("M8[ns]")
683685
result = dta.fillna(value=value, method=method, limit=limit)
684686
# error: Incompatible return value type (got "Union[ExtensionArray,
685687
# ndarray[Any, Any]]", expected "PeriodArray")
686688
return result.view(self.dtype) # type: ignore[return-value]
687689
return super().fillna(value=value, method=method, limit=limit)
688690

689-
def _quantile(
690-
self: PeriodArray,
691-
qs: npt.NDArray[np.float64],
692-
interpolation: str,
693-
) -> PeriodArray:
694-
# dispatch to DatetimeArray implementation
695-
dtres = self.view("M8[ns]")._quantile(qs, interpolation)
696-
# error: Incompatible return value type (got "Union[ExtensionArray,
697-
# ndarray[Any, Any]]", expected "PeriodArray")
698-
return dtres.view(self.dtype) # type: ignore[return-value]
699-
700691
# ------------------------------------------------------------------
701692
# Arithmetic Methods
702693

0 commit comments

Comments
 (0)