Skip to content

Commit d6f3264

Browse files
committed
feat: add a validation middleware
1 parent 7c7da51 commit d6f3264

File tree

9 files changed

+1381
-112
lines changed

9 files changed

+1381
-112
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+106-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030

3131
from aws_lambda_powertools.event_handler import content_types
3232
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
33+
from aws_lambda_powertools.event_handler.openapi.types import (
34+
COMPONENT_REF_PREFIX,
35+
METHODS_WITH_BODY,
36+
validation_error_definition,
37+
validation_error_response_definition,
38+
)
3339
from aws_lambda_powertools.event_handler.response import Response
3440
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
3541
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -67,7 +73,9 @@
6773
Tag,
6874
)
6975
from aws_lambda_powertools.event_handler.openapi.params import Dependant
70-
from aws_lambda_powertools.event_handler.openapi.types import TypeModelOrEnum
76+
from aws_lambda_powertools.event_handler.openapi.types import (
77+
TypeModelOrEnum,
78+
)
7179

7280

7381
class ProxyEventType(Enum):
@@ -268,6 +276,9 @@ def __init__(
268276
# _dependant is used to cache the dependant model for the handler function
269277
self._dependant: Optional["Dependant"] = None
270278

279+
# _body_field is used to cache the dependant model for the body field
280+
self._body_field: Optional["ModelField"] = None
281+
271282
def __call__(
272283
self,
273284
router_middlewares: List[Callable],
@@ -367,6 +378,15 @@ def dependant(self) -> "Dependant":
367378

368379
return self._dependant
369380

381+
@property
382+
def body_field(self) -> Optional["ModelField"]:
383+
if self._body_field is None:
384+
from aws_lambda_powertools.event_handler.openapi.params import _get_body_field
385+
386+
self._body_field = _get_body_field(dependant=self.dependant, name=self.operation_id)
387+
388+
return self._body_field
389+
370390
def _get_openapi_path(
371391
self,
372392
*,
@@ -375,9 +395,7 @@ def _get_openapi_path(
375395
model_name_map: Dict["TypeModelOrEnum", str],
376396
field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"],
377397
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
378-
from aws_lambda_powertools.event_handler.openapi.dependant import (
379-
get_flat_params,
380-
)
398+
from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params
381399

382400
path = {}
383401
definitions: Dict[str, Any] = {}
@@ -398,6 +416,15 @@ def _get_openapi_path(
398416
all_parameters.update(required_parameters)
399417
operation["parameters"] = list(all_parameters.values())
400418

419+
if self.method.upper() in METHODS_WITH_BODY:
420+
request_body_oai = self._openapi_operation_request_body(
421+
body_field=self.body_field,
422+
model_name_map=model_name_map,
423+
field_mapping=field_mapping,
424+
)
425+
if request_body_oai:
426+
operation["requestBody"] = request_body_oai
427+
401428
responses = operation.setdefault("responses", {})
402429
success_response = responses.setdefault("200", {})
403430
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
@@ -413,6 +440,24 @@ def _get_openapi_path(
413440
),
414441
)
415442

443+
# Validation responses
444+
operation["responses"]["422"] = {
445+
"description": "Validation Error",
446+
"content": {
447+
"application/json": {
448+
"schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"},
449+
},
450+
},
451+
}
452+
453+
if "ValidationError" not in definitions:
454+
definitions.update(
455+
{
456+
"ValidationError": validation_error_definition,
457+
"HTTPValidationError": validation_error_response_definition,
458+
},
459+
)
460+
416461
path[self.method.lower()] = operation
417462

418463
# Generate the response schema
@@ -444,6 +489,38 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]
444489

445490
return operation
446491

492+
@staticmethod
493+
def _openapi_operation_request_body(
494+
*,
495+
body_field: Optional["ModelField"],
496+
model_name_map: Dict["TypeModelOrEnum", str],
497+
field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"],
498+
) -> Optional[Dict[str, Any]]:
499+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field
500+
from aws_lambda_powertools.event_handler.openapi.params import Body
501+
502+
if not body_field:
503+
return None
504+
505+
if not isinstance(body_field, ModelField):
506+
raise AssertionError(f"Expected ModelField, got {body_field}")
507+
508+
body_schema = get_schema_from_model_field(
509+
field=body_field,
510+
model_name_map=model_name_map,
511+
field_mapping=field_mapping,
512+
)
513+
514+
field_info = cast(Body, body_field.field_info)
515+
request_media_type = field_info.media_type
516+
required = body_field.required
517+
request_body_oai: Dict[str, Any] = {}
518+
if required:
519+
request_body_oai["required"] = required
520+
request_media_content: Dict[str, Any] = {"schema": body_schema}
521+
request_body_oai["content"] = {request_media_type: request_media_content}
522+
return request_body_oai
523+
447524
@staticmethod
448525
def _openapi_operation_parameters(
449526
*,
@@ -1097,6 +1174,7 @@ def __init__(
10971174
debug: Optional[bool] = None,
10981175
serializer: Optional[Callable[[Dict], str]] = None,
10991176
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
1177+
enable_validation: Optional[bool] = False,
11001178
):
11011179
"""
11021180
Parameters
@@ -1114,6 +1192,8 @@ def __init__(
11141192
optional list of prefixes to be removed from the request path before doing the routing.
11151193
This is often used with api gateways with multiple custom mappings.
11161194
Each prefix can be a static string or a compiled regex pattern
1195+
enable_validation: Optional[bool]
1196+
Enables validation of the request body against the route schema, by default False.
11171197
"""
11181198
self._proxy_type = proxy_type
11191199
self._dynamic_routes: List[Route] = []
@@ -1124,13 +1204,27 @@ def __init__(
11241204
self._cors_enabled: bool = cors is not None
11251205
self._cors_methods: Set[str] = {"OPTIONS"}
11261206
self._debug = self._has_debug(debug)
1207+
self._enable_validation = enable_validation
11271208
self._strip_prefixes = strip_prefixes
11281209
self.context: Dict = {} # early init as customers might add context before event resolution
11291210
self.processed_stack_frames = []
11301211

11311212
# Allow for a custom serializer or a concise json serialization
11321213
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
11331214

1215+
if self._enable_validation:
1216+
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware
1217+
1218+
self.use([OpenAPIValidationMiddleware()])
1219+
1220+
# When using validation, we need to skip the serializer, as the middleware is doing it automatically
1221+
# However, if the user is using a custom serializer, we need to abort
1222+
if serializer:
1223+
raise ValueError("Cannot use a custom serializer when using validation")
1224+
1225+
# Install a dummy serializer
1226+
self._serializer = lambda args: args # type: ignore
1227+
11341228
def get_openapi_schema(
11351229
self,
11361230
*,
@@ -1711,21 +1805,28 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
17111805
Returns a list of fields from the routes
17121806
"""
17131807

1808+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField
17141809
from aws_lambda_powertools.event_handler.openapi.dependant import (
17151810
get_flat_params,
17161811
)
17171812

1813+
body_fields_from_routes: List["ModelField"] = []
17181814
responses_from_routes: List["ModelField"] = []
17191815
request_fields_from_routes: List["ModelField"] = []
17201816

17211817
for route in routes:
1818+
if route.body_field:
1819+
if not isinstance(route.body_field, ModelField):
1820+
raise AssertionError("A request body myst be a Pydantic Field")
1821+
body_fields_from_routes.append(route.body_field)
1822+
17221823
params = get_flat_params(route.dependant)
17231824
request_fields_from_routes.extend(params)
17241825

17251826
if route.dependant.return_param:
17261827
responses_from_routes.append(route.dependant.return_param)
17271828

1728-
flat_models = list(responses_from_routes + request_fields_from_routes)
1829+
flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
17291830
return flat_models
17301831

17311832

0 commit comments

Comments
 (0)