Skip to content

Commit 56abcd9

Browse files
Remove duplicates in _AndList/_OrList filters.
1 parent 0b4f04e commit 56abcd9

File tree

2 files changed

+112
-14
lines changed

2 files changed

+112
-14
lines changed

src/prompt_toolkit/filters/base.py

Lines changed: 63 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ class Filter(metaclass=ABCMeta):
1515
"""
1616

1717
def __init__(self) -> None:
18-
self._and_cache: dict[Filter, _AndList] = {}
19-
self._or_cache: dict[Filter, _OrList] = {}
18+
self._and_cache: dict[Filter, Filter] = {}
19+
self._or_cache: dict[Filter, Filter] = {}
2020
self._invert_result: Filter | None = None
2121

2222
@abstractmethod
@@ -40,7 +40,7 @@ def __and__(self, other: Filter) -> Filter:
4040
if other in self._and_cache:
4141
return self._and_cache[other]
4242

43-
result = _AndList([self, other])
43+
result = _AndList.create([self, other])
4444
self._and_cache[other] = result
4545
return result
4646

@@ -58,7 +58,7 @@ def __or__(self, other: Filter) -> Filter:
5858
if other in self._or_cache:
5959
return self._or_cache[other]
6060

61-
result = _OrList([self, other])
61+
result = _OrList.create([self, other])
6262
self._or_cache[other] = result
6363
return result
6464

@@ -86,20 +86,49 @@ def __bool__(self) -> None:
8686
)
8787

8888

89+
def _remove_duplicates(filters: list[Filter]) -> list[Filter]:
90+
result = []
91+
for f in filters:
92+
if f not in result:
93+
result.append(f)
94+
return result
95+
96+
8997
class _AndList(Filter):
9098
"""
9199
Result of &-operation between several filters.
92100
"""
93101

94-
def __init__(self, filters: Iterable[Filter]) -> None:
102+
def __init__(self, filters: list[Filter]) -> None:
95103
super().__init__()
96-
self.filters: list[Filter] = []
104+
self.filters = filters
105+
106+
@classmethod
107+
def create(cls, filters: Iterable[Filter]) -> Filter:
108+
"""
109+
Create a new filter by applying an `&` operator between them.
110+
111+
If there's only one unique filter in the given iterable, it will return
112+
that one filter instead of an `_AndList`.
113+
"""
114+
filters_2: list[Filter] = []
97115

98116
for f in filters:
99117
if isinstance(f, _AndList): # Turn nested _AndLists into one.
100-
self.filters.extend(f.filters)
118+
filters_2.extend(f.filters)
101119
else:
102-
self.filters.append(f)
120+
filters_2.append(f)
121+
122+
# Remove duplicates. This could speed up execution, and doesn't make a
123+
# difference for the evaluation.
124+
filters = _remove_duplicates(filters_2)
125+
126+
# If only one filter is left, return that without wrapping into an
127+
# `_AndList`.
128+
if len(filters) == 1:
129+
return filters[0]
130+
131+
return cls(filters)
103132

104133
def __call__(self) -> bool:
105134
return all(f() for f in self.filters)
@@ -113,15 +142,36 @@ class _OrList(Filter):
113142
Result of |-operation between several filters.
114143
"""
115144

116-
def __init__(self, filters: Iterable[Filter]) -> None:
145+
def __init__(self, filters: list[Filter]) -> None:
117146
super().__init__()
118-
self.filters: list[Filter] = []
147+
self.filters = filters
148+
149+
@classmethod
150+
def create(cls, filters: Iterable[Filter]) -> Filter:
151+
"""
152+
Create a new filter by applying an `|` operator between them.
153+
154+
If there's only one unique filter in the given iterable, it will return
155+
that one filter instead of an `_OrList`.
156+
"""
157+
filters_2: list[Filter] = []
119158

120159
for f in filters:
121-
if isinstance(f, _OrList): # Turn nested _OrLists into one.
122-
self.filters.extend(f.filters)
160+
if isinstance(f, _OrList): # Turn nested _AndLists into one.
161+
filters_2.extend(f.filters)
123162
else:
124-
self.filters.append(f)
163+
filters_2.append(f)
164+
165+
# Remove duplicates. This could speed up execution, and doesn't make a
166+
# difference for the evaluation.
167+
filters = _remove_duplicates(filters_2)
168+
169+
# If only one filter is left, return that without wrapping into an
170+
# `_AndList`.
171+
if len(filters) == 1:
172+
return filters[0]
173+
174+
return cls(filters)
125175

126176
def __call__(self) -> bool:
127177
return any(f() for f in self.filters)

tests/test_filter.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3-
import pytest
43
import gc
54

5+
import pytest
6+
67
from prompt_toolkit.filters import Always, Condition, Filter, Never, to_filter
8+
from prompt_toolkit.filters.base import _AndList, _OrList
79

810

911
def test_never():
@@ -44,6 +46,32 @@ def test_and():
4446
assert c3() == (a and b)
4547

4648

49+
def test_nested_and():
50+
for a in (True, False):
51+
for b in (True, False):
52+
for c in (True, False):
53+
c1 = Condition(lambda: a)
54+
c2 = Condition(lambda: b)
55+
c3 = Condition(lambda: c)
56+
c4 = (c1 & c2) & c3
57+
58+
assert isinstance(c4, Filter)
59+
assert c4() == (a and b and c)
60+
61+
62+
def test_nested_or():
63+
for a in (True, False):
64+
for b in (True, False):
65+
for c in (True, False):
66+
c1 = Condition(lambda: a)
67+
c2 = Condition(lambda: b)
68+
c3 = Condition(lambda: c)
69+
c4 = (c1 | c2) | c3
70+
71+
assert isinstance(c4, Filter)
72+
assert c4() == (a or b or c)
73+
74+
4775
def test_to_filter():
4876
f1 = to_filter(True)
4977
f2 = to_filter(False)
@@ -75,6 +103,7 @@ def test_filter_cache_regression_1():
75103
y = (cond & cond) & cond
76104
assert x == y
77105

106+
78107
def test_filter_cache_regression_2():
79108
cond1 = Condition(lambda: True)
80109
cond2 = Condition(lambda: True)
@@ -83,3 +112,22 @@ def test_filter_cache_regression_2():
83112
x = (cond1 & cond2) & cond3
84113
y = (cond1 & cond2) & cond3
85114
assert x == y
115+
116+
117+
def test_filter_remove_duplicates():
118+
cond1 = Condition(lambda: True)
119+
cond2 = Condition(lambda: True)
120+
121+
# When a condition is appended to itself using an `&` or `|` operator, it
122+
# should not be present twice. Having it twice in the `_AndList` or
123+
# `_OrList` will make them more expensive to evaluate.
124+
125+
assert isinstance(cond1 & cond1, Condition)
126+
assert isinstance(cond1 & cond1 & cond1, Condition)
127+
assert isinstance(cond1 & cond1 & cond2, _AndList)
128+
assert len((cond1 & cond1 & cond2).filters) == 2
129+
130+
assert isinstance(cond1 | cond1, Condition)
131+
assert isinstance(cond1 | cond1 | cond1, Condition)
132+
assert isinstance(cond1 | cond1 | cond2, _OrList)
133+
assert len((cond1 | cond1 | cond2).filters) == 2

0 commit comments

Comments
 (0)