From d021a0cb287fa7393f85a7e824653ac4de98b64d Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 1 Nov 2020 13:45:43 -0800 Subject: [PATCH 1/2] CLN: de-duplicate asof_locs --- pandas/core/indexes/base.py | 6 +++--- pandas/core/indexes/period.py | 21 ++++----------------- 2 files changed, 7 insertions(+), 20 deletions(-) diff --git a/pandas/core/indexes/base.py b/pandas/core/indexes/base.py index 1938722225b98..b9dc8b3014a20 100644 --- a/pandas/core/indexes/base.py +++ b/pandas/core/indexes/base.py @@ -4490,7 +4490,7 @@ def asof(self, label): loc = loc.indices(len(self))[-1] return self[loc] - def asof_locs(self, where, mask): + def asof_locs(self, where: "Index", mask) -> np.ndarray: """ Return the locations (indices) of labels in the index. @@ -4518,13 +4518,13 @@ def asof_locs(self, where, mask): which correspond to the return values of the `asof` function for every element in `where`. """ - locs = self.values[mask].searchsorted(where.values, side="right") + locs = self._values[mask].searchsorted(where._values, side="right") locs = np.where(locs > 0, locs - 1, 0) result = np.arange(len(self))[mask].take(locs) first = mask.argmax() - result[(locs == 0) & (where.values < self.values[first])] = -1 + result[(locs == 0) & (where._values < self._values[first])] = -1 return result diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index e90d31c6957b7..a11a1ae05fcd4 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -404,28 +404,15 @@ def __array_wrap__(self, result, context=None): # cannot pass _simple_new as it is return type(self)(result, freq=self.freq, name=self.name) - def asof_locs(self, where, mask: np.ndarray) -> np.ndarray: + def asof_locs(self, where: Index, mask: np.ndarray) -> np.ndarray: """ where : array of timestamps mask : array of booleans where data is not NA """ - where_idx = where - if isinstance(where_idx, DatetimeIndex): - where_idx = PeriodIndex(where_idx._values, freq=self.freq) - elif not isinstance(where_idx, PeriodIndex): - raise TypeError("asof_locs `where` must be DatetimeIndex or PeriodIndex") - elif where_idx.freq != self.freq: - raise raise_on_incompatible(self, where_idx) + if isinstance(where, DatetimeIndex): + where = PeriodIndex(where._values, freq=self.freq) - locs = self.asi8[mask].searchsorted(where_idx.asi8, side="right") - - locs = np.where(locs > 0, locs - 1, 0) - result = np.arange(len(self))[mask].take(locs) - - first = mask.argmax() - result[(locs == 0) & (where_idx.asi8 < self.asi8[first])] = -1 - - return result + return super().asof_locs(where, mask) @doc(Index.astype) def astype(self, dtype, copy: bool = True, how="start"): From 22097cf58326783f896099862e39c07df0908bc1 Mon Sep 17 00:00:00 2001 From: Brock Date: Sun, 1 Nov 2020 16:12:24 -0800 Subject: [PATCH 2/2] restore check --- pandas/core/indexes/period.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pandas/core/indexes/period.py b/pandas/core/indexes/period.py index a11a1ae05fcd4..44c20ad0de848 100644 --- a/pandas/core/indexes/period.py +++ b/pandas/core/indexes/period.py @@ -411,6 +411,8 @@ def asof_locs(self, where: Index, mask: np.ndarray) -> np.ndarray: """ if isinstance(where, DatetimeIndex): where = PeriodIndex(where._values, freq=self.freq) + elif not isinstance(where, PeriodIndex): + raise TypeError("asof_locs `where` must be DatetimeIndex or PeriodIndex") return super().asof_locs(where, mask)