@@ -471,20 +471,20 @@ def _interpolate_1d(
471
471
if valid .all ():
472
472
return
473
473
474
- # These are sets of index pointers to invalid values... i.e. {0, 1, etc...
475
- all_nans = set ( np .flatnonzero (invalid ) )
474
+ # These index pointers to invalid values... i.e. {0, 1, etc...
475
+ all_nans = np .flatnonzero (invalid )
476
476
477
477
first_valid_index = find_valid_index (how = "first" , is_valid = valid )
478
478
if first_valid_index is None : # no nan found in start
479
479
first_valid_index = 0
480
- start_nans = set ( range ( first_valid_index ) )
480
+ start_nans = np . arange ( first_valid_index )
481
481
482
482
last_valid_index = find_valid_index (how = "last" , is_valid = valid )
483
483
if last_valid_index is None : # no nan found in end
484
484
last_valid_index = len (yvalues )
485
- end_nans = set ( range ( 1 + last_valid_index , len (valid ) ))
485
+ end_nans = np . arange ( 1 + last_valid_index , len (valid ))
486
486
487
- # Like the sets above, preserve_nans contains indices of invalid values,
487
+ # preserve_nans contains indices of invalid values,
488
488
# but in this case, it is the final set of indices that need to be
489
489
# preserved as NaN after the interpolation.
490
490
@@ -493,27 +493,25 @@ def _interpolate_1d(
493
493
# are more than 'limit' away from the prior non-NaN.
494
494
495
495
# set preserve_nans based on direction using _interp_limit
496
- preserve_nans : list | set
497
496
if limit_direction == "forward" :
498
- preserve_nans = start_nans | set ( _interp_limit (invalid , limit , 0 ))
497
+ preserve_nans = np . union1d ( start_nans , _interp_limit (invalid , limit , 0 ))
499
498
elif limit_direction == "backward" :
500
- preserve_nans = end_nans | set ( _interp_limit (invalid , 0 , limit ))
499
+ preserve_nans = np . union1d ( end_nans , _interp_limit (invalid , 0 , limit ))
501
500
else :
502
501
# both directions... just use _interp_limit
503
- preserve_nans = set (_interp_limit (invalid , limit , limit ))
502
+ preserve_nans = np . unique (_interp_limit (invalid , limit , limit ))
504
503
505
504
# if limit_area is set, add either mid or outside indices
506
505
# to preserve_nans GH #16284
507
506
if limit_area == "inside" :
508
507
# preserve NaNs on the outside
509
- preserve_nans |= start_nans | end_nans
508
+ preserve_nans = np .union1d (preserve_nans , start_nans )
509
+ preserve_nans = np .union1d (preserve_nans , end_nans )
510
510
elif limit_area == "outside" :
511
511
# preserve NaNs on the inside
512
- mid_nans = all_nans - start_nans - end_nans
513
- preserve_nans |= mid_nans
514
-
515
- # sort preserve_nans and convert to list
516
- preserve_nans = sorted (preserve_nans )
512
+ mid_nans = np .setdiff1d (all_nans , start_nans , assume_unique = True )
513
+ mid_nans = np .setdiff1d (mid_nans , end_nans , assume_unique = True )
514
+ preserve_nans = np .union1d (preserve_nans , mid_nans )
517
515
518
516
is_datetimelike = yvalues .dtype .kind in "mM"
519
517
@@ -1027,7 +1025,7 @@ def clean_reindex_fill_method(method) -> ReindexMethod | None:
1027
1025
1028
1026
def _interp_limit (
1029
1027
invalid : npt .NDArray [np .bool_ ], fw_limit : int | None , bw_limit : int | None
1030
- ):
1028
+ ) -> np . ndarray :
1031
1029
"""
1032
1030
Get indexers of values that won't be filled
1033
1031
because they exceed the limits.
@@ -1059,20 +1057,23 @@ def _interp_limit(invalid, fw_limit, bw_limit):
1059
1057
# 1. operate on the reversed array
1060
1058
# 2. subtract the returned indices from N - 1
1061
1059
N = len (invalid )
1062
- f_idx = set ()
1063
- b_idx = set ()
1060
+ f_idx = np .array ([], dtype = np .int64 )
1061
+ b_idx = np .array ([], dtype = np .int64 )
1062
+ assume_unique = True
1064
1063
1065
1064
def inner (invalid , limit : int ):
1066
1065
limit = min (limit , N )
1067
- windowed = _rolling_window (invalid , limit + 1 ).all (1 )
1068
- idx = set (np .where (windowed )[0 ] + limit ) | set (
1069
- np .where ((~ invalid [: limit + 1 ]).cumsum () == 0 )[0 ]
1066
+ windowed = np .lib .stride_tricks .sliding_window_view (invalid , limit + 1 ).all (1 )
1067
+ idx = np .union1d (
1068
+ np .where (windowed )[0 ] + limit ,
1069
+ np .where ((~ invalid [: limit + 1 ]).cumsum () == 0 )[0 ],
1070
1070
)
1071
1071
return idx
1072
1072
1073
1073
if fw_limit is not None :
1074
1074
if fw_limit == 0 :
1075
- f_idx = set (np .where (invalid )[0 ])
1075
+ f_idx = np .where (invalid )[0 ]
1076
+ assume_unique = False
1076
1077
else :
1077
1078
f_idx = inner (invalid , fw_limit )
1078
1079
@@ -1082,26 +1083,8 @@ def inner(invalid, limit: int):
1082
1083
# just use forwards
1083
1084
return f_idx
1084
1085
else :
1085
- b_idx_inv = list (inner (invalid [::- 1 ], bw_limit ))
1086
- b_idx = set (N - 1 - np .asarray (b_idx_inv ))
1086
+ b_idx = N - 1 - inner (invalid [::- 1 ], bw_limit )
1087
1087
if fw_limit == 0 :
1088
1088
return b_idx
1089
1089
1090
- return f_idx & b_idx
1091
-
1092
-
1093
- def _rolling_window (a : npt .NDArray [np .bool_ ], window : int ) -> npt .NDArray [np .bool_ ]:
1094
- """
1095
- [True, True, False, True, False], 2 ->
1096
-
1097
- [
1098
- [True, True],
1099
- [True, False],
1100
- [False, True],
1101
- [True, False],
1102
- ]
1103
- """
1104
- # https://stackoverflow.com/a/6811241
1105
- shape = a .shape [:- 1 ] + (a .shape [- 1 ] - window + 1 , window )
1106
- strides = a .strides + (a .strides [- 1 ],)
1107
- return np .lib .stride_tricks .as_strided (a , shape = shape , strides = strides )
1090
+ return np .intersect1d (f_idx , b_idx , assume_unique = assume_unique )
0 commit comments