Skip to content

Commit 1fcff5c

Browse files
committed
feat: add a validation middleware
1 parent f7e6bc1 commit 1fcff5c

File tree

9 files changed

+1381
-112
lines changed

9 files changed

+1381
-112
lines changed

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+106-5
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030

3131
from aws_lambda_powertools.event_handler import content_types
3232
from aws_lambda_powertools.event_handler.exceptions import NotFoundError, ServiceError
33+
from aws_lambda_powertools.event_handler.openapi.types import (
34+
COMPONENT_REF_PREFIX,
35+
METHODS_WITH_BODY,
36+
validation_error_definition,
37+
validation_error_response_definition,
38+
)
3339
from aws_lambda_powertools.event_handler.response import Response
3440
from aws_lambda_powertools.shared.functions import powertools_dev_is_set
3541
from aws_lambda_powertools.shared.json_encoder import Encoder
@@ -66,7 +72,9 @@
6672
Tag,
6773
)
6874
from aws_lambda_powertools.event_handler.openapi.params import Dependant
69-
from aws_lambda_powertools.event_handler.openapi.types import TypeModelOrEnum
75+
from aws_lambda_powertools.event_handler.openapi.types import (
76+
TypeModelOrEnum,
77+
)
7078

7179

7280
class ProxyEventType(Enum):
@@ -266,6 +274,9 @@ def __init__(
266274
# _dependant is used to cache the dependant model for the handler function
267275
self._dependant: Optional["Dependant"] = None
268276

277+
# _body_field is used to cache the dependant model for the body field
278+
self._body_field: Optional["ModelField"] = None
279+
269280
def __call__(
270281
self,
271282
router_middlewares: List[Callable],
@@ -365,6 +376,15 @@ def dependant(self) -> "Dependant":
365376

366377
return self._dependant
367378

379+
@property
380+
def body_field(self) -> Optional["ModelField"]:
381+
if self._body_field is None:
382+
from aws_lambda_powertools.event_handler.openapi.params import _get_body_field
383+
384+
self._body_field = _get_body_field(dependant=self.dependant, name=self.operation_id)
385+
386+
return self._body_field
387+
368388
def _get_openapi_path(
369389
self,
370390
*,
@@ -373,9 +393,7 @@ def _get_openapi_path(
373393
model_name_map: Dict["TypeModelOrEnum", str],
374394
field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"],
375395
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
376-
from aws_lambda_powertools.event_handler.openapi.dependant import (
377-
get_flat_params,
378-
)
396+
from aws_lambda_powertools.event_handler.openapi.dependant import get_flat_params
379397

380398
path = {}
381399
definitions: Dict[str, Any] = {}
@@ -396,6 +414,15 @@ def _get_openapi_path(
396414
all_parameters.update(required_parameters)
397415
operation["parameters"] = list(all_parameters.values())
398416

417+
if self.method.upper() in METHODS_WITH_BODY:
418+
request_body_oai = self._openapi_operation_request_body(
419+
body_field=self.body_field,
420+
model_name_map=model_name_map,
421+
field_mapping=field_mapping,
422+
)
423+
if request_body_oai:
424+
operation["requestBody"] = request_body_oai
425+
399426
responses = operation.setdefault("responses", {})
400427
success_response = responses.setdefault("200", {})
401428
success_response["description"] = self.response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
@@ -411,6 +438,24 @@ def _get_openapi_path(
411438
),
412439
)
413440

441+
# Validation responses
442+
operation["responses"]["422"] = {
443+
"description": "Validation Error",
444+
"content": {
445+
"application/json": {
446+
"schema": {"$ref": COMPONENT_REF_PREFIX + "HTTPValidationError"},
447+
},
448+
},
449+
}
450+
451+
if "ValidationError" not in definitions:
452+
definitions.update(
453+
{
454+
"ValidationError": validation_error_definition,
455+
"HTTPValidationError": validation_error_response_definition,
456+
},
457+
)
458+
414459
path[self.method.lower()] = operation
415460

416461
# Generate the response schema
@@ -442,6 +487,38 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]
442487

443488
return operation
444489

490+
@staticmethod
491+
def _openapi_operation_request_body(
492+
*,
493+
body_field: Optional["ModelField"],
494+
model_name_map: Dict["TypeModelOrEnum", str],
495+
field_mapping: Dict[Tuple["ModelField", Literal["validation", "serialization"]], "JsonSchemaValue"],
496+
) -> Optional[Dict[str, Any]]:
497+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField, get_schema_from_model_field
498+
from aws_lambda_powertools.event_handler.openapi.params import Body
499+
500+
if not body_field:
501+
return None
502+
503+
if not isinstance(body_field, ModelField):
504+
raise AssertionError(f"Expected ModelField, got {body_field}")
505+
506+
body_schema = get_schema_from_model_field(
507+
field=body_field,
508+
model_name_map=model_name_map,
509+
field_mapping=field_mapping,
510+
)
511+
512+
field_info = cast(Body, body_field.field_info)
513+
request_media_type = field_info.media_type
514+
required = body_field.required
515+
request_body_oai: Dict[str, Any] = {}
516+
if required:
517+
request_body_oai["required"] = required
518+
request_media_content: Dict[str, Any] = {"schema": body_schema}
519+
request_body_oai["content"] = {request_media_type: request_media_content}
520+
return request_body_oai
521+
445522
@staticmethod
446523
def _openapi_operation_parameters(
447524
*,
@@ -1095,6 +1172,7 @@ def __init__(
10951172
debug: Optional[bool] = None,
10961173
serializer: Optional[Callable[[Dict], str]] = None,
10971174
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
1175+
enable_validation: Optional[bool] = False,
10981176
):
10991177
"""
11001178
Parameters
@@ -1112,6 +1190,8 @@ def __init__(
11121190
optional list of prefixes to be removed from the request path before doing the routing.
11131191
This is often used with api gateways with multiple custom mappings.
11141192
Each prefix can be a static string or a compiled regex pattern
1193+
enable_validation: Optional[bool]
1194+
Enables validation of the request body against the route schema, by default False.
11151195
"""
11161196
self._proxy_type = proxy_type
11171197
self._dynamic_routes: List[Route] = []
@@ -1122,13 +1202,27 @@ def __init__(
11221202
self._cors_enabled: bool = cors is not None
11231203
self._cors_methods: Set[str] = {"OPTIONS"}
11241204
self._debug = self._has_debug(debug)
1205+
self._enable_validation = enable_validation
11251206
self._strip_prefixes = strip_prefixes
11261207
self.context: Dict = {} # early init as customers might add context before event resolution
11271208
self.processed_stack_frames = []
11281209

11291210
# Allow for a custom serializer or a concise json serialization
11301211
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
11311212

1213+
if self._enable_validation:
1214+
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware
1215+
1216+
self.use([OpenAPIValidationMiddleware()])
1217+
1218+
# When using validation, we need to skip the serializer, as the middleware is doing it automatically
1219+
# However, if the user is using a custom serializer, we need to abort
1220+
if serializer:
1221+
raise ValueError("Cannot use a custom serializer when using validation")
1222+
1223+
# Install a dummy serializer
1224+
self._serializer = lambda args: args # type: ignore
1225+
11321226
def get_openapi_schema(
11331227
self,
11341228
*,
@@ -1706,21 +1800,28 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
17061800
Returns a list of fields from the routes
17071801
"""
17081802

1803+
from aws_lambda_powertools.event_handler.openapi.compat import ModelField
17091804
from aws_lambda_powertools.event_handler.openapi.dependant import (
17101805
get_flat_params,
17111806
)
17121807

1808+
body_fields_from_routes: List["ModelField"] = []
17131809
responses_from_routes: List["ModelField"] = []
17141810
request_fields_from_routes: List["ModelField"] = []
17151811

17161812
for route in routes:
1813+
if route.body_field:
1814+
if not isinstance(route.body_field, ModelField):
1815+
raise AssertionError("A request body myst be a Pydantic Field")
1816+
body_fields_from_routes.append(route.body_field)
1817+
17171818
params = get_flat_params(route.dependant)
17181819
request_fields_from_routes.extend(params)
17191820

17201821
if route.dependant.return_param:
17211822
responses_from_routes.append(route.dependant.return_param)
17221823

1723-
flat_models = list(responses_from_routes + request_fields_from_routes)
1824+
flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
17241825
return flat_models
17251826

17261827

0 commit comments

Comments
 (0)