Skip to content

Commit edd401b

Browse files
authored
PERF: Calculate mask in interpolate only once (#50326)
* PERF: Calculate mask in interpolate only once * Fix test * Fix mypy
1 parent ac762af commit edd401b

File tree

2 files changed

+14
-9
lines changed

2 files changed

+14
-9
lines changed

pandas/core/generic.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -11643,7 +11643,7 @@ def _find_valid_index(self, *, how: str) -> Hashable | None:
1164311643
-------
1164411644
idx_first_valid : type of index
1164511645
"""
11646-
idxpos = find_valid_index(self._values, how=how)
11646+
idxpos = find_valid_index(self._values, how=how, is_valid=~isna(self._values))
1164711647
if idxpos is None:
1164811648
return None
1164911649
return self.index[idxpos]

pandas/core/missing.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,9 @@ def clean_interp_method(method: str, index: Index, **kwargs) -> str:
170170
return method
171171

172172

173-
def find_valid_index(values, *, how: str) -> int | None:
173+
def find_valid_index(
174+
values, *, how: str, is_valid: npt.NDArray[np.bool_]
175+
) -> int | None:
174176
"""
175177
Retrieves the index of the first valid value.
176178
@@ -179,6 +181,8 @@ def find_valid_index(values, *, how: str) -> int | None:
179181
values : ndarray or ExtensionArray
180182
how : {'first', 'last'}
181183
Use this parameter to change between the first or last valid index.
184+
is_valid: np.ndarray
185+
Mask to find na_values.
182186
183187
Returns
184188
-------
@@ -189,8 +193,6 @@ def find_valid_index(values, *, how: str) -> int | None:
189193
if len(values) == 0: # early stop
190194
return None
191195

192-
is_valid = ~isna(values)
193-
194196
if values.ndim == 2:
195197
is_valid = is_valid.any(axis=1) # reduce axis 1
196198

@@ -204,7 +206,9 @@ def find_valid_index(values, *, how: str) -> int | None:
204206

205207
if not chk_notna:
206208
return None
207-
return idxpos
209+
# Incompatible return value type (got "signedinteger[Any]",
210+
# expected "Optional[int]")
211+
return idxpos # type: ignore[return-value]
208212

209213

210214
def interpolate_array_2d(
@@ -400,12 +404,12 @@ def _interpolate_1d(
400404
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
401405
all_nans = set(np.flatnonzero(invalid))
402406

403-
first_valid_index = find_valid_index(yvalues, how="first")
407+
first_valid_index = find_valid_index(yvalues, how="first", is_valid=valid)
404408
if first_valid_index is None: # no nan found in start
405409
first_valid_index = 0
406410
start_nans = set(range(first_valid_index))
407411

408-
last_valid_index = find_valid_index(yvalues, how="last")
412+
last_valid_index = find_valid_index(yvalues, how="last", is_valid=valid)
409413
if last_valid_index is None: # no nan found in end
410414
last_valid_index = len(yvalues)
411415
end_nans = set(range(1 + last_valid_index, len(valid)))
@@ -738,12 +742,13 @@ def _interpolate_with_limit_area(
738742
"""
739743

740744
invalid = isna(values)
745+
is_valid = ~invalid
741746

742747
if not invalid.all():
743-
first = find_valid_index(values, how="first")
748+
first = find_valid_index(values, how="first", is_valid=is_valid)
744749
if first is None:
745750
first = 0
746-
last = find_valid_index(values, how="last")
751+
last = find_valid_index(values, how="last", is_valid=is_valid)
747752
if last is None:
748753
last = len(values)
749754

0 commit comments

Comments
 (0)