Skip to content

PERF: interpolate_1d returns function to apply columnwise #34728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions pandas/core/internals/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,22 +1207,16 @@ def _interpolate(
)
# process 1-d slices in the axis direction

def func(yvalues: np.ndarray) -> np.ndarray:

# process a 1-d slice, returning it
# should the axis argument be handled below in apply_along_axis?
# i.e. not an arg to missing.interpolate_1d
return missing.interpolate_1d(
xvalues=index,
yvalues=yvalues,
method=method,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
fill_value=fill_value,
bounds_error=False,
**kwargs,
)
func = missing.Interpolator1d(
xvalues=index,
method=method,
limit=limit,
limit_direction=limit_direction,
limit_area=limit_area,
fill_value=fill_value,
bounds_error=False,
**kwargs,
).interpolate

# interp each column independently
interp_values = np.apply_along_axis(func, axis, data)
Expand Down
257 changes: 142 additions & 115 deletions pandas/core/missing.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,143 +168,170 @@ def find_valid_index(values, how: str):
return idxpos


def interpolate_1d(
xvalues: np.ndarray,
yvalues: np.ndarray,
method: Optional[str] = "linear",
limit: Optional[int] = None,
limit_direction: str = "forward",
limit_area: Optional[str] = None,
fill_value: Optional[Any] = None,
bounds_error: bool = False,
order: Optional[int] = None,
**kwargs,
):
class Interpolator1d:
"""
Logic for the 1-d interpolation. The result should be 1-d, inputs
xvalues and yvalues will each be 1-d arrays of the same length.

Bounds_error is currently hardcoded to False since non-scipy ones don't
take it as an argument.
"""
invalid = isna(yvalues)
valid = ~invalid

if not valid.any():
# have to call np.asarray(xvalues) since xvalues could be an Index
# which can't be mutated
result = np.empty_like(np.asarray(xvalues), dtype=np.float64)
result.fill(np.nan)
return result

if valid.all():
return yvalues

if method == "time":
if not getattr(xvalues, "is_all_dates", None):
# if not issubclass(xvalues.dtype.type, np.datetime64):
raise ValueError(
"time-weighted interpolation only works "
"on Series or DataFrames with a "
"DatetimeIndex"
def __init__(
self,
xvalues: np.ndarray,
method: Optional[str] = "linear",
limit: Optional[int] = None,
limit_direction: str = "forward",
limit_area: Optional[str] = None,
fill_value: Optional[Any] = None,
bounds_error: bool = False,
order: Optional[int] = None,
**kwargs,
):
method = self._validate_method(method, xvalues)
xvalues = self._convert_xvalues(xvalues, method)

# default limit is unlimited GH #16282
self.limit = algos._validate_limit(nobs=None, limit=limit)
self.limit_direction = self._validate_limit_direction(limit_direction)
self.limit_area = self._validate_limit_area(limit_area)

def _sp_func(yvalues, valid, invalid):
return _interpolate_scipy_wrapper(
xvalues[valid],
yvalues[valid],
xvalues[invalid],
method=method,
fill_value=fill_value,
bounds_error=bounds_error,
order=order,
**kwargs,
)
method = "values"

valid_limit_directions = ["forward", "backward", "both"]
limit_direction = limit_direction.lower()
if limit_direction not in valid_limit_directions:
raise ValueError(
"Invalid limit_direction: expecting one of "
f"{valid_limit_directions}, got '{limit_direction}'."
)

if limit_area is not None:
valid_limit_areas = ["inside", "outside"]
limit_area = limit_area.lower()
if limit_area not in valid_limit_areas:
if method in NP_METHODS:
self.interpolator = NumPyInterpolator(xvalues).interpolate
else:
self.interpolator = _sp_func

def _convert_xvalues(self, xvalues, method):
"""
Convert xvalues to pass to NumPy/SciPy.
"""
xvalues = getattr(xvalues, "values", xvalues)
if method == "linear":
inds = xvalues
else:
inds = np.asarray(xvalues)

# hack for DatetimeIndex, #1646
if needs_i8_conversion(inds.dtype):
inds = inds.view(np.int64)

if method in ("values", "index"):
if inds.dtype == np.object_:
inds = lib.maybe_convert_objects(inds)
return inds

def _validate_method(self, method, xvalues):
if method == "time":
if not getattr(xvalues, "is_all_dates", None):
raise ValueError(
"time-weighted interpolation only works "
"on Series or DataFrames with a "
"DatetimeIndex"
)
method = "values"
return method

def _validate_limit_direction(self, limit_direction):
valid_limit_directions = ["forward", "backward", "both"]
limit_direction = limit_direction.lower()
if limit_direction not in valid_limit_directions:
raise ValueError(
f"Invalid limit_area: expecting one of {valid_limit_areas}, got "
f"{limit_area}."
"Invalid limit_direction: expecting one of "
f"{valid_limit_directions}, got '{limit_direction}'."
)
return limit_direction

def _validate_limit_area(self, limit_area):
if limit_area is not None:
valid_limit_areas = ["inside", "outside"]
limit_area = limit_area.lower()
if limit_area not in valid_limit_areas:
raise ValueError(
f"Invalid limit_area: expecting one of {valid_limit_areas}, got "
f"{limit_area}."
)
return limit_area

def _update_invalid_to_preserve_nans(self, yvalues, valid, invalid) -> None:
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
all_nans = set(np.flatnonzero(invalid))
start_nans = set(range(find_valid_index(yvalues, "first")))
end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid)))
mid_nans = all_nans - start_nans - end_nans

