|
| 1 | +# Implementation of Token Bucket Algorithm |
| 2 | +# Token `rate` is added to the bucket every `frequency` in seconds. |
| 3 | +# The bucket can hold tokens up to `capacity` (full). |
| 4 | +# The bucket starts full. |
| 5 | +# Each request consume one token. |
| 6 | +# If a token arrives when the bucket is full, token is discarded. |
| 7 | +# If a request arrives when bucket is empty, request will be discarded. |
| 8 | +# If bucket has tokens available, requests will pass. |
| 9 | +# https://en.wikipedia.org/wiki/Token_bucket |
| 10 | +import threading |
| 11 | +import time |
| 12 | + |
| 13 | + |
| 14 | +class TokenBucketRateLimiter: |
| 15 | + def __init__(self, rate: int, capacity: int, frequency: int) -> None: |
| 16 | + """ |
| 17 | + Initialize a Token Bucket rate limiter. |
| 18 | +
|
| 19 | + :param rate: Number of tokens added to the bucket per refill |
| 20 | + :param capacity: Maximum number of tokens the bucket can hold. |
| 21 | + :param frequency: Frequency of refill in seconds |
| 22 | + >>> bucket = TokenBucketRateLimiter(4, 4, 60) |
| 23 | + >>> bucket.tokens |
| 24 | + 4 |
| 25 | + >>> bucket.capacity |
| 26 | + 4 |
| 27 | + >>> bucket.frequency |
| 28 | + 60 |
| 29 | + """ |
| 30 | + self.rate = rate # Tokens added per refill |
| 31 | + self.capacity = capacity # Maximum capacity of the bucket |
| 32 | + self.frequency = frequency # Frequency tokens are refilled |
| 33 | + self.tokens = capacity # Current tokens in the bucket |
| 34 | + self.last_checked = time.time() # Time when tokens were last checked |
| 35 | + self.lock = threading.Lock() # To make the rate limiter thread-safe |
| 36 | + |
| 37 | + def _add_tokens(self) -> None: |
| 38 | + """ |
| 39 | + Refill tokens only when a full minute has passed. |
| 40 | + >>> bucket = TokenBucketRateLimiter(1, 4, 60) |
| 41 | + >>> bucket.tokens # Initially has 4 token (rate) |
| 42 | + 4 |
| 43 | + >>> bucket._add_tokens() |
| 44 | + >>> bucket.tokens # Bucket already full |
| 45 | + 4 |
| 46 | + """ |
| 47 | + current_time = time.time() |
| 48 | + elapsed_time = current_time - self.last_checked |
| 49 | + |
| 50 | + if elapsed_time >= self.frequency: |
| 51 | + minutes_passed = int(elapsed_time // self.frequency) |
| 52 | + |
| 53 | + # Add tokens based on rate |
| 54 | + added_tokens = minutes_passed * self.rate |
| 55 | + self.tokens = min(self.capacity, self.tokens + added_tokens) |
| 56 | + |
| 57 | + # Update the last checked time |
| 58 | + self.last_checked += minutes_passed * self.frequency |
| 59 | + |
| 60 | + def allow_request(self) -> bool: |
| 61 | + """ |
| 62 | + Check if a request is allowed. |
| 63 | + If there are enough tokens, it consumes one token. |
| 64 | + :return: True if the request is allowed, False otherwise. |
| 65 | + >>> bucket = TokenBucketRateLimiter(1, 2, 60) |
| 66 | + >>> bucket.allow_request() # Token is available, request passes |
| 67 | + True |
| 68 | + >>> bucket.allow_request() # Token is available, request passes |
| 69 | + True |
| 70 | + >>> bucket.allow_request() # No token left, request is dropped |
| 71 | + False |
| 72 | + """ |
| 73 | + with self.lock: |
| 74 | + self._add_tokens() |
| 75 | + if self.tokens >= 1: |
| 76 | + self.tokens -= 1 |
| 77 | + return True |
| 78 | + return False |
| 79 | + |
| 80 | + |
| 81 | +if __name__ == "__main__": |
| 82 | + import doctest |
| 83 | + |
| 84 | + doctest.testmod() |
| 85 | + |
| 86 | + print("Allow 4 requests per minute, capacity of 4") |
| 87 | + bucket = TokenBucketRateLimiter(4, 4, 60) |
| 88 | + total_requests = 10 |
| 89 | + delay_in_seconds = 10 |
| 90 | + print("Simulate 1 request per 10 seconds...") |
| 91 | + for i in range(total_requests): |
| 92 | + result = "pass" if bucket.allow_request() else "dropped" |
| 93 | + print( |
| 94 | + f"Request {i+1}/{total_requests} \ |
| 95 | + timeline: {i*delay_in_seconds} seconds = {result}" |
| 96 | + ) |
| 97 | + time.sleep(delay_in_seconds) |
0 commit comments