1
+ from __future__ import annotations
2
+
1
3
import dataclasses
2
4
import json
3
5
import logging
4
6
from copy import deepcopy
5
- from typing import Any , Callable , Dict , List , Mapping , MutableMapping , Optional , Sequence , Tuple
7
+ from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence
6
8
7
9
from pydantic import BaseModel
8
10
9
- from aws_lambda_powertools .event_handler import Response
10
- from aws_lambda_powertools .event_handler .api_gateway import Route
11
- from aws_lambda_powertools .event_handler .middlewares import BaseMiddlewareHandler , NextMiddleware
11
+ from aws_lambda_powertools .event_handler .middlewares import BaseMiddlewareHandler
12
12
from aws_lambda_powertools .event_handler .openapi .compat import (
13
- ModelField ,
14
13
_model_dump ,
15
14
_normalize_errors ,
16
15
_regenerate_error_with_loc ,
20
19
from aws_lambda_powertools .event_handler .openapi .encoders import jsonable_encoder
21
20
from aws_lambda_powertools .event_handler .openapi .exceptions import RequestValidationError
22
21
from aws_lambda_powertools .event_handler .openapi .params import Param
23
- from aws_lambda_powertools .event_handler .openapi .types import IncEx
24
- from aws_lambda_powertools .event_handler .types import EventHandlerInstance
22
+
23
+ if TYPE_CHECKING :
24
+ from aws_lambda_powertools .event_handler import Response
25
+ from aws_lambda_powertools .event_handler .api_gateway import Route
26
+ from aws_lambda_powertools .event_handler .middlewares import NextMiddleware
27
+ from aws_lambda_powertools .event_handler .openapi .compat import ModelField
28
+ from aws_lambda_powertools .event_handler .openapi .types import IncEx
29
+ from aws_lambda_powertools .event_handler .types import EventHandlerInstance
25
30
26
31
logger = logging .getLogger (__name__ )
27
32
@@ -36,8 +41,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
36
41
--------
37
42
38
43
```python
39
- from typing import List
40
-
41
44
from pydantic import BaseModel
42
45
43
46
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +53,12 @@ class Todo(BaseModel):
50
53
app = APIGatewayRestResolver(enable_validation=True)
51
54
52
55
@app.get("/todos")
53
- def get_todos(): List [Todo]:
56
+ def get_todos(): list [Todo]:
54
57
return [Todo(name="hello world")]
55
58
```
56
59
"""
57
60
58
- def __init__ (self , validation_serializer : Optional [ Callable [[Any ], str ]] = None ):
61
+ def __init__ (self , validation_serializer : Callable [[Any ], str ] | None = None ):
59
62
"""
60
63
Initialize the OpenAPIValidationMiddleware.
61
64
@@ -72,8 +75,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
72
75
73
76
route : Route = app .context ["_route" ]
74
77
75
- values : Dict [str , Any ] = {}
76
- errors : List [Any ] = []
78
+ values : dict [str , Any ] = {}
79
+ errors : list [Any ] = []
77
80
78
81
# Process path values, which can be found on the route_args
79
82
path_values , path_errors = _request_params_to_args (
@@ -147,10 +150,10 @@ def _handle_response(self, *, route: Route, response: Response):
147
150
def _serialize_response (
148
151
self ,
149
152
* ,
150
- field : Optional [ ModelField ] = None ,
153
+ field : ModelField | None = None ,
151
154
response_content : Any ,
152
- include : Optional [ IncEx ] = None ,
153
- exclude : Optional [ IncEx ] = None ,
155
+ include : IncEx | None = None ,
156
+ exclude : IncEx | None = None ,
154
157
by_alias : bool = True ,
155
158
exclude_unset : bool = False ,
156
159
exclude_defaults : bool = False ,
@@ -160,7 +163,7 @@ def _serialize_response(
160
163
Serialize the response content according to the field type.
161
164
"""
162
165
if field :
163
- errors : List [ Dict [str , Any ]] = []
166
+ errors : list [ dict [str , Any ]] = []
164
167
# MAINTENANCE: remove this when we drop pydantic v1
165
168
if not hasattr (field , "serializable" ):
166
169
response_content = self ._prepare_response_content (
@@ -232,7 +235,7 @@ def _prepare_response_content(
232
235
return dataclasses .asdict (res )
233
236
return res
234
237
235
- def _get_body (self , app : EventHandlerInstance ) -> Dict [str , Any ]:
238
+ def _get_body (self , app : EventHandlerInstance ) -> dict [str , Any ]:
236
239
"""
237
240
Get the request body from the event, and parse it as JSON.
238
241
"""
@@ -261,7 +264,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261
264
def _request_params_to_args (
262
265
required_params : Sequence [ModelField ],
263
266
received_params : Mapping [str , Any ],
264
- ) -> Tuple [ Dict [str , Any ], List [Any ]]:
267
+ ) -> tuple [ dict [str , Any ], list [Any ]]:
265
268
"""
266
269
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267
270
"""
@@ -294,14 +297,14 @@ def _request_params_to_args(
294
297
295
298
296
299
def _request_body_to_args (
297
- required_params : List [ModelField ],
298
- received_body : Optional [ Dict [ str , Any ]] ,
299
- ) -> Tuple [ Dict [str , Any ], List [ Dict [str , Any ]]]:
300
+ required_params : list [ModelField ],
301
+ received_body : dict [ str , Any ] | None ,
302
+ ) -> tuple [ dict [str , Any ], list [ dict [str , Any ]]]:
300
303
"""
301
304
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302
305
"""
303
- values : Dict [str , Any ] = {}
304
- errors : List [ Dict [str , Any ]] = []
306
+ values : dict [str , Any ] = {}
307
+ errors : list [ dict [str , Any ]] = []
305
308
306
309
received_body , field_alias_omitted = _get_embed_body (
307
310
field = required_params [0 ],
@@ -313,11 +316,11 @@ def _request_body_to_args(
313
316
# This sets the location to:
314
317
# { "user": { object } } if field.alias == user
315
318
# { { object } if field_alias is omitted
316
- loc : Tuple [str , ...] = ("body" , field .alias )
319
+ loc : tuple [str , ...] = ("body" , field .alias )
317
320
if field_alias_omitted :
318
321
loc = ("body" ,)
319
322
320
- value : Optional [ Any ] = None
323
+ value : Any | None = None
321
324
322
325
# Now that we know what to look for, try to get the value from the received body
323
326
if received_body is not None :
@@ -347,8 +350,8 @@ def _validate_field(
347
350
* ,
348
351
field : ModelField ,
349
352
value : Any ,
350
- loc : Tuple [str , ...],
351
- existing_errors : List [ Dict [str , Any ]],
353
+ loc : tuple [str , ...],
354
+ existing_errors : list [ dict [str , Any ]],
352
355
):
353
356
"""
354
357
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +370,9 @@ def _validate_field(
367
370
def _get_embed_body (
368
371
* ,
369
372
field : ModelField ,
370
- required_params : List [ModelField ],
371
- received_body : Optional [ Dict [ str , Any ]] ,
372
- ) -> Tuple [ Optional [ Dict [ str , Any ]] , bool ]:
373
+ required_params : list [ModelField ],
374
+ received_body : dict [ str , Any ] | None ,
375
+ ) -> tuple [ dict [ str , Any ] | None , bool ]:
373
376
field_info = field .field_info
374
377
embed = getattr (field_info , "embed" , None )
375
378
@@ -382,15 +385,15 @@ def _get_embed_body(
382
385
383
386
384
387
def _normalize_multi_query_string_with_param (
385
- query_string : Dict [str , List [str ]],
388
+ query_string : dict [str , list [str ]],
386
389
params : Sequence [ModelField ],
387
- ) -> Dict [str , Any ]:
390
+ ) -> dict [str , Any ]:
388
391
"""
389
392
Extract and normalize resolved_query_string_parameters
390
393
391
394
Parameters
392
395
----------
393
- query_string: Dict
396
+ query_string: dict
394
397
A dictionary containing the initial query string parameters.
395
398
params: Sequence[ModelField]
396
399
A sequence of ModelField objects representing parameters.
@@ -399,7 +402,7 @@ def _normalize_multi_query_string_with_param(
399
402
-------
400
403
A dictionary containing the processed multi_query_string_parameters.
401
404
"""
402
- resolved_query_string : Dict [str , Any ] = query_string
405
+ resolved_query_string : dict [str , Any ] = query_string
403
406
for param in filter (is_scalar_field , params ):
404
407
try :
405
408
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +419,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416
419
417
420
Parameters
418
421
----------
419
- headers: Dict
422
+ headers: MutableMapping[str, Any]
420
423
A dictionary containing the initial header parameters.
421
424
params: Sequence[ModelField]
422
425
A sequence of ModelField objects representing parameters.
0 commit comments