Skip to content

Commit 1d6b49d

Browse files
author
Michael Brewer
committed
fix: add more input validation
1 parent e492769 commit 1d6b49d

File tree

2 files changed

+40
-4
lines changed

2 files changed

+40
-4
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+14-4
Original file line numberDiff line numberDiff line change
@@ -542,13 +542,23 @@ def _resolve(self) -> ResponseBuilder:
542542

543543
def _remove_prefix(self, path: str) -> str:
544544
"""Remove the configured prefix from the path"""
545-
if self._strip_prefixes:
546-
for prefix in self._strip_prefixes:
547-
if path.startswith(prefix + "/"):
548-
return path[len(prefix) :]
545+
if not isinstance(self._strip_prefixes, list):
546+
return path
547+
548+
for prefix in self._strip_prefixes:
549+
if self._path_starts_with(path, prefix):
550+
return path[len(prefix) :]
549551

550552
return path
551553

554+
@staticmethod
555+
def _path_starts_with(path: str, prefix: str):
556+
"""Returns true if the `path` starts with a prefix plus a `/`"""
557+
if not isinstance(prefix, str) or len(prefix) == 0:
558+
return False
559+
560+
return path.startswith(prefix + "/")
561+
552562
def _not_found(self, method: str) -> ResponseBuilder:
553563
"""Called when no matching route was found and includes support for the cors preflight response"""
554564
headers = {}

Diff for: tests/functional/event_handler/test_api_gateway.py

+26
Original file line numberDiff line numberDiff line change
@@ -797,3 +797,29 @@ def foo():
797797

798798
# THEN a route for `/foo` should be found
799799
assert response["statusCode"] == 200
800+
801+
802+
@pytest.mark.parametrize(
803+
"prefix",
804+
[
805+
pytest.param("/foo", id="String are not supported"),
806+
pytest.param({"/foo"}, id="Sets are not supported"),
807+
pytest.param({"foo": "/foo"}, id="Dicts are not supported"),
808+
pytest.param(tuple("/foo"), id="Tuples are not supported"),
809+
pytest.param([None, 1, "", False], id="List of invalid values"),
810+
],
811+
)
812+
def test_ignore_invalid(prefix):
813+
# GIVEN an invalid prefix
814+
app = ApiGatewayResolver(strip_prefixes=prefix)
815+
816+
@app.get("/foo/status")
817+
def foo():
818+
...
819+
820+
# WHEN calling handler
821+
response = app({"httpMethod": "GET", "path": "/foo/status"}, None)
822+
823+
# THEN a route for `/foo/status` should be found
824+
# so no prefix was stripped from the request path
825+
assert response["statusCode"] == 200

0 commit comments

Comments
 (0)