Skip to content

Commit 7bed7c7

Browse files
authored
Merge pull request #1228 from cmu-delphi/fix_dashboard_signals_parsing
Fix dashboard signals parsing for dashboard queries
2 parents b11875e + 17b98dd commit 7bed7c7

File tree

1 file changed

+19
-10
lines changed

1 file changed

+19
-10
lines changed

src/server/_limiter.py

+19-10
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77
from ._common import app, get_real_ip_addr
88
from ._config import RATE_LIMIT, RATELIMIT_STORAGE_URL, REDIS_HOST, REDIS_PASSWORD
99
from ._exceptions import ValidationFailedException
10-
from ._params import extract_dates, extract_integers, extract_strings
10+
from ._params import extract_dates, extract_integers, extract_strings, parse_source_signal_sets
1111
from ._security import _is_public_route, current_user, resolve_auth_token, ERROR_MSG_RATE_LIMIT, ERROR_MSG_MULTIPLES
1212

1313

14+
1415
def deduct_on_success(response: Response) -> bool:
1516
if response.status_code != 200:
1617
return False
@@ -52,8 +53,9 @@ def get_multiples_count(request):
5253
if "window" in request.args.keys():
5354
multiple_selection_allowed -= 1
5455
for k, v in request.args.items():
55-
if v == "*":
56+
if "*" in v:
5657
multiple_selection_allowed -= 1
58+
continue
5759
try:
5860
vals = multiples.get(k)(k)
5961
if len(vals) >= 2:
@@ -70,16 +72,23 @@ def get_multiples_count(request):
7072

7173
def check_signals_allowlist(request):
7274
signals_allowlist = {":".join(ss_pair) for ss_pair in DashboardSignals().srcsig_list()}
73-
request_signals = []
74-
if "signal" in request.args.keys():
75-
request_signals += extract_strings("signal")
76-
if "signals" in request.args.keys():
77-
request_signals += extract_strings("signals")
78-
if "data_source" in request.args:
79-
request_signals = [f"{request.args['data_source']}:{request_signal}" for request_signal in request_signals]
75+
request_signals = set()
76+
try:
77+
source_signal_sets = parse_source_signal_sets()
78+
except ValidationFailedException:
79+
return False
80+
for source_signal in source_signal_sets:
81+
# source_signal.signal is expected to be eiter list or bool:
82+
# in case of bool, we have wildcard signal -> return False as there are no chances that
83+
# all signals from given source will be whitelisted
84+
# in case of list, we have list of signals
85+
if isinstance(source_signal.signal, bool):
86+
return False
87+
for signal in source_signal.signal:
88+
request_signals.add(f"{source_signal.source}:{signal}")
8089
if len(request_signals) == 0:
8190
return False
82-
return all([signal in signals_allowlist for signal in request_signals])
91+
return request_signals.issubset(signals_allowlist)
8392

8493

8594
def _resolve_tracking_key() -> str:

0 commit comments

Comments
 (0)