Skip to content

Commit 8c6ff7a

Browse files
committed
Moved code to derive indices of "NaNs to preserve" in separate function
1 parent 3cb371e commit 8c6ff7a

File tree

1 file changed

+57
-37
lines changed

1 file changed

+57
-37
lines changed

pandas/core/missing.py

+57-37
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,61 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None, max_gap=None,
173173
elif max_gap < 1:
174174
raise ValueError('max_gap must be greater than 0')
175175

176+
preserve_nans = _derive_indices_of_nans_to_preserve(
177+
yvalues=yvalues, valid=valid, invalid=invalid,
178+
limit=limit, limit_area=limit_area, limit_direction=limit_direction,
179+
max_gap=max_gap)
180+
181+
xvalues = getattr(xvalues, 'values', xvalues)
182+
yvalues = getattr(yvalues, 'values', yvalues)
183+
result = yvalues.copy()
184+
185+
if method in ['linear', 'time', 'index', 'values']:
186+
if method in ('values', 'index'):
187+
inds = np.asarray(xvalues)
188+
# hack for DatetimeIndex, #1646
189+
if needs_i8_conversion(inds.dtype.type):
190+
inds = inds.view(np.int64)
191+
if inds.dtype == np.object_:
192+
inds = lib.maybe_convert_objects(inds)
193+
else:
194+
inds = xvalues
195+
result[invalid] = np.interp(inds[invalid], inds[valid], yvalues[valid])
196+
result[preserve_nans] = np.nan
197+
return result
198+
199+
sp_methods = ['nearest', 'zero', 'slinear', 'quadratic', 'cubic',
200+
'barycentric', 'krogh', 'spline', 'polynomial',
201+
'from_derivatives', 'piecewise_polynomial', 'pchip', 'akima']
202+
203+
if method in sp_methods:
204+
inds = np.asarray(xvalues)
205+
# hack for DatetimeIndex, #1646
206+
if issubclass(inds.dtype.type, np.datetime64):
207+
inds = inds.view(np.int64)
208+
result[invalid] = _interpolate_scipy_wrapper(inds[valid],
209+
yvalues[valid],
210+
inds[invalid],
211+
method=method,
212+
fill_value=fill_value,
213+
bounds_error=bounds_error,
214+
order=order, **kwargs)
215+
result[preserve_nans] = np.nan
216+
return result
217+
218+
219+
def _derive_indices_of_nans_to_preserve(yvalues, invalid, valid,
220+
limit, limit_area, limit_direction,
221+
max_gap):
222+
""" Derive the indices of NaNs that shall be preserved after interpolation
223+
224+
This function is called by `interpolate_1d` and takes the arguments with
225+
the same name from there. In `interpolate_1d`, after performing the
226+
interpolation the list of indices of NaNs to preserve is used to put
227+
NaNs in the desired locations.
228+
229+
"""
230+
176231
from pandas import Series
177232
ys = Series(yvalues)
178233

@@ -220,7 +275,7 @@ def bfill_nan(arr):
220275
diff[invalid] = np.nan
221276
diff = bfill_nan(diff)
222277
# hack to avoid having trailing NaNs in `diff`. Fill these
223-
# with `max_gap`. Everthing smaller than `max_gap` won't matter
278+
# with `max_gap`. Everything smaller than `max_gap` won't matter
224279
# in the following.
225280
diff[np.isnan(diff)] = max_gap
226281
preserve_nans = set(np.flatnonzero((diff > max_gap) & invalid))
@@ -237,42 +292,7 @@ def bfill_nan(arr):
237292
# sort preserve_nans and covert to list
238293
preserve_nans = sorted(preserve_nans)
239294

240-
xvalues = getattr(xvalues, 'values', xvalues)
241-
yvalues = getattr(yvalues, 'values', yvalues)
242-
result = yvalues.copy()
243-
244-
if method in ['linear', 'time', 'index', 'values']:
245-
if method in ('values', 'index'):
246-
inds = np.asarray(xvalues)
247-
# hack for DatetimeIndex, #1646
248-
if needs_i8_conversion(inds.dtype.type):
249-
inds = inds.view(np.int64)
250-
if inds.dtype == np.object_:
251-
inds = lib.maybe_convert_objects(inds)
252-
else:
253-
inds = xvalues
254-
result[invalid] = np.interp(inds[invalid], inds[valid], yvalues[valid])
255-
result[preserve_nans] = np.nan
256-
return result
257-
258-
sp_methods = ['nearest', 'zero', 'slinear', 'quadratic', 'cubic',
259-
'barycentric', 'krogh', 'spline', 'polynomial',
260-
'from_derivatives', 'piecewise_polynomial', 'pchip', 'akima']
261-
262-
if method in sp_methods:
263-
inds = np.asarray(xvalues)
264-
# hack for DatetimeIndex, #1646
265-
if issubclass(inds.dtype.type, np.datetime64):
266-
inds = inds.view(np.int64)
267-
result[invalid] = _interpolate_scipy_wrapper(inds[valid],
268-
yvalues[valid],
269-
inds[invalid],
270-
method=method,
271-
fill_value=fill_value,
272-
bounds_error=bounds_error,
273-
order=order, **kwargs)
274-
result[preserve_nans] = np.nan
275-
return result
295+
return preserve_nans
276296

277297

278298
def _interpolate_scipy_wrapper(x, y, new_x, method, fill_value=None,

0 commit comments

Comments
 (0)