Skip to content

Commit ff46413

Browse files
authored
ENH: API ratelimit available for all handlers. Fixes #632 (#635)
* ENH: API ratelimit available for all handlers * DOC: fix docstring for BaseMixin.get_list
1 parent 17c83e7 commit ff46413

File tree

6 files changed

+470
-67
lines changed

6 files changed

+470
-67
lines changed

gramex/gramex.yaml

+27-2
Original file line numberDiff line numberDiff line change
@@ -276,15 +276,40 @@ app:
276276
# Save in a JSON store
277277
type: json
278278
path: $GRAMEXDATA/session.json
279-
# Flush every 5 seconds
279+
# Flush every 5 seconds. Clear expired sessions every hour
280280
flush: 5
281-
# Clear expired sessions every hour
282281
purge: 3600
283282
# Cookies expire after 31 days
284283
expiry: 31
285284
# Browsers cannot use JS to access session cookie. Only HTTP access allowed, for security
286285
httponly: true
287286

287+
# Configure how rate limiting works.
288+
ratelimit:
289+
# Save in a JSON store
290+
type: json
291+
path: $GRAMEXDATA/ratelimit.json
292+
# Flush every 30 seconds. Clear expired sessions every hour
293+
flush: 30
294+
purge: 3600
295+
# These can be used in keys as pre-defined functions
296+
keys:
297+
hourly:
298+
function: "pd.Timestamp.utcnow().strftime('%Y-%m-%d %H')"
299+
expiry: "int((pd.Timestamp.utcnow().ceil(freq='H') - pd.Timestamp.utcnow()).total_seconds())"
300+
daily:
301+
function: "pd.Timestamp.utcnow().strftime('%Y-%m-%d')"
302+
expiry: "int((pd.Timestamp.utcnow().normalize() + pd.Timedelta(days=1) - pd.Timestamp.utcnow()).total_seconds())"
303+
weekly:
304+
function: "pd.Timestamp.utcnow().strftime('%Y %U')"
305+
expiry: "int((pd.Timestamp.utcnow().normalize() + pd.Timedelta(days=7) - pd.Timestamp.utcnow()).total_seconds())"
306+
monthly:
307+
function: "pd.Timestamp.utcnow().strftime('%Y-%m')"
308+
expiry: "int((pd.Timestamp.utcnow().normalize() + pd.offsets.MonthBegin() - pd.Timestamp.utcnow()).total_seconds())"
309+
yearly:
310+
function: "pd.Timestamp.utcnow().strftime('%Y')"
311+
expiry: "int((pd.Timestamp.utcnow().normalize() + pd.offsets.YearBegin() - pd.Timestamp.utcnow()).total_seconds())"
312+
288313
# The storelocations: section defines where Gramex stores its data.
289314
storelocations:
290315
# The `otp:` section defines where to store one-time passwords

gramex/handlers/basehandler.py

+204-23
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import tornado.gen
1010
import gramex
1111
import gramex.cache
12-
from typing import Union
12+
from typing import Union, Optional, List, Any
1313
from binascii import b2a_base64
1414
from fnmatch import fnmatch
1515
from http.cookies import Morsel
@@ -19,16 +19,16 @@
1919
from tornado.websocket import WebSocketHandler
2020
from gramex import conf, __version__
2121
from gramex.config import merge, objectpath, app_log
22-
from gramex.transforms import build_transform, build_log_info
22+
from gramex.transforms import build_transform, build_log_info, handler_expr
2323
from gramex.transforms.template import CacheLoader
24-
from gramex.http import UNAUTHORIZED, FORBIDDEN, BAD_REQUEST, METHOD_NOT_ALLOWED
24+
from gramex.http import UNAUTHORIZED, FORBIDDEN, BAD_REQUEST, METHOD_NOT_ALLOWED, TOO_MANY_REQUESTS
2525
from gramex.cache import get_store
2626

