30
30
31
31
from aws_lambda_powertools .event_handler import content_types
32
32
from aws_lambda_powertools .event_handler .exceptions import NotFoundError , ServiceError
33
+ from aws_lambda_powertools .event_handler .openapi .types import (
34
+ COMPONENT_REF_PREFIX ,
35
+ METHODS_WITH_BODY ,
36
+ validation_error_definition ,
37
+ validation_error_response_definition ,
38
+ )
33
39
from aws_lambda_powertools .event_handler .response import Response
34
40
from aws_lambda_powertools .shared .functions import powertools_dev_is_set
35
41
from aws_lambda_powertools .shared .json_encoder import Encoder
66
72
Tag ,
67
73
)
68
74
from aws_lambda_powertools .event_handler .openapi .params import Dependant
69
- from aws_lambda_powertools .event_handler .openapi .types import TypeModelOrEnum
75
+ from aws_lambda_powertools .event_handler .openapi .types import (
76
+ TypeModelOrEnum ,
77
+ )
70
78
71
79
72
80
class ProxyEventType (Enum ):
@@ -266,6 +274,9 @@ def __init__(
266
274
# _dependant is used to cache the dependant model for the handler function
267
275
self ._dependant : Optional ["Dependant" ] = None
268
276
277
+ # _body_field is used to cache the dependant model for the body field
278
+ self ._body_field : Optional ["ModelField" ] = None
279
+
269
280
def __call__ (
270
281
self ,
271
282
router_middlewares : List [Callable ],
@@ -365,6 +376,15 @@ def dependant(self) -> "Dependant":
365
376
366
377
return self ._dependant
367
378
379
+ @property
380
+ def body_field (self ) -> Optional ["ModelField" ]:
381
+ if self ._body_field is None :
382
+ from aws_lambda_powertools .event_handler .openapi .params import _get_body_field
383
+
384
+ self ._body_field = _get_body_field (dependant = self .dependant , name = self .operation_id )
385
+
386
+ return self ._body_field
387
+
368
388
def _get_openapi_path (
369
389
self ,
370
390
* ,
@@ -373,9 +393,7 @@ def _get_openapi_path(
373
393
model_name_map : Dict ["TypeModelOrEnum" , str ],
374
394
field_mapping : Dict [Tuple ["ModelField" , Literal ["validation" , "serialization" ]], "JsonSchemaValue" ],
375
395
) -> Tuple [Dict [str , Any ], Dict [str , Any ]]:
376
- from aws_lambda_powertools .event_handler .openapi .dependant import (
377
- get_flat_params ,
378
- )
396
+ from aws_lambda_powertools .event_handler .openapi .dependant import get_flat_params
379
397
380
398
path = {}
381
399
definitions : Dict [str , Any ] = {}
@@ -396,6 +414,15 @@ def _get_openapi_path(
396
414
all_parameters .update (required_parameters )
397
415
operation ["parameters" ] = list (all_parameters .values ())
398
416
417
+ if self .method .upper () in METHODS_WITH_BODY :
418
+ request_body_oai = self ._openapi_operation_request_body (
419
+ body_field = self .body_field ,
420
+ model_name_map = model_name_map ,
421
+ field_mapping = field_mapping ,
422
+ )
423
+ if request_body_oai :
424
+ operation ["requestBody" ] = request_body_oai
425
+
399
426
responses = operation .setdefault ("responses" , {})
400
427
success_response = responses .setdefault ("200" , {})
401
428
success_response ["description" ] = self .response_description or _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION
@@ -411,6 +438,24 @@ def _get_openapi_path(
411
438
),
412
439
)
413
440
441
+ # Validation responses
442
+ operation ["responses" ]["422" ] = {
443
+ "description" : "Validation Error" ,
444
+ "content" : {
445
+ "application/json" : {
446
+ "schema" : {"$ref" : COMPONENT_REF_PREFIX + "HTTPValidationError" },
447
+ },
448
+ },
449
+ }
450
+
451
+ if "ValidationError" not in definitions :
452
+ definitions .update (
453
+ {
454
+ "ValidationError" : validation_error_definition ,
455
+ "HTTPValidationError" : validation_error_response_definition ,
456
+ },
457
+ )
458
+
414
459
path [self .method .lower ()] = operation
415
460
416
461
# Generate the response schema
@@ -442,6 +487,38 @@ def _openapi_operation_metadata(self, operation_ids: Set[str]) -> Dict[str, Any]
442
487
443
488
return operation
444
489
490
+ @staticmethod
491
+ def _openapi_operation_request_body (
492
+ * ,
493
+ body_field : Optional ["ModelField" ],
494
+ model_name_map : Dict ["TypeModelOrEnum" , str ],
495
+ field_mapping : Dict [Tuple ["ModelField" , Literal ["validation" , "serialization" ]], "JsonSchemaValue" ],
496
+ ) -> Optional [Dict [str , Any ]]:
497
+ from aws_lambda_powertools .event_handler .openapi .compat import ModelField , get_schema_from_model_field
498
+ from aws_lambda_powertools .event_handler .openapi .params import Body
499
+
500
+ if not body_field :
501
+ return None
502
+
503
+ if not isinstance (body_field , ModelField ):
504
+ raise AssertionError (f"Expected ModelField, got { body_field } " )
505
+
506
+ body_schema = get_schema_from_model_field (
507
+ field = body_field ,
508
+ model_name_map = model_name_map ,
509
+ field_mapping = field_mapping ,
510
+ )
511
+
512
+ field_info = cast (Body , body_field .field_info )
513
+ request_media_type = field_info .media_type
514
+ required = body_field .required
515
+ request_body_oai : Dict [str , Any ] = {}
516
+ if required :
517
+ request_body_oai ["required" ] = required
518
+ request_media_content : Dict [str , Any ] = {"schema" : body_schema }
519
+ request_body_oai ["content" ] = {request_media_type : request_media_content }
520
+ return request_body_oai
521
+
445
522
@staticmethod
446
523
def _openapi_operation_parameters (
447
524
* ,
@@ -1095,6 +1172,7 @@ def __init__(
1095
1172
debug : Optional [bool ] = None ,
1096
1173
serializer : Optional [Callable [[Dict ], str ]] = None ,
1097
1174
strip_prefixes : Optional [List [Union [str , Pattern ]]] = None ,
1175
+ enable_validation : Optional [bool ] = False ,
1098
1176
):
1099
1177
"""
1100
1178
Parameters
@@ -1112,6 +1190,8 @@ def __init__(
1112
1190
optional list of prefixes to be removed from the request path before doing the routing.
1113
1191
This is often used with api gateways with multiple custom mappings.
1114
1192
Each prefix can be a static string or a compiled regex pattern
1193
+ enable_validation: Optional[bool]
1194
+ Enables validation of the request body against the route schema, by default False.
1115
1195
"""
1116
1196
self ._proxy_type = proxy_type
1117
1197
self ._dynamic_routes : List [Route ] = []
@@ -1122,13 +1202,27 @@ def __init__(
1122
1202
self ._cors_enabled : bool = cors is not None
1123
1203
self ._cors_methods : Set [str ] = {"OPTIONS" }
1124
1204
self ._debug = self ._has_debug (debug )
1205
+ self ._enable_validation = enable_validation
1125
1206
self ._strip_prefixes = strip_prefixes
1126
1207
self .context : Dict = {} # early init as customers might add context before event resolution
1127
1208
self .processed_stack_frames = []
1128
1209
1129
1210
# Allow for a custom serializer or a concise json serialization
1130
1211
self ._serializer = serializer or partial (json .dumps , separators = ("," , ":" ), cls = Encoder )
1131
1212
1213
+ if self ._enable_validation :
1214
+ from aws_lambda_powertools .event_handler .middlewares .openapi_validation import OpenAPIValidationMiddleware
1215
+
1216
+ self .use ([OpenAPIValidationMiddleware ()])
1217
+
1218
+ # When using validation, we need to skip the serializer, as the middleware is doing it automatically
1219
+ # However, if the user is using a custom serializer, we need to abort
1220
+ if serializer :
1221
+ raise ValueError ("Cannot use a custom serializer when using validation" )
1222
+
1223
+ # Install a dummy serializer
1224
+ self ._serializer = lambda args : args # type: ignore
1225
+
1132
1226
def get_openapi_schema (
1133
1227
self ,
1134
1228
* ,
@@ -1706,21 +1800,28 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
1706
1800
Returns a list of fields from the routes
1707
1801
"""
1708
1802
1803
+ from aws_lambda_powertools .event_handler .openapi .compat import ModelField
1709
1804
from aws_lambda_powertools .event_handler .openapi .dependant import (
1710
1805
get_flat_params ,
1711
1806
)
1712
1807
1808
+ body_fields_from_routes : List ["ModelField" ] = []
1713
1809
responses_from_routes : List ["ModelField" ] = []
1714
1810
request_fields_from_routes : List ["ModelField" ] = []
1715
1811
1716
1812
for route in routes :
1813
+ if route .body_field :
1814
+ if not isinstance (route .body_field , ModelField ):
1815
+ raise AssertionError ("A request body myst be a Pydantic Field" )
1816
+ body_fields_from_routes .append (route .body_field )
1817
+
1717
1818
params = get_flat_params (route .dependant )
1718
1819
request_fields_from_routes .extend (params )
1719
1820
1720
1821
if route .dependant .return_param :
1721
1822
responses_from_routes .append (route .dependant .return_param )
1722
1823
1723
- flat_models = list (responses_from_routes + request_fields_from_routes )
1824
+ flat_models = list (responses_from_routes + request_fields_from_routes + body_fields_from_routes )
1724
1825
return flat_models
1725
1826
1726
1827
0 commit comments