diff --git a/pandas/core/internals/blocks.py b/pandas/core/internals/blocks.py index d8779dae7c384..d2d34ab94e26e 100644 --- a/pandas/core/internals/blocks.py +++ b/pandas/core/internals/blocks.py @@ -1207,22 +1207,16 @@ def _interpolate( ) # process 1-d slices in the axis direction - def func(yvalues: np.ndarray) -> np.ndarray: - - # process a 1-d slice, returning it - # should the axis argument be handled below in apply_along_axis? - # i.e. not an arg to missing.interpolate_1d - return missing.interpolate_1d( - xvalues=index, - yvalues=yvalues, - method=method, - limit=limit, - limit_direction=limit_direction, - limit_area=limit_area, - fill_value=fill_value, - bounds_error=False, - **kwargs, - ) + func = missing.Interpolator1d( + xvalues=index, + method=method, + limit=limit, + limit_direction=limit_direction, + limit_area=limit_area, + fill_value=fill_value, + bounds_error=False, + **kwargs, + ).interpolate # interp each column independently interp_values = np.apply_along_axis(func, axis, data) diff --git a/pandas/core/missing.py b/pandas/core/missing.py index 7802c5cbdbfb3..16fad3ef9a5cd 100644 --- a/pandas/core/missing.py +++ b/pandas/core/missing.py @@ -168,18 +168,7 @@ def find_valid_index(values, how: str): return idxpos -def interpolate_1d( - xvalues: np.ndarray, - yvalues: np.ndarray, - method: Optional[str] = "linear", - limit: Optional[int] = None, - limit_direction: str = "forward", - limit_area: Optional[str] = None, - fill_value: Optional[Any] = None, - bounds_error: bool = False, - order: Optional[int] = None, - **kwargs, -): +class Interpolator1d: """ Logic for the 1-d interpolation. The result should be 1-d, inputs xvalues and yvalues will each be 1-d arrays of the same length. @@ -187,124 +176,162 @@ def interpolate_1d( Bounds_error is currently hardcoded to False since non-scipy ones don't take it as an argument. """ - invalid = isna(yvalues) - valid = ~invalid - - if not valid.any(): - # have to call np.asarray(xvalues) since xvalues could be an Index - # which can't be mutated - result = np.empty_like(np.asarray(xvalues), dtype=np.float64) - result.fill(np.nan) - return result - if valid.all(): - return yvalues - - if method == "time": - if not getattr(xvalues, "is_all_dates", None): - # if not issubclass(xvalues.dtype.type, np.datetime64): - raise ValueError( - "time-weighted interpolation only works " - "on Series or DataFrames with a " - "DatetimeIndex" + def __init__( + self, + xvalues: np.ndarray, + method: Optional[str] = "linear", + limit: Optional[int] = None, + limit_direction: str = "forward", + limit_area: Optional[str] = None, + fill_value: Optional[Any] = None, + bounds_error: bool = False, + order: Optional[int] = None, + **kwargs, + ): + method = self._validate_method(method, xvalues) + xvalues = self._convert_xvalues(xvalues, method) + + # default limit is unlimited GH #16282 + self.limit = algos._validate_limit(nobs=None, limit=limit) + self.limit_direction = self._validate_limit_direction(limit_direction) + self.limit_area = self._validate_limit_area(limit_area) + + def _sp_func(yvalues, valid, invalid): + return _interpolate_scipy_wrapper( + xvalues[valid], + yvalues[valid], + xvalues[invalid], + method=method, + fill_value=fill_value, + bounds_error=bounds_error, + order=order, + **kwargs, ) - method = "values" - - valid_limit_directions = ["forward", "backward", "both"] - limit_direction = limit_direction.lower() - if limit_direction not in valid_limit_directions: - raise ValueError( - "Invalid limit_direction: expecting one of " - f"{valid_limit_directions}, got '{limit_direction}'." - ) - if limit_area is not None: - valid_limit_areas = ["inside", "outside"] - limit_area = limit_area.lower() - if limit_area not in valid_limit_areas: + if method in NP_METHODS: + self.interpolator = NumPyInterpolator(xvalues).interpolate + else: + self.interpolator = _sp_func + + def _convert_xvalues(self, xvalues, method): + """ + Convert xvalues to pass to NumPy/SciPy. + """ + xvalues = getattr(xvalues, "values", xvalues) + if method == "linear": + inds = xvalues + else: + inds = np.asarray(xvalues) + + # hack for DatetimeIndex, #1646 + if needs_i8_conversion(inds.dtype): + inds = inds.view(np.int64) + + if method in ("values", "index"): + if inds.dtype == np.object_: + inds = lib.maybe_convert_objects(inds) + return inds + + def _validate_method(self, method, xvalues): + if method == "time": + if not getattr(xvalues, "is_all_dates", None): + raise ValueError( + "time-weighted interpolation only works " + "on Series or DataFrames with a " + "DatetimeIndex" + ) + method = "values" + return method + + def _validate_limit_direction(self, limit_direction): + valid_limit_directions = ["forward", "backward", "both"] + limit_direction = limit_direction.lower() + if limit_direction not in valid_limit_directions: raise ValueError( - f"Invalid limit_area: expecting one of {valid_limit_areas}, got " - f"{limit_area}." + "Invalid limit_direction: expecting one of " + f"{valid_limit_directions}, got '{limit_direction}'." ) + return limit_direction + + def _validate_limit_area(self, limit_area): + if limit_area is not None: + valid_limit_areas = ["inside", "outside"] + limit_area = limit_area.lower() + if limit_area not in valid_limit_areas: + raise ValueError( + f"Invalid limit_area: expecting one of {valid_limit_areas}, got " + f"{limit_area}." + ) + return limit_area + + def _update_invalid_to_preserve_nans(self, yvalues, valid, invalid) -> None: + # These are sets of index pointers to invalid values... i.e. {0, 1, etc... + all_nans = set(np.flatnonzero(invalid)) + start_nans = set(range(find_valid_index(yvalues, "first"))) + end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid))) + mid_nans = all_nans - start_nans - end_nans + + # Like the sets above, preserve_nans contains indices of invalid values, + # but in this case, it is the final set of indices that need to be + # preserved as NaN after the interpolation. + + # For example if limit_direction='forward' then preserve_nans will + # contain indices of NaNs at the beginning of the series, and NaNs that + # are more than'limit' away from the prior non-NaN. + + # set preserve_nans based on direction using _interp_limit + preserve_nans: Union[List, Set] + if self.limit_direction == "forward": + preserve_nans = start_nans | set(_interp_limit(invalid, self.limit, 0)) + elif self.limit_direction == "backward": + preserve_nans = end_nans | set(_interp_limit(invalid, 0, self.limit)) + else: + # both directions... just use _interp_limit + preserve_nans = set(_interp_limit(invalid, self.limit, self.limit)) - # default limit is unlimited GH #16282 - limit = algos._validate_limit(nobs=None, limit=limit) - - # These are sets of index pointers to invalid values... i.e. {0, 1, etc... - all_nans = set(np.flatnonzero(invalid)) - start_nans = set(range(find_valid_index(yvalues, "first"))) - end_nans = set(range(1 + find_valid_index(yvalues, "last"), len(valid))) - mid_nans = all_nans - start_nans - end_nans - - # Like the sets above, preserve_nans contains indices of invalid values, - # but in this case, it is the final set of indices that need to be - # preserved as NaN after the interpolation. - - # For example if limit_direction='forward' then preserve_nans will - # contain indices of NaNs at the beginning of the series, and NaNs that - # are more than'limit' away from the prior non-NaN. - - # set preserve_nans based on direction using _interp_limit - preserve_nans: Union[List, Set] - if limit_direction == "forward": - preserve_nans = start_nans | set(_interp_limit(invalid, limit, 0)) - elif limit_direction == "backward": - preserve_nans = end_nans | set(_interp_limit(invalid, 0, limit)) - else: - # both directions... just use _interp_limit - preserve_nans = set(_interp_limit(invalid, limit, limit)) + # if limit_area is set, add either mid or outside indices + # to preserve_nans GH #16284 + if self.limit_area == "inside": + # preserve NaNs on the outside + preserve_nans |= start_nans | end_nans + elif self.limit_area == "outside": + # preserve NaNs on the inside + preserve_nans |= mid_nans - # if limit_area is set, add either mid or outside indices - # to preserve_nans GH #16284 - if limit_area == "inside": - # preserve NaNs on the outside - preserve_nans |= start_nans | end_nans - elif limit_area == "outside": - # preserve NaNs on the inside - preserve_nans |= mid_nans + # sort preserve_nans and covert to list + preserve_nans = sorted(preserve_nans) + invalid[preserve_nans] = False - # sort preserve_nans and covert to list - preserve_nans = sorted(preserve_nans) + def interpolate(self, yvalues: np.ndarray) -> np.ndarray: + invalid = isna(yvalues) + valid = ~invalid - yvalues = getattr(yvalues, "values", yvalues) - result = yvalues.copy() + if not valid.any() or valid.all(): + return yvalues - # xvalues to pass to NumPy/SciPy + yvalues = getattr(yvalues, "values", yvalues) + result = yvalues.copy() - xvalues = getattr(xvalues, "values", xvalues) - if method == "linear": - inds = xvalues - else: - inds = np.asarray(xvalues) + self._update_invalid_to_preserve_nans(yvalues, valid, invalid) - # hack for DatetimeIndex, #1646 - if needs_i8_conversion(inds.dtype): - inds = inds.view(np.int64) + result[invalid] = self.interpolator(yvalues, valid, invalid) + return result - if method in ("values", "index"): - if inds.dtype == np.object_: - inds = lib.maybe_convert_objects(inds) - if method in NP_METHODS: - # np.interp requires sorted X values, #21037 - indexer = np.argsort(inds[valid]) - result[invalid] = np.interp( - inds[invalid], inds[valid][indexer], yvalues[valid][indexer] - ) - else: - result[invalid] = _interpolate_scipy_wrapper( - inds[valid], - yvalues[valid], - inds[invalid], - method=method, - fill_value=fill_value, - bounds_error=bounds_error, - order=order, - **kwargs, - ) +class NumPyInterpolator: + # np.interp requires sorted X values, #21037 + def __init__(self, xvalues: np.ndarray): + self.xvalues = xvalues + self.indexer = np.argsort(xvalues) + self.xvalues_sorted = xvalues[self.indexer] - result[preserve_nans] = np.nan - return result + def interpolate(self, yvalues, valid, invalid): + valid_sorted = valid[self.indexer] + x = self.xvalues[invalid] + xp = self.xvalues_sorted[valid_sorted] + yp = yvalues[self.indexer][valid_sorted] + return np.interp(x, xp, yp) def _interpolate_scipy_wrapper(