9
9
import tornado .gen
10
10
import gramex
11
11
import gramex .cache
12
- from typing import Union
12
+ from typing import Union , Optional , List , Any
13
13
from binascii import b2a_base64
14
14
from fnmatch import fnmatch
15
15
from http .cookies import Morsel
19
19
from tornado .websocket import WebSocketHandler
20
20
from gramex import conf , __version__
21
21
from gramex .config import merge , objectpath , app_log
22
- from gramex .transforms import build_transform , build_log_info
22
+ from gramex .transforms import build_transform , build_log_info , handler_expr
23
23
from gramex .transforms .template import CacheLoader
24
- from gramex .http import UNAUTHORIZED , FORBIDDEN , BAD_REQUEST , METHOD_NOT_ALLOWED
24
+ from gramex .http import UNAUTHORIZED , FORBIDDEN , BAD_REQUEST , METHOD_NOT_ALLOWED , TOO_MANY_REQUESTS
25
25
from gramex .cache import get_store
26
26
27
27
# We don't use these, but these stores used to be defined here. Programs may import these
28
28
from gramex .cache import KeyStore , JSONStore , HDF5Store , SQLiteStore , RedisStore # noqa
29
29
30
30
server_header = f'Gramex/{ __version__ } '
31
- session_store_cache = {}
31
+ _store_cache = {}
32
32
33
33
# Python 3.8+ supports SameSite cookie attribute. Monkey-patch it for Python 3.7
34
34
# https://stackoverflow.com/a/50813092/100904
@@ -52,6 +52,8 @@ def setup(
52
52
error = None ,
53
53
xsrf_cookies = None ,
54
54
cors : Union [None , bool , dict ] = None ,
55
+ ratelimit : Optional [dict ] = None ,
56
+ # If you add any explicit kwargs here, add them to special_keys too.
55
57
** kwargs ,
56
58
):
57
59
'''
@@ -69,6 +71,7 @@ def setup(
69
71
# Note: call setup_session before setup_auth to ensure that
70
72
# override_user is run before authorize
71
73
cls .setup_session (conf .app .get ('session' ))
74
+ cls .setup_ratelimit (ratelimit , conf .app .get ('ratelimit' ))
72
75
cls .setup_auth (auth )
73
76
cls .setup_error (error )
74
77
cls .setup_xsrf (xsrf_cookies )
@@ -80,7 +83,7 @@ def setup(
80
83
if conf .app .settings .get ('debug' , False ):
81
84
cls .log_exception = cls .debug_exception
82
85
83
- # A list of special keys for BaseHandler. Can be extended by other classes.
86
+ # A list of special keys handled by BaseHandler. Can be extended by other classes.
84
87
special_keys = [
85
88
'transform' ,
86
89
'redirect' ,
@@ -92,6 +95,7 @@ def setup(
92
95
'xsrf_cookies' ,
93
96
'cors' ,
94
97
'headers' ,
98
+ 'ratelimit' ,
95
99
]
96
100
97
101
@classmethod
@@ -107,10 +111,28 @@ def clear_special_keys(cls, kwargs, *args):
107
111
return kwargs
108
112
109
113
@classmethod
110
- def get_list (cls , val : Union [list , tuple , str ], key : str = '' , eg : str = '' , caps = True ) -> set :
111
- '''
112
- Convert val="GET, PUT" into {"GET", "PUT"}.
113
- If val is not a string or list/tuple, raise ValueError("url.{key} invalid. e.g. {eg}")
114
+ def get_list (
115
+ cls , val : Union [list , tuple , str ], key : str = '' , eg : str = '' , caps : bool = True
116
+ ) -> set :
117
+ '''Split comma-separated values into a set.
118
+
119
+ Process kwargs that can be a comma-separated string or a list,
120
+ like BaseMixin's `methods:`, `cors.origins`, `cors.methods`, `cors.headers`,
121
+ `ratelimit.keys`, etc.
122
+
123
+ Examples:
124
+ >>> get_list('GET, PUT') == {'GET', 'PUT'}
125
+ >>> get_list(['GET', ' ,get '], caps=True) == {'GET'}
126
+ >>> get_list([' GET , PUT', ' ,POST, ']) == {'GET', 'PUT', 'POST'}
127
+
128
+ Parameters:
129
+ val: Input to split. If val is not str/list/tuple, raise `ValueError`
130
+ key: `url:` key to display in error message
131
+ eg: Example values to display in error message
132
+ caps: True to convert values to uppercase
133
+
134
+ Returns:
135
+ Unique comma-separated values
114
136
'''
115
137
if isinstance (val , (list , tuple )):
116
138
val = ' ' .join (val )
@@ -312,20 +334,24 @@ def _purge_keys(data):
312
334
return keys
313
335
314
336
@classmethod
315
- def setup_session (cls , session_conf ):
316
- '''handler.session returns the session object. It is saved on finish.'''
317
- if session_conf is None :
318
- return
319
- key = store_type , store_path = session_conf .get ('type' ), session_conf .get ('path' )
320
- if key not in session_store_cache :
321
- session_store_cache [key ] = get_store (
337
+ def _get_store (cls , conf ):
338
+ key = store_type , store_path = conf .get ('type' ), conf .get ('path' )
339
+ if key not in _store_cache :
340
+ _store_cache [key ] = get_store (
322
341
type = store_type ,
323
342
path = store_path ,
324
- flush = session_conf .get ('flush' ),
325
- purge = session_conf .get ('purge' ),
343
+ flush = conf .get ('flush' ),
344
+ purge = conf .get ('purge' ),
326
345
purge_keys = cls ._purge_keys ,
327
346
)
328
- cls ._session_store = session_store_cache [key ]
347
+ return _store_cache [key ]
348
+
349
+ @classmethod
350
+ def setup_session (cls , session_conf ):
351
+ '''handler.session returns the session object. It is saved on finish.'''
352
+ if session_conf is None :
353
+ return
354
+ cls ._session_store = cls ._get_store (session_conf )
329
355
cls .session = property (cls .get_session )
330
356
cls ._session_expiry = session_conf .get ('expiry' )
331
357
cls ._session_cookie_id = session_conf .get ('cookie' , 'sid' )
@@ -344,6 +370,112 @@ def setup_session(cls, session_conf):
344
370
# Ensure that session is saved AFTER we set last visited
345
371
cls ._on_finish_methods .append (cls .save_session )
346
372
373
+ @classmethod
374
+ def setup_ratelimit (cls , ratelimit : Union [dict , None ], ratelimit_app_conf : Union [dict , None ]):
375
+ '''Initialize rate limiting checks'''
376
+ if ratelimit is None :
377
+ return
378
+ if ratelimit_app_conf is None :
379
+ raise ValueError (f"url:{ cls .name } .ratelimit: no app.ratelimit defined" )
380
+ if 'keys' not in ratelimit :
381
+ raise ValueError (f'url:{ cls .name } .ratelimit.keys: missing' )
382
+ if 'limit' not in ratelimit :
383
+ raise ValueError (f'url:{ cls .name } .ratelimit.limit: missing' )
384
+
385
+ # All ratelimit related info is stored in self._ratelimit
386
+ cls ._ratelimit = AttrDict (key_fn = [])
387
+
388
+ # Default the pool name to `pattern:`
389
+ cls ._ratelimit .pool = ratelimit .get ('pool' , cls .conf .pattern )
390
+
391
+ # Convert keys: into list
392
+ keys_spec = ratelimit ['keys' ]
393
+ # keys: daily, user => keys: [daily, user]
394
+ if isinstance (keys_spec , str ):
395
+ keys_spec = cls .get_list (keys_spec , key = cls .name , eg = 'daily, user' , caps = False )
396
+ # keys: {function: ...} => keys: [{function: ...}]
397
+ elif isinstance (keys_spec , dict ):
398
+ keys_spec = [keys_spec ]
399
+ # keys: must be a list
400
+ elif not isinstance (keys_spec , (list , tuple )):
401
+ raise ValueError (f'url:{ cls .name } .ratelimit.keys: needs dict list, not { keys_spec } ' )
402
+
403
+ # Pre-compile keys: into self._ratelimit.keys = [key_fn, key_fn, ...]
404
+ # key_fn['function'](self) will return nth key
405
+ # key_fn['expiry'](self) will return nth expiry (in seconds)
406
+ predefined_keys = ratelimit_app_conf .get ('keys' , {})
407
+ for index , key_spec in enumerate (keys_spec ):
408
+ if isinstance (key_spec , str ):
409
+ # Look up string keys like daily to predefined_keys.
410
+ if key_spec in predefined_keys :
411
+ key_spec = predefined_keys [key_spec ]
412
+ # Or construct functions for `user.id`, etc
413
+ else :
414
+ try :
415
+ key_spec = {'function' : handler_expr (key_spec )}
416
+ except ValueError :
417
+ raise ValueError (f'url:{ cls .name } .ratelimit.keys: { key_spec } is unknown' )
418
+ # {function: ...} MUST be defined for a key. {expiry: ... } is optional
419
+ if not isinstance (key_spec , dict ) or 'function' not in key_spec :
420
+ raise ValueError (f'url:{ cls .name } .ratelimit.keys: { key_spec } has no function:' )
421
+ # Compile key/expiry functions into cls._ratelimit.keys[index]['function' / 'expiry']
422
+ key_fn = {}
423
+ for fn in ('function' , 'expiry' ):
424
+ if fn in key_spec :
425
+ key_fn [fn ] = build_transform (
426
+ {'function' : key_spec [fn ]},
427
+ vars = {'handler' : None },
428
+ filename = f'url:{ cls .name } .ratelimit.keys[{ index } ].{ fn } ' ,
429
+ iter = False ,
430
+ )
431
+ cls ._ratelimit .key_fn .append (key_fn )
432
+
433
+ # Ensure limit: is a number or a {function: ...}
434
+ limit_spec = ratelimit ['limit' ]
435
+ if isinstance (limit_spec , (int , float )):
436
+ limit_spec = {'function' : limit_spec }
437
+ elif not isinstance (ratelimit ['limit' ], dict ) or 'function' not in ratelimit ['limit' ]:
438
+ example = "{'function': number}"
439
+ raise ValueError (f'url:{ cls .name } .ratelimit.limit: needs { example } , not { limit_spec } ' )
440
+
441
+ # Pre-compile limit: into self._ratelimit.limit_fn
442
+ cls ._ratelimit .limit_fn = build_transform (
443
+ limit_spec ,
444
+ vars = {'handler' : None },
445
+ filename = f'url:{ cls .name } .ratelimit.limit' ,
446
+ iter = False ,
447
+ )
448
+
449
+ cls ._ratelimit .store = cls ._get_store (ratelimit_app_conf )
450
+ cls ._on_init_methods .append (cls .check_ratelimit )
451
+ cls ._on_finish_methods .append (cls .update_ratelimit )
452
+
453
+ @classmethod
454
+ def reset_ratelimit (cls , pool : str , keys : List [Any ], value : int = 0 ) -> bool :
455
+ '''Reset the rate limit usage for a specific pool.
456
+
457
+ Examples:
458
+
459
+ >>> reset_ratelimit('/api', ['2022-01-01', '[email protected] '])
460
+ >>> reset_ratelimit('/api', ['2022-01-01', '[email protected] '], 10)
461
+
462
+ Parameters:
463
+
464
+ pool: Rate limit pool to use. This is the url's `pattern:` unless you specified a
465
+ `kwargs.ratelimit.pool:`
466
+ keys: specific instance to reset. If your `ratelimit.keys` is `[daily, user.id]`,
467
+ keys might look like `['2022-01-01', '[email protected] ']` to clear for that day/user
468
+ value: sets the usage counter to this number (default: `0`)
469
+ '''
470
+ store = cls ._get_store (conf .app .get ('ratelimit' ))
471
+ key = json .dumps ([pool ] + keys )
472
+ val = store .load (key , None )
473
+ if val is not None and 'n' in val :
474
+ val ['n' ] = value
475
+ store .dump (key , val )
476
+ else :
477
+ return False
478
+
347
479
@classmethod
348
480
def setup_redirect (cls , redirect ):
349
481
'''
@@ -655,6 +787,9 @@ def _write_custom_error(self, status_code, **kwargs):
655
787
return
656
788
except Exception :
657
789
app_log .exception (f'url:{ self .name } .error.{ status_code } raised an exception' )
790
+ # HTTP 429 error code reports ratelimits
791
+ if status_code == TOO_MANY_REQUESTS and hasattr (self , '_ratelimit' ):
792
+ self .set_ratelimit_headers ()
658
793
# If error was not written, use the default error
659
794
self ._write_error (status_code , ** kwargs )
660
795
@@ -889,10 +1024,13 @@ def override_user(self):
889
1024
self .session ['user' ] = row ['user' ]
890
1025
891
1026
def set_last_visited (self ):
892
- '''
893
- This method is called by :py:func:`BaseHandler.prepare` when any user
894
- accesses a page. It updates the last visited time in the ``_l`` session
895
- key. It does this only if the ``_i`` key exists.
1027
+ '''Update session last visited time if we track inactive expiry.
1028
+
1029
+ - `session._l` is the last time the user accessed a page.
1030
+ - `session._i` is the seconds of inactivity after which the session expires.
1031
+ - If `session._i` is set (we track inactive expiry), we set ``session._l` to now.
1032
+
1033
+ Called by [prepare][BaseHandler.prepare] when any user accesses a page.
896
1034
'''
897
1035
# For efficiency reasons, don't call get_session every time. Check
898
1036
# session only if there's a valid sid cookie (with possibly long expiry)
@@ -901,6 +1039,49 @@ def set_last_visited(self):
901
1039
if '_i' in session :
902
1040
session ['_l' ] = time .time ()
903
1041
1042
+ def check_ratelimit (self ):
1043
+ '''Raise HTTP 429 if usage exceeds rate limit. Set X-Ratelimit-* HTTP headers'''
1044
+ ratelimit = self ._ratelimit
1045
+ # Get the rate limit key, limit and expiry
1046
+ ratelimit .key = json .dumps (
1047
+ [ratelimit .pool ] + [key_fn ['function' ](self ) for key_fn in ratelimit .key_fn ]
1048
+ )
1049
+ ratelimit .limit = ratelimit .limit_fn (self )
1050
+ expiries = [key_fn ['expiry' ](self ) for key_fn in ratelimit .key_fn if 'expiry' in key_fn ]
1051
+ # If no expiry is specified, store for 100 years
1052
+ ratelimit .expiry = min (expiries + [3155760000 ])
1053
+
1054
+ # Ensure usage does not hit limit
1055
+ ratelimit .usage = ratelimit .store .load (ratelimit .key , {'n' : 0 }).get ('n' , 0 )
1056
+ if ratelimit .usage >= ratelimit .limit :
1057
+ raise HTTPError (TOO_MANY_REQUESTS , f'{ ratelimit .key } hit rate limit { ratelimit .limit } ' )
1058
+ self .set_ratelimit_headers ()
1059
+
1060
+ def update_ratelimit (self ):
1061
+ '''If request succeeds, increase rate limit usage count by 1'''
1062
+ ratelimit = self ._ratelimit
1063
+ # If check_ratelimit failed (e.g. invalid function) and didn't set a key, skip update
1064
+ # If response is a HTTP error, don't count towards rate limit
1065
+ if 'key' not in ratelimit or self .get_status () >= 400 :
1066
+ return
1067
+ # Increment the rate limit by 1
1068
+ usage_obj = ratelimit .store .load (ratelimit .key , {'n' : 0 })
1069
+ usage_obj ['n' ] += 1
1070
+ usage_obj ['_t' ] = time .time () + ratelimit .expiry
1071
+ ratelimit .store .dump (ratelimit .key , usage_obj )
1072
+
1073
+ def set_ratelimit_headers (self ):
1074
+ ratelimit = self ._ratelimit
1075
+ # ratelimit.usage goes 0, 1, 2, ...
1076
+ # If limit is 3, remaining goes 3, 2, 1, ... -- use (limit - usage - 1)
1077
+ # But when usage hits 3, don't show remaining = -1. Show remaining = 0 using max()
1078
+ remaining = max (ratelimit .limit - ratelimit .usage - 1 , 0 )
1079
+ self .set_header ('X-Ratelimit-Limit' , str (ratelimit .limit ))
1080
+ self .set_header ('X-Ratelimit-Remaining' , str (remaining ))
1081
+ self .set_header ('X-RateLimit-Reset' , str (ratelimit .expiry ))
1082
+ if ratelimit .usage >= ratelimit .limit :
1083
+ self .set_header ('Retry-After' , str (ratelimit .expiry ))
1084
+
904
1085
905
1086
class BaseHandler (RequestHandler , BaseMixin ):
906
1087
'''
0 commit comments