30
30
31
31
from aws_lambda_powertools .event_handler import content_types
32
32
from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
33
+ from aws_lambda_powertools .event_handler .openapi .types import (
34
+ COMPONENT_REF_PREFIX ,
35
+ METHODS_WITH_BODY ,
36
+ validation_error_definition ,
37
+ validation_error_response_definition ,
38
+ )
33
39
from aws_lambda_powertools .event_handler .response import Response
34
40
from aws_lambda_powertools .shared .functions import powertools_dev_is_set
35
41
from aws_lambda_powertools .shared .json_encoder import Encoder
67
73
Tag ,
68
74
)
69
75
from aws_lambda_powertools .event_handler .openapi .params import Dependant
70
- from aws_lambda_powertools .event_handler .openapi .types import TypeModelOrEnum
76
+ from aws_lambda_powertools .event_handler .openapi .types import (
77
+ TypeModelOrEnum ,
78
+ )
71
79
72
80
73
81
class ProxyEventType (Enum ):
@@ -268,6 +276,9 @@ def __init__(
268
276
# _dependant is used to cache the dependant model for the handler function
269
277
self ._dependant : Optional ["Dependant" ] = None
270
278
279
+ # _body_field is used to cache the dependant model for the body field
280
+ self ._body_field : Optional ["ModelField" ] = None
281
+
271
282
def __call__ (
272
283
self ,
273
284
router_middlewares : List [Callable ],
@@ -367,6 +378,15 @@ def dependant(self) -> "Dependant":
367
378
368
379
return self ._dependant
369
380
381
+ @property
382
+ def body_field (self ) -> Optional ["ModelField" ]:
383
+ if self ._body_field is None :
384
+ from aws_lambda_powertools .event_handler .openapi .params import _get_body_field
385
+
386
+ self ._body_field = _get_body_field (dependant = self .dependant , name = self .operation_id )
387
+
388
+ return self ._body_field
389
+
370
390
def _get_openapi_path (
371
391
self ,
372
392
* ,
@@ -375,9 +395,7 @@ def _get_openapi_path(
375
395
model_name_map : Dict ["TypeModelOrEnum" , str ],
376
396
field_mapping : Dict [Tuple ["ModelField" , Literal ["validation" , "serialization" ]], "JsonSchemaValue" ],
377
397
) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
378
- from aws_lambda_powertools .event_handler .openapi .dependant import (
379
- get_flat_params ,
380
- )
398
+ from aws_lambda_powertools .event_handler .openapi .dependant import get_flat_params
381
399
382
400
path = {}
383
401
definitions : Dict [str , Any ] = {}
@@ -398,6 +416,15 @@ def _get_openapi_path(
398
416
all_parameters .update (required_parameters )
399
417
operation ["parameters" ] = list (all_parameters .values ())
400
418
419
+ if self .method .upper () in METHODS_WITH_BODY :
420
+ request_body_oai = self ._openapi_operation_request_body (
421
+ body_field = self .body_field ,
422
+ model_name_map = model_name_map ,
423
+ field_mapping = field_mapping ,
424
+ )
425
+ if request_body_oai :
426
+ operation ["requestBody" ] = request_body_oai
427
+
401
428
responses = operation .setdefault ("responses" , {})
402
429
success_response = responses .setdefault ("200" , {})
403
430
success_response ["description" ] = self .response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
@@ -413,6 +440,24 @@ def _get_openapi_path(
413
440
),
414
441
)
415
442
443
+ # Validation responses
444
+ operation ["responses" ]["422" ] = {
445
+ "description" : "Validation Error" ,
446
+ "content" : {
447
+ "application/json" : {
448
+ "schema" : {"$ref" : COMPONENT_REF_PREFIX + "HTTPValidationError" },
449
+ },
450
+ },
451
+ }
452
+
453
+ if "ValidationError" not in definitions :
454
+ definitions .update (
455
+ {
456
+ "ValidationError" : validation_error_definition ,
457
+ "HTTPValidationError" : validation_error_response_definition ,
458
+ },
459
+ )
460
+
416
461
path [self .method .lower ()] = operation
417
462
418
463
# Generate the response schema
@@ -444,6 +489,38 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]
444
489
445
490
return operation
446
491
492
+ @staticmethod
493
+ def _openapi_operation_request_body (
494
+ * ,
495
+ body_field : Optional ["ModelField" ],
496
+ model_name_map : Dict ["TypeModelOrEnum" , str ],
497
+ field_mapping : Dict [Tuple ["ModelField" , Literal ["validation" , "serialization" ]], "JsonSchemaValue" ],
498
+ ) -> Optional [Dict [str , Any ]]:
499
+ from aws_lambda_powertools .event_handler .openapi .compat import ModelField , get_schema_from_model_field
500
+ from aws_lambda_powertools .event_handler .openapi .params import Body
501
+
502
+ if not body_field :
503
+ return None
504
+
505
+ if not isinstance (body_field , ModelField ):
506
+ raise AssertionError (f"Expected ModelField, got { body_field } " )
507
+
508
+ body_schema = get_schema_from_model_field (
509
+ field = body_field ,
510
+ model_name_map = model_name_map ,
511
+ field_mapping = field_mapping ,
512
+ )
513
+
514
+ field_info = cast (Body , body_field .field_info )
515
+ request_media_type = field_info .media_type
516
+ required = body_field .required
517
+ request_body_oai : Dict [str , Any ] = {}
518
+ if required :
519
+ request_body_oai ["required" ] = required
520
+ request_media_content : Dict [str , Any ] = {"schema" : body_schema }
521
+ request_body_oai ["content" ] = {request_media_type : request_media_content }
522
+ return request_body_oai
523
+
447
524
@staticmethod
448
525
def _openapi_operation_parameters (
449
526
* ,
@@ -1097,6 +1174,7 @@ def __init__(
1097
1174
debug : Optional [bool ] = None ,
1098
1175
serializer : Optional [Callable [[Dict ], str ]] = None ,
1099
1176
strip_prefixes : Optional [List [Union [str , Pattern ]]] = None ,
1177
+ enable_validation : Optional [bool ] = False ,
1100
1178
):
1101
1179
"""
1102
1180
Parameters
@@ -1114,6 +1192,8 @@ def __init__(
1114
1192
optional list of prefixes to be removed from the request path before doing the routing.
1115
1193
This is often used with api gateways with multiple custom mappings.
1116
1194
Each prefix can be a static string or a compiled regex pattern
1195
+ enable_validation: Optional[bool]
1196
+ Enables validation of the request body against the route schema, by default False.
1117
1197
"""
1118
1198
self ._proxy_type = proxy_type
1119
1199
self ._dynamic_routes : List [Route ] = []
@@ -1124,13 +1204,27 @@ def __init__(
1124
1204
self ._cors_enabled : bool = cors is not None
1125
1205
self ._cors_methods : Set [str ] = {"OPTIONS" }
1126
1206
self ._debug = self ._has_debug (debug )
1207
+ self ._enable_validation = enable_validation
1127
1208
self ._strip_prefixes = strip_prefixes
1128
1209
self .context : Dict = {} # early init as customers might add context before event resolution
1129
1210
self .processed_stack_frames = []
1130
1211
1131
1212
# Allow for a custom serializer or a concise json serialization
1132
1213
self ._serializer = serializer or partial (json .dumps , separators = ("," , ":" ), cls = Encoder )
1133
1214
1215
+ if self ._enable_validation :
1216
+ from aws_lambda_powertools .event_handler .middlewares .openapi_validation import OpenAPIValidationMiddleware
1217
+
1218
+ self .use ([OpenAPIValidationMiddleware ()])
1219
+
1220
+ # When using validation, we need to skip the serializer, as the middleware is doing it automatically
1221
+ # However, if the user is using a custom serializer, we need to abort
1222
+ if serializer :
1223
+ raise ValueError ("Cannot use a custom serializer when using validation" )
1224
+
1225
+ # Install a dummy serializer
1226
+ self ._serializer = lambda args : args # type: ignore
1227
+
1134
1228
def get_openapi_schema (
1135
1229
self ,
1136
1230
* ,
@@ -1711,21 +1805,28 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
1711
1805
Returns a list of fields from the routes
1712
1806
"""
1713
1807
1808
+ from aws_lambda_powertools .event_handler .openapi .compat import ModelField
1714
1809
from aws_lambda_powertools .event_handler .openapi .dependant import (
1715
1810
get_flat_params ,
1716
1811
)
1717
1812
1813
+ body_fields_from_routes : List ["ModelField" ] = []
1718
1814
responses_from_routes : List ["ModelField" ] = []
1719
1815
request_fields_from_routes : List ["ModelField" ] = []
1720
1816
1721
1817
for route in routes :
1818
+ if route .body_field :
1819
+ if not isinstance (route .body_field , ModelField ):
1820
+ raise AssertionError ("A request body myst be a Pydantic Field" )
1821
+ body_fields_from_routes .append (route .body_field )
1822
+
1722
1823
params = get_flat_params (route .dependant )
1723
1824
request_fields_from_routes .extend (params )
1724
1825
1725
1826
if route .dependant .return_param :
1726
1827
responses_from_routes .append (route .dependant .return_param )
1727
1828
1728
- flat_models = list (responses_from_routes + request_fields_from_routes )
1829
+ flat_models = list (responses_from_routes + request_fields_from_routes + body_fields_from_routes )
1729
1830
return flat_models
1730
1831
1731
1832
0 commit comments