2727
# We don't use these, but these stores used to be defined here. Programs may import these
2828
from gramex.cache import KeyStore, JSONStore, HDF5Store, SQLiteStore, RedisStore # noqa
2929

3030
server_header = f'Gramex/{__version__}'
31-
session_store_cache = {}
31+
_store_cache = {}
3232

3333
# Python 3.8+ supports SameSite cookie attribute. Monkey-patch it for Python 3.7
3434
# https://stackoverflow.com/a/50813092/100904
@@ -52,6 +52,8 @@ def setup(
5252
error=None,
5353
xsrf_cookies=None,
5454
cors: Union[None, bool, dict] = None,
55+
ratelimit: Optional[dict] = None,
56+
# If you add any explicit kwargs here, add them to special_keys too.
5557
**kwargs,
5658
):
5759
'''
@@ -69,6 +71,7 @@ def setup(
6971
# Note: call setup_session before setup_auth to ensure that
7072
# override_user is run before authorize
7173
cls.setup_session(conf.app.get('session'))
74+
cls.setup_ratelimit(ratelimit, conf.app.get('ratelimit'))
7275
cls.setup_auth(auth)
7376
cls.setup_error(error)
7477
cls.setup_xsrf(xsrf_cookies)
@@ -80,7 +83,7 @@ def setup(
8083
if conf.app.settings.get('debug', False):
8184
cls.log_exception = cls.debug_exception
8285

83-
# A list of special keys for BaseHandler. Can be extended by other classes.
86+
# A list of special keys handled by BaseHandler. Can be extended by other classes.
8487
special_keys = [
8588
'transform',
8689
'redirect',
@@ -92,6 +95,7 @@ def setup(
9295
'xsrf_cookies',
9396
'cors',
9497
'headers',
98+
'ratelimit',
9599
]
96100

97101
@classmethod
@@ -107,10 +111,28 @@ def clear_special_keys(cls, kwargs, *args):
107111
return kwargs
108112

109113
@classmethod
110-
def get_list(cls, val: Union[list, tuple, str], key: str = '', eg: str = '', caps=True) -> set:
111-
'''
112-
Convert val="GET, PUT" into {"GET", "PUT"}.
113-
If val is not a string or list/tuple, raise ValueError("url.{key} invalid. e.g. {eg}")
114+
def get_list(
115+
cls, val: Union[list, tuple, str], key: str = '', eg: str = '', caps: bool = True
116+
) -> set:
117+
'''Split comma-separated values into a set.
118+
119+
Process kwargs that can be a comma-separated string or a list,
120+
like BaseMixin's `methods:`, `cors.origins`, `cors.methods`, `cors.headers`,
121+
`ratelimit.keys`, etc.
122+
123+
Examples:
124+
>>> get_list('GET, PUT') == {'GET', 'PUT'}
125+
>>> get_list(['GET', ' ,get '], caps=True) == {'GET'}
126+
>>> get_list([' GET , PUT', ' ,POST, ']) == {'GET', 'PUT', 'POST'}
127+
128+
Parameters:
129+
val: Input to split. If val is not str/list/tuple, raise `ValueError`
130+
key: `url:` key to display in error message
131+
eg: Example values to display in error message
132+
caps: True to convert values to uppercase
133+
134+
Returns:
135+
Unique comma-separated values
114136
'''
115137
if isinstance(val, (list, tuple)):
116138
val = ' '.join(val)
@@ -312,20 +334,24 @@ def _purge_keys(data):
312334
return keys
313335

