1
+ from __future__ import annotations
2
+
1
3
import dataclasses
2
4
import json
3
5
import logging
4
6
from copy import deepcopy
5
- from typing import Any , Callable , Dict , List , Mapping , MutableMapping , Optional , Sequence , Tuple
7
+ from typing import TYPE_CHECKING , Any , Callable , Mapping , MutableMapping , Sequence
6
8
7
9
from pydantic import BaseModel
8
10
9
- from aws_lambda_powertools .event_handler import Response
10
- from aws_lambda_powertools .event_handler .api_gateway import Route
11
11
from aws_lambda_powertools .event_handler .middlewares import BaseMiddlewareHandler , NextMiddleware
12
12
from aws_lambda_powertools .event_handler .openapi .compat import (
13
13
ModelField ,
20
20
from aws_lambda_powertools .event_handler .openapi .encoders import jsonable_encoder
21
21
from aws_lambda_powertools .event_handler .openapi .exceptions import RequestValidationError
22
22
from aws_lambda_powertools .event_handler .openapi .params import Param
23
- from aws_lambda_powertools .event_handler .openapi .types import IncEx
24
- from aws_lambda_powertools .event_handler .types import EventHandlerInstance
23
+
24
+ if TYPE_CHECKING :
25
+ from aws_lambda_powertools .event_handler import Response
26
+ from aws_lambda_powertools .event_handler .api_gateway import Route
27
+ from aws_lambda_powertools .event_handler .openapi .types import IncEx
28
+ from aws_lambda_powertools .event_handler .types import EventHandlerInstance
25
29
26
30
logger = logging .getLogger (__name__ )
27
31
@@ -36,8 +40,6 @@ class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
36
40
--------
37
41
38
42
```python
39
- from typing import List
40
-
41
43
from pydantic import BaseModel
42
44
43
45
from aws_lambda_powertools.event_handler.api_gateway import (
@@ -50,12 +52,12 @@ class Todo(BaseModel):
50
52
app = APIGatewayRestResolver(enable_validation=True)
51
53
52
54
@app.get("/todos")
53
- def get_todos(): List [Todo]:
55
+ def get_todos(): list [Todo]:
54
56
return [Todo(name="hello world")]
55
57
```
56
58
"""
57
59
58
- def __init__ (self , validation_serializer : Optional [ Callable [[Any ], str ]] = None ):
60
+ def __init__ (self , validation_serializer : Callable [[Any ], str ] | None = None ):
59
61
"""
60
62
Initialize the OpenAPIValidationMiddleware.
61
63
@@ -72,8 +74,8 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
72
74
73
75
route : Route = app .context ["_route" ]
74
76
75
- values : Dict [str , Any ] = {}
76
- errors : List [Any ] = []
77
+ values : dict [str , Any ] = {}
78
+ errors : list [Any ] = []
77
79
78
80
# Process path values, which can be found on the route_args
79
81
path_values , path_errors = _request_params_to_args (
@@ -147,10 +149,10 @@ def _handle_response(self, *, route: Route, response: Response):
147
149
def _serialize_response (
148
150
self ,
149
151
* ,
150
- field : Optional [ ModelField ] = None ,
152
+ field : ModelField | None = None ,
151
153
response_content : Any ,
152
- include : Optional [ IncEx ] = None ,
153
- exclude : Optional [ IncEx ] = None ,
154
+ include : IncEx | None = None ,
155
+ exclude : IncEx | None = None ,
154
156
by_alias : bool = True ,
155
157
exclude_unset : bool = False ,
156
158
exclude_defaults : bool = False ,
@@ -160,7 +162,7 @@ def _serialize_response(
160
162
Serialize the response content according to the field type.
161
163
"""
162
164
if field :
163
- errors : List [ Dict [str , Any ]] = []
165
+ errors : list [ dict [str , Any ]] = []
164
166
# MAINTENANCE: remove this when we drop pydantic v1
165
167
if not hasattr (field , "serializable" ):
166
168
response_content = self ._prepare_response_content (
@@ -232,7 +234,7 @@ def _prepare_response_content(
232
234
return dataclasses .asdict (res )
233
235
return res
234
236
235
- def _get_body (self , app : EventHandlerInstance ) -> Dict [str , Any ]:
237
+ def _get_body (self , app : EventHandlerInstance ) -> dict [str , Any ]:
236
238
"""
237
239
Get the request body from the event, and parse it as JSON.
238
240
"""
@@ -261,7 +263,7 @@ def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
261
263
def _request_params_to_args (
262
264
required_params : Sequence [ModelField ],
263
265
received_params : Mapping [str , Any ],
264
- ) -> Tuple [ Dict [str , Any ], List [Any ]]:
266
+ ) -> tuple [ dict [str , Any ], list [Any ]]:
265
267
"""
266
268
Convert the request params to a dictionary of values using validation, and returns a list of errors.
267
269
"""
@@ -294,14 +296,14 @@ def _request_params_to_args(
294
296
295
297
296
298
def _request_body_to_args (
297
- required_params : List [ModelField ],
298
- received_body : Optional [ Dict [ str , Any ]] ,
299
- ) -> Tuple [ Dict [str , Any ], List [ Dict [str , Any ]]]:
299
+ required_params : list [ModelField ],
300
+ received_body : dict [ str , Any ] | None ,
301
+ ) -> tuple [ dict [str , Any ], list [ dict [str , Any ]]]:
300
302
"""
301
303
Convert the request body to a dictionary of values using validation, and returns a list of errors.
302
304
"""
303
- values : Dict [str , Any ] = {}
304
- errors : List [ Dict [str , Any ]] = []
305
+ values : dict [str , Any ] = {}
306
+ errors : list [ dict [str , Any ]] = []
305
307
306
308
received_body , field_alias_omitted = _get_embed_body (
307
309
field = required_params [0 ],
@@ -313,11 +315,11 @@ def _request_body_to_args(
313
315
# This sets the location to:
314
316
# { "user": { object } } if field.alias == user
315
317
# { { object } if field_alias is omitted
316
- loc : Tuple [str , ...] = ("body" , field .alias )
318
+ loc : tuple [str , ...] = ("body" , field .alias )
317
319
if field_alias_omitted :
318
320
loc = ("body" ,)
319
321
320
- value : Optional [ Any ] = None
322
+ value : Any | None = None
321
323
322
324
# Now that we know what to look for, try to get the value from the received body
323
325
if received_body is not None :
@@ -347,8 +349,8 @@ def _validate_field(
347
349
* ,
348
350
field : ModelField ,
349
351
value : Any ,
350
- loc : Tuple [str , ...],
351
- existing_errors : List [ Dict [str , Any ]],
352
+ loc : tuple [str , ...],
353
+ existing_errors : list [ dict [str , Any ]],
352
354
):
353
355
"""
354
356
Validate a field, and append any errors to the existing_errors list.
@@ -367,9 +369,9 @@ def _validate_field(
367
369
def _get_embed_body (
368
370
* ,
369
371
field : ModelField ,
370
- required_params : List [ModelField ],
371
- received_body : Optional [ Dict [ str , Any ]] ,
372
- ) -> Tuple [ Optional [ Dict [ str , Any ]] , bool ]:
372
+ required_params : list [ModelField ],
373
+ received_body : dict [ str , Any ] | None ,
374
+ ) -> tuple [ dict [ str , Any ] | None , bool ]:
373
375
field_info = field .field_info
374
376
embed = getattr (field_info , "embed" , None )
375
377
@@ -382,15 +384,15 @@ def _get_embed_body(
382
384
383
385
384
386
def _normalize_multi_query_string_with_param (
385
- query_string : Dict [str , List [str ]],
387
+ query_string : dict [str , list [str ]],
386
388
params : Sequence [ModelField ],
387
- ) -> Dict [str , Any ]:
389
+ ) -> dict [str , Any ]:
388
390
"""
389
391
Extract and normalize resolved_query_string_parameters
390
392
391
393
Parameters
392
394
----------
393
- query_string: Dict
395
+ query_string: dict
394
396
A dictionary containing the initial query string parameters.
395
397
params: Sequence[ModelField]
396
398
A sequence of ModelField objects representing parameters.
@@ -399,7 +401,7 @@ def _normalize_multi_query_string_with_param(
399
401
-------
400
402
A dictionary containing the processed multi_query_string_parameters.
401
403
"""
402
- resolved_query_string : Dict [str , Any ] = query_string
404
+ resolved_query_string : dict [str , Any ] = query_string
403
405
for param in filter (is_scalar_field , params ):
404
406
try :
405
407
# if the target parameter is a scalar, we keep the first value of the query string
@@ -416,7 +418,7 @@ def _normalize_multi_header_values_with_param(headers: MutableMapping[str, Any],
416
418
417
419
Parameters
418
420
----------
419
- headers: Dict
421
+ headers: MutableMapping[str, Any]
420
422
A dictionary containing the initial header parameters.
421
423
params: Sequence[ModelField]
422
424
A sequence of ModelField objects representing parameters.
0 commit comments