Skip to content

Commit 63cc520

Browse files
jbrockmendelfeefladder
authored andcommitted
PERF: avoid repeating checks in interpolation (pandas-dev#42963)
1 parent 9fb581d commit 63cc520

File tree

1 file changed

+71
-66
lines changed

1 file changed

+71
-66
lines changed

pandas/core/missing.py

+71-66
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def interpolate_array_2d(
215215
**kwargs,
216216
):
217217
"""
218-
Wrapper to dispatch to either interpolate_2d or interpolate_2d_with_fill.
218+
Wrapper to dispatch to either interpolate_2d or _interpolate_2d_with_fill.
219219
"""
220220
try:
221221
m = clean_fill_method(method)
@@ -237,7 +237,7 @@ def interpolate_array_2d(
237237
else:
238238
assert index is not None # for mypy
239239

240-
interp_values = interpolate_2d_with_fill(
240+
interp_values = _interpolate_2d_with_fill(
241241
data=data,
242242
index=index,
243243
axis=axis,
@@ -251,7 +251,7 @@ def interpolate_array_2d(
251251
return interp_values
252252

253253

254-
def interpolate_2d_with_fill(
254+
def _interpolate_2d_with_fill(
255255
data: np.ndarray, # floating dtype
256256
index: Index,
257257
axis: int,
@@ -263,11 +263,11 @@ def interpolate_2d_with_fill(
263263
**kwargs,
264264
) -> np.ndarray:
265265
"""
266-
Column-wise application of interpolate_1d.
266+
Column-wise application of _interpolate_1d.
267267
268268
Notes
269269
-----
270-
The signature does differs from interpolate_1d because it only
270+
The signature does differs from _interpolate_1d because it only
271271
includes what is needed for Block.interpolate.
272272
"""
273273
# validate the interp method
@@ -276,13 +276,44 @@ def interpolate_2d_with_fill(
276276
if is_valid_na_for_dtype(fill_value, data.dtype):
277277
fill_value = na_value_for_dtype(data.dtype, compat=False)
278278

279+
if method == "time":
280+
if not needs_i8_conversion(index.dtype):
281+
raise ValueError(
282+
"time-weighted interpolation only works "
283+
"on Series or DataFrames with a "
284+
"DatetimeIndex"
285+
)
286+
method = "values"
287+
288+
valid_limit_directions = ["forward", "backward", "both"]
289+
limit_direction = limit_direction.lower()
290+
if limit_direction not in valid_limit_directions:
291+
raise ValueError(
292+
"Invalid limit_direction: expecting one of "
293+
f"{valid_limit_directions}, got '{limit_direction}'."
294+
)
295+
296+
if limit_area is not None:
297+
valid_limit_areas = ["inside", "outside"]
298+
limit_area = limit_area.lower()
299+
if limit_area not in valid_limit_areas:
300+
raise ValueError(
301+
f"Invalid limit_area: expecting one of {valid_limit_areas}, got "
302+
f"{limit_area}."
303+
)
304+
305+
# default limit is unlimited GH #16282
306+
limit = algos.validate_limit(nobs=None, limit=limit)
307+
308+
indices = _index_to_interp_indices(index, method)
309+
279310
def func(yvalues: np.ndarray) -> np.ndarray:
280311
# process 1-d slices in the axis direction, returning it
281312

282313
# should the axis argument be handled below in apply_along_axis?
283-
# i.e. not an arg to interpolate_1d
284-
return interpolate_1d(
285-
xvalues=index,
314+
# i.e. not an arg to _interpolate_1d
315+
return _interpolate_1d(
316+
indices=indices,
286317
yvalues=yvalues,
287318
method=method,
288319
limit=limit,
@@ -297,8 +328,30 @@ def func(yvalues: np.ndarray) -> np.ndarray:
297328
return np.apply_along_axis(func, axis, data)
298329

299330

300-
def interpolate_1d(
301-
xvalues: Index,
331+
def _index_to_interp_indices(index: Index, method: str) -> np.ndarray:
332+
"""
333+
Convert Index to ndarray of indices to pass to NumPy/SciPy.
334+
"""
335+
xarr = index._values
336+
if needs_i8_conversion(xarr.dtype):
337+
# GH#1646 for dt64tz
338+
xarr = xarr.view("i8")
339+
340+
if method == "linear":
341+
inds = xarr
342+
inds = cast(np.ndarray, inds)
343+
else:
344+
inds = np.asarray(xarr)
345+
346+
if method in ("values", "index"):
347+
if inds.dtype == np.object_:
348+
inds = lib.maybe_convert_objects(inds)
349+
350+
return inds
351+
352+
353+
def _interpolate_1d(
354+
indices: np.ndarray,
302355
yvalues: np.ndarray,
303356
method: str | None = "linear",
304357
limit: int | None = None,
@@ -311,51 +364,23 @@ def interpolate_1d(
311364
):
312365
"""
313366
Logic for the 1-d interpolation. The result should be 1-d, inputs
314-
xvalues and yvalues will each be 1-d arrays of the same length.
367+
indices and yvalues will each be 1-d arrays of the same length.
315368
316369
Bounds_error is currently hardcoded to False since non-scipy ones don't
317370
take it as an argument.
318371
"""
372+
319373
invalid = isna(yvalues)
320374
valid = ~invalid
321375

322376
if not valid.any():
323-
result = np.empty(xvalues.shape, dtype=np.float64)
377+
result = np.empty(indices.shape, dtype=np.float64)
324378
result.fill(np.nan)
325379
return result
326380

327381
if valid.all():
328382
return yvalues
329383

330-
if method == "time":
331-
if not needs_i8_conversion(xvalues.dtype):
332-
raise ValueError(
333-
"time-weighted interpolation only works "
334-
"on Series or DataFrames with a "
335-
"DatetimeIndex"
336-
)
337-
method = "values"
338-
339-
valid_limit_directions = ["forward", "backward", "both"]
340-
limit_direction = limit_direction.lower()
341-
if limit_direction not in valid_limit_directions:
342-
raise ValueError(
343-
"Invalid limit_direction: expecting one of "
344-
f"{valid_limit_directions}, got '{limit_direction}'."
345-
)
346-
347-
if limit_area is not None:
348-
valid_limit_areas = ["inside", "outside"]
349-
limit_area = limit_area.lower()
350-
if limit_area not in valid_limit_areas:
351-
raise ValueError(
352-
f"Invalid limit_area: expecting one of {valid_limit_areas}, got "
353-
f"{limit_area}."
354-
)
355-
356-
# default limit is unlimited GH #16282
357-
limit = algos.validate_limit(nobs=None, limit=limit)
358-
359384
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
360385
all_nans = set(np.flatnonzero(invalid))
361386

@@ -369,8 +394,6 @@ def interpolate_1d(
369394
last_valid_index = len(yvalues)
370395
end_nans = set(range(1 + last_valid_index, len(valid)))
371396

372-
mid_nans = all_nans - start_nans - end_nans
373-
374397
# Like the sets above, preserve_nans contains indices of invalid values,
375398
# but in this case, it is the final set of indices that need to be
376399
# preserved as NaN after the interpolation.
@@ -396,44 +419,26 @@ def interpolate_1d(
396419
preserve_nans |= start_nans | end_nans
397420
elif limit_area == "outside":
398421
# preserve NaNs on the inside
422+
mid_nans = all_nans - start_nans - end_nans
399423
preserve_nans |= mid_nans
400424

401425
# sort preserve_nans and convert to list
402426
preserve_nans = sorted(preserve_nans)
403427

404428
result = yvalues.copy()
405429

406-
# xarr to pass to NumPy/SciPy
407-
xarr = xvalues._values
408-
if needs_i8_conversion(xarr.dtype):
409-
# GH#1646 for dt64tz
410-
xarr = xarr.view("i8")
411-
412-
if method == "linear":
413-
inds = xarr
414-
else:
415-
inds = np.asarray(xarr)
416-
417-
if method in ("values", "index"):
418-
if inds.dtype == np.object_:
419-
inds = lib.maybe_convert_objects(inds)
420-
421430
if method in NP_METHODS:
422431
# np.interp requires sorted X values, #21037
423432

424-
# error: Argument 1 to "argsort" has incompatible type "Union[ExtensionArray,
425-
# Any]"; expected "Union[Union[int, float, complex, str, bytes, generic],
426-
# Sequence[Union[int, float, complex, str, bytes, generic]],
427-
# Sequence[Sequence[Any]], _SupportsArray]"
428-
indexer = np.argsort(inds[valid]) # type: ignore[arg-type]
433+
indexer = np.argsort(indices[valid])
429434
result[invalid] = np.interp(
430-
inds[invalid], inds[valid][indexer], yvalues[valid][indexer]
435+
indices[invalid], indices[valid][indexer], yvalues[valid][indexer]
431436
)
432437
else:
433438
result[invalid] = _interpolate_scipy_wrapper(
434-
inds[valid],
439+
indices[valid],
435440
yvalues[valid],
436-
inds[invalid],
441+
indices[invalid],
437442
method=method,
438443
fill_value=fill_value,
439444
bounds_error=bounds_error,

0 commit comments

Comments
 (0)