diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index b7204fa41c0..c1cdde63db9 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -37,7 +37,7 @@ class ProxyEventType(Enum): ALBEvent = "ALBEvent" -class CORSConfig(object): +class CORSConfig: """CORS Config Examples @@ -265,6 +265,7 @@ def __init__( cors: Optional[CORSConfig] = None, debug: Optional[bool] = None, serializer: Optional[Callable[[Dict], str]] = None, + strip_prefixes: Optional[List[str]] = None, ): """ Parameters @@ -276,6 +277,11 @@ def __init__( debug: Optional[bool] Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG" environment variable + serializer : Callable, optional + function to serialize `obj` to a JSON formatted `str`, by default json.dumps + strip_prefixes: List[str], optional + optional list of prefixes to be removed from the request path before doing the routing. This is often used + with api gateways with multiple custom mappings. """ self._proxy_type = proxy_type self._routes: List[Route] = [] @@ -285,6 +291,7 @@ def __init__( self._debug = resolve_truthy_env_var_choice( env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug ) + self._strip_prefixes = strip_prefixes # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -521,7 +528,7 @@ def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: def _resolve(self) -> ResponseBuilder: """Resolves the response or return the not found response""" method = self.current_event.http_method.upper() - path = self.current_event.path + path = self._remove_prefix(self.current_event.path) for route in self._routes: if method != route.method: continue @@ -533,6 +540,25 @@ def _resolve(self) -> ResponseBuilder: logger.debug(f"No match found for path {path} and method {method}") return self._not_found(method) + def _remove_prefix(self, path: str) -> str: + """Remove the configured prefix from the path""" + if not isinstance(self._strip_prefixes, list): + return path + + for prefix in self._strip_prefixes: + if self._path_starts_with(path, prefix): + return path[len(prefix) :] + + return path + + @staticmethod + def _path_starts_with(path: str, prefix: str): + """Returns true if the `path` starts with a prefix plus a `/`""" + if not isinstance(prefix, str) or len(prefix) == 0: + return False + + return path.startswith(prefix + "/") + def _not_found(self, method: str) -> ResponseBuilder: """Called when no matching route was found and includes support for the cors preflight response""" headers = {} diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 1272125da8b..3c959747daf 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -769,3 +769,57 @@ def get_color() -> Dict: body = response["body"] expected = '{"color": 1, "variations": ["dark", "light"]}' assert expected == body + + +@pytest.mark.parametrize( + "path", + [ + pytest.param("/pay/foo", id="path matched pay prefix"), + pytest.param("/payment/foo", id="path matched payment prefix"), + pytest.param("/foo", id="path does not start with any of the prefixes"), + ], +) +def test_remove_prefix(path: str): + # GIVEN events paths `/pay/foo`, `/payment/foo` or `/foo` + # AND a configured strip_prefixes of `/pay` and `/payment` + app = ApiGatewayResolver(strip_prefixes=["/pay", "/payment"]) + + @app.get("/pay/foo") + def pay_foo(): + raise ValueError("should not be matching") + + @app.get("/foo") + def foo(): + ... + + # WHEN calling handler + response = app({"httpMethod": "GET", "path": path}, None) + + # THEN a route for `/foo` should be found + assert response["statusCode"] == 200 + + +@pytest.mark.parametrize( + "prefix", + [ + pytest.param("/foo", id="String are not supported"), + pytest.param({"/foo"}, id="Sets are not supported"), + pytest.param({"foo": "/foo"}, id="Dicts are not supported"), + pytest.param(tuple("/foo"), id="Tuples are not supported"), + pytest.param([None, 1, "", False], id="List of invalid values"), + ], +) +def test_ignore_invalid(prefix): + # GIVEN an invalid prefix + app = ApiGatewayResolver(strip_prefixes=prefix) + + @app.get("/foo/status") + def foo(): + ... + + # WHEN calling handler + response = app({"httpMethod": "GET", "path": "/foo/status"}, None) + + # THEN a route for `/foo/status` should be found + # so no prefix was stripped from the request path + assert response["statusCode"] == 200