Skip to content

Commit 473615e

Browse files
PERF: vectorize _interp_limit (pandas-dev#16592)
* PERF: vectorize _interp_limit * CLN: remove old implementation * fixup! CLN: remove old implementation
1 parent 9e620bc commit 473615e

File tree

1 file changed

+67
-10
lines changed

1 file changed

+67
-10
lines changed

pandas/core/missing.py

+67-10
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,6 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
143143
'DatetimeIndex')
144144
method = 'values'
145145

146-
def _interp_limit(invalid, fw_limit, bw_limit):
147-
"Get idx of values that won't be filled b/c they exceed the limits."
148-
for x in np.where(invalid)[0]:
149-
if invalid[max(0, x - fw_limit):x + bw_limit + 1].all():
150-
yield x
151-
152146
valid_limit_directions = ['forward', 'backward', 'both']
153147
limit_direction = limit_direction.lower()
154148
if limit_direction not in valid_limit_directions:
@@ -180,21 +174,29 @@ def _interp_limit(invalid, fw_limit, bw_limit):
180174

181175
# default limit is unlimited GH #16282
182176
if limit is None:
183-
limit = len(xvalues)
177+
# limit = len(xvalues)
178+
pass
184179
elif not is_integer(limit):
185180
raise ValueError('Limit must be an integer')
186181
elif limit < 1:
187182
raise ValueError('Limit must be greater than 0')
188183

189184
# each possible limit_direction
190-
if limit_direction == 'forward':
185+
# TODO: do we need sorted?
186+
if limit_direction == 'forward' and limit is not None:
191187
violate_limit = sorted(start_nans |
192188
set(_interp_limit(invalid, limit, 0)))
193-
elif limit_direction == 'backward':
189+
elif limit_direction == 'forward':
190+
violate_limit = sorted(start_nans)
191+
elif limit_direction == 'backward' and limit is not None:
194192
violate_limit = sorted(end_nans |
195193
set(_interp_limit(invalid, 0, limit)))
196-
elif limit_direction == 'both':
194+
elif limit_direction == 'backward':
195+
violate_limit = sorted(end_nans)
196+
elif limit_direction == 'both' and limit is not None:
197197
violate_limit = sorted(_interp_limit(invalid, limit, limit))
198+
else:
199+
violate_limit = []
198200

199201
xvalues = getattr(xvalues, 'values', xvalues)
200202
yvalues = getattr(yvalues, 'values', yvalues)
@@ -630,3 +632,58 @@ def fill_zeros(result, x, y, name, fill):
630632
result = result.reshape(shape)
631633

632634
return result
635+
636+
637+
def _interp_limit(invalid, fw_limit, bw_limit):
638+
"""Get idx of values that won't be filled b/c they exceed the limits.
639+
640+
This is equivalent to the more readable, but slower
641+
642+
.. code-block:: python
643+
644+
for x in np.where(invalid)[0]:
645+
if invalid[max(0, x - fw_limit):x + bw_limit + 1].all():
646+
yield x
647+
"""
648+
# handle forward first; the backward direction is the same except
649+
# 1. operate on the reversed array
650+
# 2. subtract the returned indicies from N - 1
651+
N = len(invalid)
652+
653+
def inner(invalid, limit):
654+
limit = min(limit, N)
655+
windowed = _rolling_window(invalid, limit + 1).all(1)
656+
idx = (set(np.where(windowed)[0] + limit) |
657+
set(np.where((~invalid[:limit + 1]).cumsum() == 0)[0]))
658+
return idx
659+
660+
if fw_limit == 0:
661+
f_idx = set(np.where(invalid)[0])
662+
else:
663+
f_idx = inner(invalid, fw_limit)
664+
665+
if bw_limit == 0:
666+
# then we don't even need to care about backwards, just use forwards
667+
return f_idx
668+
else:
669+
b_idx = set(N - 1 - np.asarray(list(inner(invalid[::-1], bw_limit))))
670+
if fw_limit == 0:
671+
return b_idx
672+
return f_idx & b_idx
673+
674+
675+
def _rolling_window(a, window):
676+
"""
677+
[True, True, False, True, False], 2 ->
678+
679+
[
680+
[True, True],
681+
[True, False],
682+
[False, True],
683+
[True, False],
684+
]
685+
"""
686+
# https://stackoverflow.com/a/6811241
687+
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
688+
strides = a.strides + (a.strides[-1],)
689+
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

0 commit comments

Comments
 (0)