Skip to content

Commit 516953d

Browse files
mroeschkepmhatre1
authored andcommitted
REF: Use numpy set methods in interpolate (pandas-dev#57997)
* Use numpy arrays instead of sets in interp * Enable assume_unique in intersect1d * Typing
1 parent 708664c commit 516953d

File tree

1 file changed

+25
-42
lines changed

1 file changed

+25
-42
lines changed

pandas/core/missing.py

+25-42
Original file line numberDiff line numberDiff line change
@@ -471,20 +471,20 @@ def _interpolate_1d(
471471
if valid.all():
472472
return
473473

474-
# These are sets of index pointers to invalid values... i.e. {0, 1, etc...
475-
all_nans = set(np.flatnonzero(invalid))
474+
# These index pointers to invalid values... i.e. {0, 1, etc...
475+
all_nans = np.flatnonzero(invalid)
476476

477477
first_valid_index = find_valid_index(how="first", is_valid=valid)
478478
if first_valid_index is None: # no nan found in start
479479
first_valid_index = 0
480-
start_nans = set(range(first_valid_index))
480+
start_nans = np.arange(first_valid_index)
481481

482482
last_valid_index = find_valid_index(how="last", is_valid=valid)
483483
if last_valid_index is None: # no nan found in end
484484
last_valid_index = len(yvalues)
485-
end_nans = set(range(1 + last_valid_index, len(valid)))
485+
end_nans = np.arange(1 + last_valid_index, len(valid))
486486

487-
# Like the sets above, preserve_nans contains indices of invalid values,
487+
# preserve_nans contains indices of invalid values,
488488
# but in this case, it is the final set of indices that need to be
489489
# preserved as NaN after the interpolation.
490490

@@ -493,27 +493,25 @@ def _interpolate_1d(
493493
# are more than 'limit' away from the prior non-NaN.
494494

495495
# set preserve_nans based on direction using _interp_limit
496-
preserve_nans: list | set
497496
if limit_direction == "forward":
498-
preserve_nans = start_nans | set(_interp_limit(invalid, limit, 0))
497+
preserve_nans = np.union1d(start_nans, _interp_limit(invalid, limit, 0))
499498
elif limit_direction == "backward":
500-
preserve_nans = end_nans | set(_interp_limit(invalid, 0, limit))
499+
preserve_nans = np.union1d(end_nans, _interp_limit(invalid, 0, limit))
501500
else:
502501
# both directions... just use _interp_limit
503-
preserve_nans = set(_interp_limit(invalid, limit, limit))
502+
preserve_nans = np.unique(_interp_limit(invalid, limit, limit))
504503

505504
# if limit_area is set, add either mid or outside indices
506505
# to preserve_nans GH #16284
507506
if limit_area == "inside":
508507
# preserve NaNs on the outside
509-
preserve_nans |= start_nans | end_nans
508+
preserve_nans = np.union1d(preserve_nans, start_nans)
509+
preserve_nans = np.union1d(preserve_nans, end_nans)
510510
elif limit_area == "outside":
511511
# preserve NaNs on the inside
512-
mid_nans = all_nans - start_nans - end_nans
513-
preserve_nans |= mid_nans
514-
515-
# sort preserve_nans and convert to list
516-
preserve_nans = sorted(preserve_nans)
512+
mid_nans = np.setdiff1d(all_nans, start_nans, assume_unique=True)
513+
mid_nans = np.setdiff1d(mid_nans, end_nans, assume_unique=True)
514+
preserve_nans = np.union1d(preserve_nans, mid_nans)
517515

518516
is_datetimelike = yvalues.dtype.kind in "mM"
519517

@@ -1027,7 +1025,7 @@ def clean_reindex_fill_method(method) -> ReindexMethod | None:
10271025

10281026
def _interp_limit(
10291027
invalid: npt.NDArray[np.bool_], fw_limit: int | None, bw_limit: int | None
1030-
):
1028+
) -> np.ndarray:
10311029
"""
10321030
Get indexers of values that won't be filled
10331031
because they exceed the limits.
@@ -1059,20 +1057,23 @@ def _interp_limit(invalid, fw_limit, bw_limit):
10591057
# 1. operate on the reversed array
10601058
# 2. subtract the returned indices from N - 1
10611059
N = len(invalid)
1062-
f_idx = set()
1063-
b_idx = set()
1060+
f_idx = np.array([], dtype=np.int64)
1061+
b_idx = np.array([], dtype=np.int64)
1062+
assume_unique = True
10641063

10651064
def inner(invalid, limit: int):
10661065
limit = min(limit, N)
1067-
windowed = _rolling_window(invalid, limit + 1).all(1)
1068-
idx = set(np.where(windowed)[0] + limit) | set(
1069-
np.where((~invalid[: limit + 1]).cumsum() == 0)[0]
1066+
windowed = np.lib.stride_tricks.sliding_window_view(invalid, limit + 1).all(1)
1067+
idx = np.union1d(
1068+
np.where(windowed)[0] + limit,
1069+
np.where((~invalid[: limit + 1]).cumsum() == 0)[0],
10701070
)
10711071
return idx
10721072

10731073
if fw_limit is not None:
10741074
if fw_limit == 0:
1075-
f_idx = set(np.where(invalid)[0])
1075+
f_idx = np.where(invalid)[0]
1076+
assume_unique = False
10761077
else:
10771078
f_idx = inner(invalid, fw_limit)
10781079

@@ -1082,26 +1083,8 @@ def inner(invalid, limit: int):
10821083
# just use forwards
10831084
return f_idx
10841085
else:
1085-
b_idx_inv = list(inner(invalid[::-1], bw_limit))
1086-
b_idx = set(N - 1 - np.asarray(b_idx_inv))
1086+
b_idx = N - 1 - inner(invalid[::-1], bw_limit)
10871087
if fw_limit == 0:
10881088
return b_idx
10891089

1090-
return f_idx & b_idx
1091-
1092-
1093-
def _rolling_window(a: npt.NDArray[np.bool_], window: int) -> npt.NDArray[np.bool_]:
1094-
"""
1095-
[True, True, False, True, False], 2 ->
1096-
1097-
[
1098-
[True, True],
1099-
[True, False],
1100-
[False, True],
1101-
[True, False],
1102-
]
1103-
"""
1104-
# https://stackoverflow.com/a/6811241
1105-
shape = a.shape[:-1] + (a.shape[-1] - window + 1, window)
1106-
strides = a.strides + (a.strides[-1],)
1107-
return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)
1090+
return np.intersect1d(f_idx, b_idx, assume_unique=assume_unique)

0 commit comments

Comments
 (0)