Skip to content

Commit 7113926

Browse files
jbrockmendelukarroum
authored andcommitted
CLN: de-duplicate asof_locs (#37567)
1 parent 9bf2e46 commit 7113926

File tree

2 files changed

+8
-19
lines changed

2 files changed

+8
-19
lines changed

pandas/core/indexes/base.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -4503,7 +4503,7 @@ def asof(self, label):
45034503
loc = loc.indices(len(self))[-1]
45044504
return self[loc]
45054505

4506-
def asof_locs(self, where, mask):
4506+
def asof_locs(self, where: "Index", mask) -> np.ndarray:
45074507
"""
45084508
Return the locations (indices) of labels in the index.
45094509
@@ -4531,13 +4531,13 @@ def asof_locs(self, where, mask):
45314531
which correspond to the return values of the `asof` function
45324532
for every element in `where`.
45334533
"""
4534-
locs = self.values[mask].searchsorted(where.values, side="right")
4534+
locs = self._values[mask].searchsorted(where._values, side="right")
45354535
locs = np.where(locs > 0, locs - 1, 0)
45364536

45374537
result = np.arange(len(self))[mask].take(locs)
45384538

45394539
first = mask.argmax()
4540-
result[(locs == 0) & (where.values < self.values[first])] = -1
4540+
result[(locs == 0) & (where._values < self._values[first])] = -1
45414541

45424542
return result
45434543

pandas/core/indexes/period.py

+5-16
Original file line numberDiff line numberDiff line change
@@ -404,28 +404,17 @@ def __array_wrap__(self, result, context=None):
404404
# cannot pass _simple_new as it is
405405
return type(self)(result, freq=self.freq, name=self.name)
406406

407-
def asof_locs(self, where, mask: np.ndarray) -> np.ndarray:
407+
def asof_locs(self, where: Index, mask: np.ndarray) -> np.ndarray:
408408
"""
409409
where : array of timestamps
410410
mask : array of booleans where data is not NA
411411
"""
412-
where_idx = where
413-
if isinstance(where_idx, DatetimeIndex):
414-
where_idx = PeriodIndex(where_idx._values, freq=self.freq)
415-
elif not isinstance(where_idx, PeriodIndex):
412+
if isinstance(where, DatetimeIndex):
413+
where = PeriodIndex(where._values, freq=self.freq)
414+
elif not isinstance(where, PeriodIndex):
416415
raise TypeError("asof_locs `where` must be DatetimeIndex or PeriodIndex")
417-
elif where_idx.freq != self.freq:
418-
raise raise_on_incompatible(self, where_idx)
419416

420-
locs = self.asi8[mask].searchsorted(where_idx.asi8, side="right")
421-
422-
locs = np.where(locs > 0, locs - 1, 0)
423-
result = np.arange(len(self))[mask].take(locs)
424-
425-
first = mask.argmax()
426-
result[(locs == 0) & (where_idx.asi8 < self.asi8[first])] = -1
427-
428-
return result
417+
return super().asof_locs(where, mask)
429418

430419
@doc(Index.astype)
431420
def astype(self, dtype, copy: bool = True, how="start"):

0 commit comments

Comments
 (0)