# Like the sets above, preserve_nans contains indices of invalid values,
# but in this case, it is the final set of indices that need to be
# preserved as NaN after the interpolation.

# For example if limit_direction='forward' then preserve_nans will
# contain indices of NaNs at the beginning of the series, and NaNs that
# are more than'limit' away from the prior non-NaN.

# set preserve_nans based on direction using _interp_limit
preserve_nans: Union[List, Set]
if self.limit_direction == "forward":
preserve_nans = start_nans | set(_interp_limit(invalid, self.limit, 0))
elif self.limit_direction == "backward":
preserve_nans = end_nans | set(_interp_limit(invalid, 0, self.limit))
else:
# both directions... just use _interp_limit
preserve_nans = set(_interp_limit(invalid, self.limit, self.limit))

# default limit is unlimited GH #16282
limit = algos._validate_limit(nobs=None, limit=limit)

# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
all_nans = set(np.flatnonzero(invalid))
start_nans = set(range(find_valid_index(yvalues, "first")))
end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid)))
mid_nans = all_nans - start_nans - end_nans

# Like the sets above, preserve_nans contains indices of invalid values,
# but in this case, it is the final set of indices that need to be
# preserved as NaN after the interpolation.

# For example if limit_direction='forward' then preserve_nans will
# contain indices of NaNs at the beginning of the series, and NaNs that
# are more than'limit' away from the prior non-NaN.

# set preserve_nans based on direction using _interp_limit
preserve_nans: Union[List, Set]
if limit_direction == "forward":
preserve_nans = start_nans | set(_interp_limit(invalid, limit, 0))
elif limit_direction == "backward":
preserve_nans = end_nans | set(_interp_limit(invalid, 0, limit))
else:
# both directions... just use _interp_limit
preserve_nans = set(_interp_limit(invalid, limit, limit))
# if limit_area is set, add either mid or outside indices
# to preserve_nans GH #16284
if self.limit_area == "inside":
# preserve NaNs on the outside
preserve_nans |= start_nans | end_nans
elif self.limit_area == "outside":
# preserve NaNs on the inside
preserve_nans |= mid_nans

# if limit_area is set, add either mid or outside indices
# to preserve_nans GH #16284
if limit_area == "inside":
# preserve NaNs on the outside
preserve_nans |= start_nans | end_nans
elif limit_area == "outside":
# preserve NaNs on the inside
preserve_nans |= mid_nans
# sort preserve_nans and covert to list
preserve_nans = sorted(preserve_nans)
invalid[preserve_nans] = False

# sort preserve_nans and covert to list
preserve_nans = sorted(preserve_nans)
def interpolate(self, yvalues: np.ndarray) -> np.ndarray:
invalid = isna(yvalues)
valid = ~invalid

yvalues = getattr(yvalues, "values", yvalues)
result = yvalues.copy()
if not valid.any() or valid.all():
return yvalues

# xvalues to pass to NumPy/SciPy
yvalues = getattr(yvalues, "values", yvalues)
result = yvalues.copy()

xvalues = getattr(xvalues, "values", xvalues)
if method == "linear":
inds = xvalues
else:
inds = np.asarray(xvalues)
self._update_invalid_to_preserve_nans(yvalues, valid, invalid)

# hack for DatetimeIndex, #1646
if needs_i8_conversion(inds.dtype):
inds = inds.view(np.int64)
result[invalid] = self.interpolator(yvalues, valid, invalid)
return result

if method in ("values", "index"):
if inds.dtype == np.object_:
inds = lib.maybe_convert_objects(inds)

if method in NP_METHODS:
# np.interp requires sorted X values, #21037
indexer = np.argsort(inds[valid])
result[invalid] = np.interp(
inds[invalid], inds[valid][indexer], yvalues[valid][indexer]
)
else:
result[invalid] = _interpolate_scipy_wrapper(
inds[valid],
yvalues[valid],
inds[invalid],
method=method,
fill_value=fill_value,
bounds_error=bounds_error,
order=order,
**kwargs,
)
class NumPyInterpolator:
# np.interp requires sorted X values, #21037
def __init__(self, xvalues: np.ndarray):
self.xvalues = xvalues
self.indexer = np.argsort(xvalues)
self.xvalues_sorted = xvalues[self.indexer]

result[preserve_nans] = np.nan
return result
def interpolate(self, yvalues, valid, invalid):
valid_sorted = valid[self.indexer]
x = self.xvalues[invalid]
xp = self.xvalues_sorted[valid_sorted]
yp = yvalues[self.indexer][valid_sorted]
return np.interp(x, xp, yp)


def _interpolate_scipy_wrapper(
Expand Down