@@ -143,12 +143,6 @@ def interpolate_1d(xvalues, yvalues, method='linear', limit=None,
143
143
'DatetimeIndex' )
144
144
method = 'values'
145
145
146
- def _interp_limit (invalid , fw_limit , bw_limit ):
147
- "Get idx of values that won't be filled b/c they exceed the limits."
148
- for x in np .where (invalid )[0 ]:
149
- if invalid [max (0 , x - fw_limit ):x + bw_limit + 1 ].all ():
150
- yield x
151
-
152
146
valid_limit_directions = ['forward' , 'backward' , 'both' ]
153
147
limit_direction = limit_direction .lower ()
154
148
if limit_direction not in valid_limit_directions :
@@ -180,21 +174,29 @@ def _interp_limit(invalid, fw_limit, bw_limit):
180
174
181
175
# default limit is unlimited GH #16282
182
176
if limit is None :
183
- limit = len (xvalues )
177
+ # limit = len(xvalues)
178
+ pass
184
179
elif not is_integer (limit ):
185
180
raise ValueError ('Limit must be an integer' )
186
181
elif limit < 1 :
187
182
raise ValueError ('Limit must be greater than 0' )
188
183
189
184
# each possible limit_direction
190
- if limit_direction == 'forward' :
185
+ # TODO: do we need sorted?
186
+ if limit_direction == 'forward' and limit is not None :
191
187
violate_limit = sorted (start_nans |
192
188
set (_interp_limit (invalid , limit , 0 )))
193
- elif limit_direction == 'backward' :
189
+ elif limit_direction == 'forward' :
190
+ violate_limit = sorted (start_nans )
191
+ elif limit_direction == 'backward' and limit is not None :
194
192
violate_limit = sorted (end_nans |
195
193
set (_interp_limit (invalid , 0 , limit )))
196
- elif limit_direction == 'both' :
194
+ elif limit_direction == 'backward' :
195
+ violate_limit = sorted (end_nans )
196
+ elif limit_direction == 'both' and limit is not None :
197
197
violate_limit = sorted (_interp_limit (invalid , limit , limit ))
198
+ else :
199
+ violate_limit = []
198
200
199
201
xvalues = getattr (xvalues , 'values' , xvalues )
200
202
yvalues = getattr (yvalues , 'values' , yvalues )
@@ -630,3 +632,58 @@ def fill_zeros(result, x, y, name, fill):
630
632
result = result .reshape (shape )
631
633
632
634
return result
635
+
636
+
637
+ def _interp_limit (invalid , fw_limit , bw_limit ):
638
+ """Get idx of values that won't be filled b/c they exceed the limits.
639
+
640
+ This is equivalent to the more readable, but slower
641
+
642
+ .. code-block:: python
643
+
644
+ for x in np.where(invalid)[0]:
645
+ if invalid[max(0, x - fw_limit):x + bw_limit + 1].all():
646
+ yield x
647
+ """
648
+ # handle forward first; the backward direction is the same except
649
+ # 1. operate on the reversed array
650
+ # 2. subtract the returned indicies from N - 1
651
+ N = len (invalid )
652
+
653
+ def inner (invalid , limit ):
654
+ limit = min (limit , N )
655
+ windowed = _rolling_window (invalid , limit + 1 ).all (1 )
656
+ idx = (set (np .where (windowed )[0 ] + limit ) |
657
+ set (np .where ((~ invalid [:limit + 1 ]).cumsum () == 0 )[0 ]))
658
+ return idx
659
+
660
+ if fw_limit == 0 :
661
+ f_idx = set (np .where (invalid )[0 ])
662
+ else :
663
+ f_idx = inner (invalid , fw_limit )
664
+
665
+ if bw_limit == 0 :
666
+ # then we don't even need to care about backwards, just use forwards
667
+ return f_idx
668
+ else :
669
+ b_idx = set (N - 1 - np .asarray (list (inner (invalid [::- 1 ], bw_limit ))))
670
+ if fw_limit == 0 :
671
+ return b_idx
672
+ return f_idx & b_idx
673
+
674
+
675
+ def _rolling_window (a , window ):
676
+ """
677
+ [True, True, False, True, False], 2 ->
678
+
679
+ [
680
+ [True, True],
681
+ [True, False],
682
+ [False, True],
683
+ [True, False],
684
+ ]
685
+ """
686
+ # https://stackoverflow.com/a/6811241
687
+ shape = a .shape [:- 1 ] + (a .shape [- 1 ] - window + 1 , window )
688
+ strides = a .strides + (a .strides [- 1 ],)
689
+ return np .lib .stride_tricks .as_strided (a , shape = shape , strides = strides )
0 commit comments