|
| 1 | +import base64 |
| 2 | +import json |
| 3 | +import re |
| 4 | +import zlib |
| 5 | +from enum import Enum |
| 6 | +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union |
| 7 | + |
| 8 | +from aws_lambda_powertools.shared.json_encoder import Encoder |
| 9 | +from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 |
| 10 | +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent |
| 11 | +from aws_lambda_powertools.utilities.typing import LambdaContext |
| 12 | + |
| 13 | + |
| 14 | +class ProxyEventType(Enum): |
| 15 | + http_api_v1 = "APIGatewayProxyEvent" |
| 16 | + http_api_v2 = "APIGatewayProxyEventV2" |
| 17 | + alb_event = "ALBEvent" |
| 18 | + api_gateway = http_api_v1 |
| 19 | + |
| 20 | + |
| 21 | +class CORSConfig(object): |
| 22 | + """CORS Config""" |
| 23 | + |
| 24 | + _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + allow_origin: str = "*", |
| 29 | + allow_headers: List[str] = None, |
| 30 | + expose_headers: List[str] = None, |
| 31 | + max_age: int = None, |
| 32 | + allow_credentials: bool = False, |
| 33 | + ): |
| 34 | + """ |
| 35 | + Parameters |
| 36 | + ---------- |
| 37 | + allow_origin: str |
| 38 | + The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should |
| 39 | + only be used during development. |
| 40 | + allow_headers: str |
| 41 | + The list of additional allowed headers. This list is added to list of |
| 42 | + built in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`, |
| 43 | + `X-Api-Key`, `X-Amz-Security-Token`. |
| 44 | + expose_headers: str |
| 45 | + A list of values to return for the Access-Control-Expose-Headers |
| 46 | + max_age: int |
| 47 | + The value for the `Access-Control-Max-Age` |
| 48 | + allow_credentials: bool |
| 49 | + A boolean value that sets the value of `Access-Control-Allow-Credentials` |
| 50 | + """ |
| 51 | + self.allow_origin = allow_origin |
| 52 | + self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) |
| 53 | + self.expose_headers = expose_headers or [] |
| 54 | + self.max_age = max_age |
| 55 | + self.allow_credentials = allow_credentials |
| 56 | + |
| 57 | + def to_dict(self) -> Dict[str, str]: |
| 58 | + headers = { |
| 59 | + "Access-Control-Allow-Origin": self.allow_origin, |
| 60 | + "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), |
| 61 | + } |
| 62 | + if self.expose_headers: |
| 63 | + headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers) |
| 64 | + if self.max_age is not None: |
| 65 | + headers["Access-Control-Max-Age"] = str(self.max_age) |
| 66 | + if self.allow_credentials is True: |
| 67 | + headers["Access-Control-Allow-Credentials"] = "true" |
| 68 | + return headers |
| 69 | + |
| 70 | + |
| 71 | +class Route: |
| 72 | + def __init__( |
| 73 | + self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] |
| 74 | + ): |
| 75 | + self.method = method.upper() |
| 76 | + self.rule = rule |
| 77 | + self.func = func |
| 78 | + self.cors = cors |
| 79 | + self.compress = compress |
| 80 | + self.cache_control = cache_control |
| 81 | + |
| 82 | + |
| 83 | +class Response: |
| 84 | + def __init__( |
| 85 | + self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None |
| 86 | + ): |
| 87 | + self.status_code = status_code |
| 88 | + self.body = body |
| 89 | + self.base64_encoded = False |
| 90 | + self.headers: Dict = headers or {} |
| 91 | + if content_type: |
| 92 | + self.headers.setdefault("Content-Type", content_type) |
| 93 | + |
| 94 | + def add_cors(self, cors: CORSConfig): |
| 95 | + self.headers.update(cors.to_dict()) |
| 96 | + |
| 97 | + def add_cache_control(self, cache_control: str): |
| 98 | + self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache" |
| 99 | + |
| 100 | + def compress(self): |
| 101 | + self.headers["Content-Encoding"] = "gzip" |
| 102 | + if isinstance(self.body, str): |
| 103 | + self.body = bytes(self.body, "utf-8") |
| 104 | + gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) |
| 105 | + self.body = gzip.compress(self.body) + gzip.flush() |
| 106 | + |
| 107 | + def to_dict(self) -> Dict[str, Any]: |
| 108 | + if isinstance(self.body, bytes): |
| 109 | + self.base64_encoded = True |
| 110 | + self.body = base64.b64encode(self.body).decode() |
| 111 | + return { |
| 112 | + "statusCode": self.status_code, |
| 113 | + "headers": self.headers, |
| 114 | + "body": self.body, |
| 115 | + "isBase64Encoded": self.base64_encoded, |
| 116 | + } |
| 117 | + |
| 118 | + |
| 119 | +class ApiGatewayResolver: |
| 120 | + current_event: BaseProxyEvent |
| 121 | + lambda_context: LambdaContext |
| 122 | + |
| 123 | + def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): |
| 124 | + self._proxy_type = proxy_type |
| 125 | + self._routes: List[Route] = [] |
| 126 | + self._cors = cors |
| 127 | + self._cors_methods: Set[str] = {"OPTIONS"} |
| 128 | + |
| 129 | + def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 130 | + return self.route(rule, "GET", cors, compress, cache_control) |
| 131 | + |
| 132 | + def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 133 | + return self.route(rule, "POST", cors, compress, cache_control) |
| 134 | + |
| 135 | + def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 136 | + return self.route(rule, "PUT", cors, compress, cache_control) |
| 137 | + |
| 138 | + def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 139 | + return self.route(rule, "DELETE", cors, compress, cache_control) |
| 140 | + |
| 141 | + def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 142 | + return self.route(rule, "PATCH", cors, compress, cache_control) |
| 143 | + |
| 144 | + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): |
| 145 | + def register_resolver(func: Callable): |
| 146 | + self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) |
| 147 | + if cors: |
| 148 | + self._cors_methods.add(method.upper()) |
| 149 | + return func |
| 150 | + |
| 151 | + return register_resolver |
| 152 | + |
| 153 | + def resolve(self, event, context) -> Dict[str, Any]: |
| 154 | + self.current_event = self._to_data_class(event) |
| 155 | + self.lambda_context = context |
| 156 | + route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path) |
| 157 | + if route is None: # No matching route was found |
| 158 | + return response.to_dict() |
| 159 | + |
| 160 | + if route.cors: |
| 161 | + response.add_cors(self._cors or CORSConfig()) |
| 162 | + if route.cache_control: |
| 163 | + response.add_cache_control(route.cache_control) |
| 164 | + if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): |
| 165 | + response.compress() |
| 166 | + |
| 167 | + return response.to_dict() |
| 168 | + |
| 169 | + @staticmethod |
| 170 | + def _compile_regex(rule: str): |
| 171 | + rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) |
| 172 | + return re.compile("^{}$".format(rule_regex)) |
| 173 | + |
| 174 | + def _to_data_class(self, event: Dict) -> BaseProxyEvent: |
| 175 | + if self._proxy_type == ProxyEventType.http_api_v1: |
| 176 | + return APIGatewayProxyEvent(event) |
| 177 | + if self._proxy_type == ProxyEventType.http_api_v2: |
| 178 | + return APIGatewayProxyEventV2(event) |
| 179 | + return ALBEvent(event) |
| 180 | + |
| 181 | + def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: |
| 182 | + for route in self._routes: |
| 183 | + if method != route.method: |
| 184 | + continue |
| 185 | + match: Optional[re.Match] = route.rule.match(path) |
| 186 | + if match: |
| 187 | + return self._call_route(route, match.groupdict()) |
| 188 | + |
| 189 | + headers = {} |
| 190 | + if self._cors: |
| 191 | + headers.update(self._cors.to_dict()) |
| 192 | + if method == "OPTIONS": # Preflight |
| 193 | + headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) |
| 194 | + return None, Response(status_code=204, content_type=None, body=None, headers=headers) |
| 195 | + |
| 196 | + return None, Response( |
| 197 | + status_code=404, |
| 198 | + content_type="application/json", |
| 199 | + body=json.dumps({"message": f"No route found for '{method}.{path}'"}), |
| 200 | + headers=headers, |
| 201 | + ) |
| 202 | + |
| 203 | + def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]: |
| 204 | + return route, self._to_response(route.func(**args)) |
| 205 | + |
| 206 | + @staticmethod |
| 207 | + def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: |
| 208 | + if isinstance(result, Response): |
| 209 | + return result |
| 210 | + elif isinstance(result, dict): |
| 211 | + return Response( |
| 212 | + status_code=200, |
| 213 | + content_type="application/json", |
| 214 | + body=json.dumps(result, separators=(",", ":"), cls=Encoder), |
| 215 | + ) |
| 216 | + else: # Tuple[int, str, Union[bytes, str]] |
| 217 | + return Response(*result) |
| 218 | + |
| 219 | + def __call__(self, event, context) -> Any: |
| 220 | + return self.resolve(event, context) |
0 commit comments