314336
@classmethod
315-
def setup_session(cls, session_conf):
316-
'''handler.session returns the session object. It is saved on finish.'''
317-
if session_conf is None:
318-
return
319-
key = store_type, store_path = session_conf.get('type'), session_conf.get('path')
320-
if key not in session_store_cache:
321-
session_store_cache[key] = get_store(
337+
def _get_store(cls, conf):
338+
key = store_type, store_path = conf.get('type'), conf.get('path')
339+
if key not in _store_cache:
340+
_store_cache[key] = get_store(
322341
type=store_type,
323342
path=store_path,
324-
flush=session_conf.get('flush'),
325-
purge=session_conf.get('purge'),
343+
flush=conf.get('flush'),
344+
purge=conf.get('purge'),
326345
purge_keys=cls._purge_keys,
327346
)
328-
cls._session_store = session_store_cache[key]
347+
return _store_cache[key]
348+
349+
@classmethod
350+
def setup_session(cls, session_conf):
351+
'''handler.session returns the session object. It is saved on finish.'''
352+
if session_conf is None:
353+
return
354+
cls._session_store = cls._get_store(session_conf)
329355
cls.session = property(cls.get_session)
330356
cls._session_expiry = session_conf.get('expiry')
331357
cls._session_cookie_id = session_conf.get('cookie', 'sid')
@@ -344,6 +370,112 @@ def setup_session(cls, session_conf):
344370
# Ensure that session is saved AFTER we set last visited
345371
cls._on_finish_methods.append(cls.save_session)
346372

373+
@classmethod
374+
def setup_ratelimit(cls, ratelimit: Union[dict, None], ratelimit_app_conf: Union[dict, None]):
375+
'''Initialize rate limiting checks'''
376+
if ratelimit is None:
377+
return
378+
if ratelimit_app_conf is None:
379+
raise ValueError(f"url:{cls.name}.ratelimit: no app.ratelimit defined")
380+
if 'keys' not in ratelimit:
381+
raise ValueError(f'url:{cls.name}.ratelimit.keys: missing')
382+
if 'limit' not in ratelimit:
383+
raise ValueError(f'url:{cls.name}.ratelimit.limit: missing')
384+
385+
# All ratelimit related info is stored in self._ratelimit
386+
cls._ratelimit = AttrDict(key_fn=[])
387+
388+
# Default the pool name to `pattern:`
389+
cls._ratelimit.pool = ratelimit.get('pool', cls.conf.pattern)
390+
391+
# Convert keys: into list
392+
keys_spec = ratelimit['keys']
393+
# keys: daily, user => keys: [daily, user]
394+
if isinstance(keys_spec, str):
395+
keys_spec = cls.get_list(keys_spec, key=cls.name, eg='daily, user', caps=False)
396+
# keys: {function: ...} => keys: [{function: ...}]
397+
elif isinstance(keys_spec, dict):
398+
keys_spec = [keys_spec]
399+
# keys: must be a list
400+
elif not isinstance(keys_spec, (list, tuple)):
401+
raise ValueError(f'url:{cls.name}.ratelimit.keys: needs dict list, not {keys_spec}')
402+
403+
# Pre-compile keys: into self._ratelimit.keys = [key_fn, key_fn, ...]
404+
# key_fn['function'](self) will return nth key
405+
# key_fn['expiry'](self) will return nth expiry (in seconds)
406+
predefined_keys = ratelimit_app_conf.get('keys', {})
407+
for index, key_spec in enumerate(keys_spec):
408+
if isinstance(key_spec, str):
409+
# Look up string keys like daily to predefined_keys.
410+
if key_spec in predefined_keys:
411+
key_spec = predefined_keys[key_spec]
412+
# Or construct functions for `user.id`, etc
413+
else:
414+
try:
415+
key_spec = {'function': handler_expr(key_spec)}
416+
except ValueError:
417+
raise ValueError(f'url:{cls.name}.ratelimit.keys: {key_spec} is unknown')
418+
# {function: ...} MUST be defined for a key. {expiry: ... } is optional
419+
if not isinstance(key_spec, dict) or 'function' not in key_spec:
420+
raise ValueError(f'url:{cls.name}.ratelimit.keys: {key_spec} has no function:')
421+
# Compile key/expiry functions into cls._ratelimit.keys[index]['function' / 'expiry']
422+
key_fn = {}
423+
for fn in ('function', 'expiry'):
424+
if fn in key_spec:
425+
key_fn[fn] = build_transform(
426+
{'function': key_spec[fn]},
427+
vars={'handler': None},
428+
filename=f'url:{cls.name}.ratelimit.keys[{index}].{fn}',
429+
iter=False,
430+
)
431+
cls._ratelimit.key_fn.append(key_fn)
432+
433+
# Ensure limit: is a number or a {function: ...}
434+
limit_spec = ratelimit['limit']
435+
if isinstance(limit_spec, (int, float)):
436+
limit_spec = {'function': limit_spec}
437+
elif not isinstance(ratelimit['limit'], dict) or 'function' not in ratelimit['limit']:
438+
example = "{'function': number}"
439+
raise ValueError(f'url:{cls.name}.ratelimit.limit: needs {example}, not {limit_spec}')
440+
441+
# Pre-compile limit: into self._ratelimit.limit_fn
442+
cls._ratelimit.limit_fn = build_transform(
443+
limit_spec,
444+
vars={'handler': None},
445+
filename=f'url:{cls.name}.ratelimit.limit',
446+
iter=False,
447+
)
448+
449+
cls._ratelimit.store = cls._get_store(ratelimit_app_conf)
450+
cls._on_init_methods.append(cls.check_ratelimit)
451+
cls._on_finish_methods.append(cls.update_ratelimit)
452+
453+
@classmethod
454+
def reset_ratelimit(cls, pool: str, keys: List[Any], value: int = 0) -> bool:
455+
'''Reset the rate limit usage for a specific pool.
456+
457+
Examples:
458+
459+
>>> reset_ratelimit('/api', ['2022-01-01', '[email protected]'])
460+
>>> reset_ratelimit('/api', ['2022-01-01', '[email protected]'], 10)
461+
462+
Parameters:
463+
464+
pool: Rate limit pool to use. This is the url's `pattern:` unless you specified a
465+
`kwargs.ratelimit.pool:`
466+
keys: specific instance to reset. If your `ratelimit.keys` is `[daily, user.id]`,
467+
keys might look like `['2022-01-01', '[email protected]']` to clear for that day/user
468+
value: sets the usage counter to this number (default: `0`)
469+
'''
470+
store = cls._get_store(conf.app.get('ratelimit'))
471+
key = json.dumps([pool] + keys)
472+
val = store.load(key, None)
473+
if val is not None and 'n' in val:
474+
val['n'] = value
475+
store.dump(key, val)
476+
else:
477+
return False
478+
347479
@classmethod
348480
def setup_redirect(cls, redirect):
349481
'''
@@ -655,6 +787,9 @@ def _write_custom_error(self, status_code, **kwargs):
655787
return
656788
except Exception:
657789
app_log.exception(f'url:{self.name}.error.{status_code} raised an exception')
790+
# HTTP 429 error code reports ratelimits
791+
if status_code == TOO_MANY_REQUESTS and hasattr(self, '_ratelimit'):
792+
self.set_ratelimit_headers()
658793
# If error was not written, use the default error
659794
self._write_error(status_code, **kwargs)
660795

@@ -889,10 +1024,13 @@ def override_user(self):
8891024
self.session['user'] = row['user']
8901025

8911026
def set_last_visited(self):
892-
'''
893-
This method is called by :py:func:`BaseHandler.prepare` when any user
894-
accesses a page. It updates the last visited time in the ``_l`` session
895-
key. It does this only if the ``_i`` key exists.
1027+
'''Update session last visited time if we track inactive expiry.
1028+
1029+
- `session._l` is the last time the user accessed a page.
1030+
- `session._i` is the seconds of inactivity after which the session expires.
1031+
- If `session._i` is set (we track inactive expiry), we set ``session._l` to now.
1032+
1033+
Called by [prepare][BaseHandler.prepare] when any user accesses a page.
8961034
'''
8971035
# For efficiency reasons, don't call get_session every time. Check
8981036
# session only if there's a valid sid cookie (with possibly long expiry)
@@ -901,6 +1039,49 @@ def set_last_visited(self):
9011039
if '_i' in session:
9021040
session['_l'] = time.time()
9031041

1042+
def check_ratelimit(self):
1043+
'''Raise HTTP 429 if usage exceeds rate limit. Set X-Ratelimit-* HTTP headers'''
1044+
ratelimit = self._ratelimit
1045+
# Get the rate limit key, limit and expiry
1046+
ratelimit.key = json.dumps(
1047+
[ratelimit.pool] + [key_fn['function'](self) for key_fn in ratelimit.key_fn]
1048+
)
1049+
ratelimit.limit = ratelimit.limit_fn(self)
1050+
expiries = [key_fn['expiry'](self) for key_fn in ratelimit.key_fn if 'expiry' in key_fn]
1051+
# If no expiry is specified, store for 100 years
1052+
ratelimit.expiry = min(expiries + [3155760000])
1053+
1054+
# Ensure usage does not hit limit
1055+
ratelimit.usage = ratelimit.store.load(ratelimit.key, {'n': 0}).get('n', 0)
1056+
if ratelimit.usage >= ratelimit.limit:
1057+
raise HTTPError(TOO_MANY_REQUESTS, f'{ratelimit.key} hit rate limit {ratelimit.limit}')
1058+
self.set_ratelimit_headers()
1059+
1060+
def update_ratelimit(self):
1061+
'''If request succeeds, increase rate limit usage count by 1'''
1062+
ratelimit = self._ratelimit
1063+
# If check_ratelimit failed (e.g. invalid function) and didn't set a key, skip update
1064+
# If response is a HTTP error, don't count towards rate limit
1065+
if 'key' not in ratelimit or self.get_status() >= 400:
1066+
return
1067+
# Increment the rate limit by 1
1068+
usage_obj = ratelimit.store.load(ratelimit.key, {'n': 0})
1069+
usage_obj['n'] += 1
1070+
usage_obj['_t'] = time.time() + ratelimit.expiry
1071+
ratelimit.store.dump(ratelimit.key, usage_obj)
1072+
1073+
def set_ratelimit_headers(self):
1074+
ratelimit = self._ratelimit
1075+
# ratelimit.usage goes 0, 1, 2, ...
1076+
# If limit is 3, remaining goes 3, 2, 1, ... -- use (limit - usage - 1)
1077+
# But when usage hits 3, don't show remaining = -1. Show remaining = 0 using max()
1078+
remaining = max(ratelimit.limit - ratelimit.usage - 1, 0)
1079+
self.set_header('X-Ratelimit-Limit', str(ratelimit.limit))
1080+
self.set_header('X-Ratelimit-Remaining', str(remaining))
1081+
self.set_header('X-RateLimit-Reset', str(ratelimit.expiry))
1082+
if ratelimit.usage >= ratelimit.limit:
1083+
self.set_header('Retry-After', str(ratelimit.expiry))
1084+
9041085

9051086
class BaseHandler(RequestHandler, BaseMixin):
9061087
'''

gramex/transforms/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .auth import ensure_single_session
44
from .template import template, sass, scss, ts, vue
55
from .transforms import build_transform, build_pipeline, build_log_info, condition, flattener, once
6-
from .transforms import handler, Header
6+
from .transforms import handler, handler_expr, Header
77

88
# Import common libraries with their popular abbreviations.
99
# This lets build_transform() to use, for e.g., `pd.concat()` instead of `pandas.concat()`.
@@ -24,6 +24,7 @@
2424
'flattener',
2525
'once',
2626
'handler',
27+
'handler_expr',
2728
'Header',
2829
'pd',
2930
'np',

0 commit comments

Comments
 (0)