Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit dcd0d4d

Browse files
rubenfonsecaleandrodamascenaCavalcante Damascena
authoredOct 24, 2023
feat(event_handler): generate OpenAPI specifications and validate input/output (aws-powertools#3109)
* feat: generate OpenAPI spec from event handler * fix: resolver circular dependencies * fix: rebase * fix: document the new methods * fix: linter * fix: remove unneeded code * fix: reduce duplication * fix: types and sonarcube * chore: refactor complex function * fix: typing extensions * fix: tests * fix: mypy * fix: security baseline * feat: add simultaneous support for Pydantic v2 * fix: disable mypy and ruff on openapi compat * chore: add explanation to imports * chore: add first test * fix: test * fix: test * fix: don't require pydantic to run normal things * chore: added first tests * fix: refactored tests to remove code smell * fix: customize the handler methods * fix: tests * feat: add a validation middleware * fix: uniontype * fix: types * fix: ignore unused-ignore * fix: moved things around * fix: compatibility with pydantic v2 * chore: add tests on the body request * chore: add tests for validation middleware * fix: assorted fixes * fix: make tests pass in both pydantic versions * fix: remove assert * fix: complexity * fix: move Response class back * fix: more fix * fix: more fix * fix: one more fix * fix: refactor OpenAPI validation middleware * fix: refactor dependant.py * fix: beautify encoders * fix: move things around * fix: costmetic changes * fix: add more comments * fix: format * fix: cyclomatic * fix: change method of generating operation id * fix: allow validation in all resolvers * fix: use proper resolver in tests * fix: move from flake8 to ruff * fix: customizing responses * fix: add documentation to a method * fix: more explicit comments * fix: typo * fix: add extra comment * fix: comment * fix: add comments * fix: comments * fix: typo * fix: remove leftover comment * fix: addressing comments * fix: pydantic2 models * fix: typing extension problems * Adding more tests and fixing small things * Adding more tests and fixing small things * Adding more tests and fixing small things * Removing flaky tests * fix: improve coverage of encoders * fix: mark test as pydantic v1 only * fix: make sonarcube happy * fix: improve coverage of params.py * fix: add codecov.yml file to ignore compat.py * Increasing coverage --------- Signed-off-by: Leandro Damascena <[email protected]> Co-authored-by: Leandro Damascena <[email protected]> Co-authored-by: Cavalcante Damascena <[email protected]>
1 parent 14cb407 commit dcd0d4d

22 files changed

+4733
-31
lines changed
 

‎aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 769 additions & 25 deletions
Large diffs are not rendered by default.

‎aws_lambda_powertools/event_handler/lambda_function_url.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,13 @@ def __init__(
5252
debug: Optional[bool] = None,
5353
serializer: Optional[Callable[[Dict], str]] = None,
5454
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
55+
enable_validation: bool = False,
5556
):
56-
super().__init__(ProxyEventType.LambdaFunctionUrlEvent, cors, debug, serializer, strip_prefixes)
57+
super().__init__(
58+
ProxyEventType.LambdaFunctionUrlEvent,
59+
cors,
60+
debug,
61+
serializer,
62+
strip_prefixes,
63+
enable_validation,
64+
)
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import dataclasses
2+
import json
3+
import logging
4+
from copy import deepcopy
5+
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple
6+
7+
from pydantic import BaseModel
8+
9+
from aws_lambda_powertools.event_handler import Response
10+
from aws_lambda_powertools.event_handler.api_gateway import Route
11+
from aws_lambda_powertools.event_handler.middlewares import BaseMiddlewareHandler, NextMiddleware
12+
from aws_lambda_powertools.event_handler.openapi.compat import (
13+
ModelField,
14+
_model_dump,
15+
_normalize_errors,
16+
_regenerate_error_with_loc,
17+
get_missing_field_error,
18+
)
19+
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
21+
from aws_lambda_powertools.event_handler.openapi.params import Param
22+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
23+
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
24+
25+
logger = logging.getLogger(__name__)
26+
27+
28+
class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
29+
"""
30+
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
31+
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
32+
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
33+
34+
Examples
35+
--------
36+
37+
```python
38+
from typing import List
39+
40+
from pydantic import BaseModel
41+
42+
from aws_lambda_powertools.event_handler.api_gateway import (
43+
APIGatewayRestResolver,
44+
)
45+
46+
class Todo(BaseModel):
47+
name: str
48+
49+
app = APIGatewayRestResolver(enable_validation=True)
50+
51+
@app.get("/todos")
52+
def get_todos(): List[Todo]:
53+
return [Todo(name="hello world")]
54+
```
55+
"""
56+
57+
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
58+
logger.debug("OpenAPIValidationMiddleware handler")
59+
60+
route: Route = app.context["_route"]
61+
62+
values: Dict[str, Any] = {}
63+
errors: List[Any] = []
64+
65+
try:
66+
# Process path values, which can be found on the route_args
67+
path_values, path_errors = _request_params_to_args(
68+
route.dependant.path_params,
69+
app.context["_route_args"],
70+
)
71+
72+
# Process query values
73+
query_values, query_errors = _request_params_to_args(
74+
route.dependant.query_params,
75+
app.current_event.query_string_parameters or {},
76+
)
77+
78+
values.update(path_values)
79+
values.update(query_values)
80+
errors += path_errors + query_errors
81+
82+
# Process the request body, if it exists
83+
if route.dependant.body_params:
84+
(body_values, body_errors) = _request_body_to_args(
85+
required_params=route.dependant.body_params,
86+
received_body=self._get_body(app),
87+
)
88+
values.update(body_values)
89+
errors.extend(body_errors)
90+
91+
if errors:
92+
# Raise the validation errors
93+
raise RequestValidationError(_normalize_errors(errors))
94+
else:
95+
# Re-write the route_args with the validated values, and call the next middleware
96+
app.context["_route_args"] = values
97+
response = next_middleware(app)
98+
99+
# Process the response body if it exists
100+
raw_response = jsonable_encoder(response.body)
101+
102+
# Validate and serialize the response
103+
return self._serialize_response(field=route.dependant.return_param, response_content=raw_response)
104+
except RequestValidationError as e:
105+
return Response(
106+
status_code=422,
107+
content_type="application/json",
108+
body=json.dumps({"detail": e.errors()}),
109+
)
110+
111+
def _serialize_response(
112+
self,
113+
*,
114+
field: Optional[ModelField] = None,
115+
response_content: Any,
116+
include: Optional[IncEx] = None,
117+
exclude: Optional[IncEx] = None,
118+
by_alias: bool = True,
119+
exclude_unset: bool = False,
120+
exclude_defaults: bool = False,
121+
exclude_none: bool = False,
122+
) -> Any:
123+
"""
124+
Serialize the response content according to the field type.
125+
"""
126+
if field:
127+
errors: List[Dict[str, Any]] = []
128+
# MAINTENANCE: remove this when we drop pydantic v1
129+
if not hasattr(field, "serializable"):
130+
response_content = self._prepare_response_content(
131+
response_content,
132+
exclude_unset=exclude_unset,
133+
exclude_defaults=exclude_defaults,
134+
exclude_none=exclude_none,
135+
)
136+
137+
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
138+
if errors:
139+
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
140+
141+
if hasattr(field, "serialize"):
142+
return field.serialize(
143+
value,
144+
include=include,
145+
exclude=exclude,
146+
by_alias=by_alias,
147+
exclude_unset=exclude_unset,
148+
exclude_defaults=exclude_defaults,
149+
exclude_none=exclude_none,
150+
)
151+
152+
return jsonable_encoder(
153+
value,
154+
include=include,
155+
exclude=exclude,
156+
by_alias=by_alias,
157+
exclude_unset=exclude_unset,
158+
exclude_defaults=exclude_defaults,
159+
exclude_none=exclude_none,
160+
)
161+
else:
162+
# Just serialize the response content returned from the handler
163+
return jsonable_encoder(response_content)
164+
165+
def _prepare_response_content(
166+
self,
167+
res: Any,
168+
*,
169+
exclude_unset: bool,
170+
exclude_defaults: bool = False,
171+
exclude_none: bool = False,
172+
) -> Any:
173+
"""
174+
Prepares the response content for serialization.
175+
"""
176+
if isinstance(res, BaseModel):
177+
return _model_dump(
178+
res,
179+
by_alias=True,
180+
exclude_unset=exclude_unset,
181+
exclude_defaults=exclude_defaults,
182+
exclude_none=exclude_none,
183+
)
184+
elif isinstance(res, list):
185+
return [
186+
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
187+
for item in res
188+
]
189+
elif isinstance(res, dict):
190+
return {
191+
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
192+
for k, v in res.items()
193+
}
194+
elif dataclasses.is_dataclass(res):
195+
return dataclasses.asdict(res)
196+
return res
197+
198+
def _get_body(self, app: EventHandlerInstance) -> Dict[str, Any]:
199+
"""
200+
Get the request body from the event, and parse it as JSON.
201+
"""
202+
203+
content_type_value = app.current_event.get_header_value("content-type")
204+
if not content_type_value or content_type_value.startswith("application/json"):
205+
try:
206+
return app.current_event.json_body
207+
except json.JSONDecodeError as e:
208+
raise RequestValidationError(
209+
[
210+
{
211+
"type": "json_invalid",
212+
"loc": ("body", e.pos),
213+
"msg": "JSON decode error",
214+
"input": {},
215+
"ctx": {"error": e.msg},
216+
},
217+
],
218+
body=e.doc,
219+
) from e
220+
else:
221+
raise NotImplementedError("Only JSON body is supported")
222+
223+
224+
def _request_params_to_args(
225+
required_params: Sequence[ModelField],
226+
received_params: Mapping[str, Any],
227+
) -> Tuple[Dict[str, Any], List[Any]]:
228+
"""
229+
Convert the request params to a dictionary of values using validation, and returns a list of errors.
230+
"""
231+
values = {}
232+
errors = []
233+
234+
for field in required_params:
235+
value = received_params.get(field.alias)
236+
237+
field_info = field.field_info
238+
if not isinstance(field_info, Param):
239+
raise AssertionError(f"Expected Param field_info, got {field_info}")
240+
241+
loc = (field_info.in_.value, field.alias)
242+
243+
# If we don't have a value, see if it's required or has a default
244+
if value is None:
245+
if field.required:
246+
errors.append(get_missing_field_error(loc=loc))
247+
else:
248+
values[field.name] = deepcopy(field.default)
249+
continue
250+
251+
# Finally, validate the value
252+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
253+
254+
return values, errors
255+
256+
257+
def _request_body_to_args(
258+
required_params: List[ModelField],
259+
received_body: Optional[Dict[str, Any]],
260+
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
261+
"""
262+
Convert the request body to a dictionary of values using validation, and returns a list of errors.
263+
"""
264+
values: Dict[str, Any] = {}
265+
errors: List[Dict[str, Any]] = []
266+
267+
received_body, field_alias_omitted = _get_embed_body(
268+
field=required_params[0],
269+
required_params=required_params,
270+
received_body=received_body,
271+
)
272+
273+
for field in required_params:
274+
# This sets the location to:
275+
# { "user": { object } } if field.alias == user
276+
# { { object } if field_alias is omitted
277+
loc: Tuple[str, ...] = ("body", field.alias)
278+
if field_alias_omitted:
279+
loc = ("body",)
280+
281+
value: Optional[Any] = None
282+
283+
# Now that we know what to look for, try to get the value from the received body
284+
if received_body is not None:
285+
try:
286+
value = received_body.get(field.alias)
287+
except AttributeError:
288+
errors.append(get_missing_field_error(loc))
289+
continue
290+
291+
# Determine if the field is required
292+
if value is None:
293+
if field.required:
294+
errors.append(get_missing_field_error(loc))
295+
else:
296+
values[field.name] = deepcopy(field.default)
297+
continue
298+
299+
# MAINTENANCE: Handle byte and file fields
300+
301+
# Finally, validate the value
302+
values[field.name] = _validate_field(field=field, value=value, loc=loc, existing_errors=errors)
303+
304+
return values, errors
305+
306+
307+
def _validate_field(
308+
*,
309+
field: ModelField,
310+
value: Any,
311+
loc: Tuple[str, ...],
312+
existing_errors: List[Dict[str, Any]],
313+
):
314+
"""
315+
Validate a field, and append any errors to the existing_errors list.
316+
"""
317+
validated_value, errors = field.validate(value, value, loc=loc)
318+
319+
if isinstance(errors, list):
320+
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
321+
existing_errors.extend(processed_errors)
322+
elif errors:
323+
existing_errors.append(errors)
324+
325+
return validated_value
326+
327+
328+
def _get_embed_body(
329+
*,
330+
field: ModelField,
331+
required_params: List[ModelField],
332+
received_body: Optional[Dict[str, Any]],
333+
) -> Tuple[Optional[Dict[str, Any]], bool]:
334+
field_info = field.field_info
335+
embed = getattr(field_info, "embed", None)
336+
337+
# If the field is an embed, and the field alias is omitted, we need to wrap the received body in the field alias.
338+
field_alias_omitted = len(required_params) == 1 and not embed
339+
if field_alias_omitted:
340+
received_body = {field.alias: received_body}
341+
342+
return received_body, field_alias_omitted

‎aws_lambda_powertools/event_handler/openapi/__init__.py

Whitespace-only changes.

‎aws_lambda_powertools/event_handler/openapi/compat.py

Lines changed: 497 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
import inspect
2+
import re
3+
from typing import Any, Callable, Dict, ForwardRef, List, Optional, Set, Tuple, Type, cast
4+
5+
from pydantic import BaseModel
6+
7+
from aws_lambda_powertools.event_handler.openapi.compat import (
8+
ModelField,
9+
create_body_model,
10+
evaluate_forwardref,
11+
is_scalar_field,
12+
is_scalar_sequence_field,
13+
)
14+
from aws_lambda_powertools.event_handler.openapi.params import (
15+
Body,
16+
Dependant,
17+
File,
18+
Form,
19+
Header,
20+
Param,
21+
ParamTypes,
22+
Query,
23+
analyze_param,
24+
create_response_field,
25+
get_flat_dependant,
26+
)
27+
28+
"""
29+
This turns the opaque function signature into typed, validated models.
30+
31+
It relies on Pydantic's typing and validation to achieve this in a declarative way.
32+
This enables traits like autocompletion, validation, and declarative structure vs imperative parsing.
33+
34+
This code parses an OpenAPI operation handler function signature into Pydantic models. It uses inspect to get the
35+
signature and regex to parse path parameters. Each parameter is analyzed to extract its type annotation and generate
36+
a corresponding Pydantic field, which are added to a Dependant model. Return values are handled similarly.
37+
38+
This modeling allows for type checking, automatic parameter name/location/type extraction, and input validation -
39+
turning the opaque signature into validated models. It relies on Pydantic's typing and validation for a declarative
40+
approach over imperative parsing, enabling autocompletion, validation and structure.
41+
"""
42+
43+
44+
def add_param_to_fields(
45+
*,
46+
field: ModelField,
47+
dependant: Dependant,
48+
) -> None:
49+
"""
50+
Adds a parameter to the list of parameters in the dependant model.
51+
52+
Parameters
53+
----------
54+
field: ModelField
55+
The field to add
56+
dependant: Dependant
57+
The dependant model to add the field to
58+
59+
"""
60+
field_info = cast(Param, field.field_info)
61+
if field_info.in_ == ParamTypes.path:
62+
dependant.path_params.append(field)
63+
elif field_info.in_ == ParamTypes.query:
64+
dependant.query_params.append(field)
65+
elif field_info.in_ == ParamTypes.header:
66+
dependant.header_params.append(field)
67+
else:
68+
if field_info.in_ != ParamTypes.cookie:
69+
raise AssertionError(f"Unsupported param type: {field_info.in_}")
70+
dependant.cookie_params.append(field)
71+
72+
73+
def get_typed_annotation(annotation: Any, globalns: Dict[str, Any]) -> Any:
74+
"""
75+
Evaluates a type annotation, which can be a string or a ForwardRef.
76+
"""
77+
if isinstance(annotation, str):
78+
annotation = ForwardRef(annotation)
79+
annotation = evaluate_forwardref(annotation, globalns, globalns)
80+
return annotation
81+
82+
83+
def get_typed_signature(call: Callable[..., Any]) -> inspect.Signature:
84+
"""
85+
Returns a typed signature for a callable, resolving forward references.
86+
87+
Parameters
88+
----------
89+
call: Callable[..., Any]
90+
The callable to get the signature for
91+
92+
Returns
93+
-------
94+
inspect.Signature
95+
The typed signature
96+
"""
97+
signature = inspect.signature(call)
98+
99+
# Gets the global namespace for the call. This is used to resolve forward references.
100+
globalns = getattr(call, "__global__", {})
101+
102+
typed_params = [
103+
inspect.Parameter(
104+
name=param.name,
105+
kind=param.kind,
106+
default=param.default,
107+
annotation=get_typed_annotation(param.annotation, globalns),
108+
)
109+
for param in signature.parameters.values()
110+
]
111+
112+
# If the return annotation is not empty, add it to the signature.
113+
if signature.return_annotation is not inspect.Signature.empty:
114+
return_param = inspect.Parameter(
115+
name="Return",
116+
kind=inspect.Parameter.POSITIONAL_OR_KEYWORD,
117+
default=None,
118+
annotation=get_typed_annotation(signature.return_annotation, globalns),
119+
)
120+
return inspect.Signature(typed_params, return_annotation=return_param.annotation)
121+
else:
122+
return inspect.Signature(typed_params)
123+
124+
125+
def get_path_param_names(path: str) -> Set[str]:
126+
"""
127+
Returns the path parameter names from a path template. Those are the strings between < and >.
128+
129+
Parameters
130+
----------
131+
path: str
132+
The path template
133+
134+
Returns
135+
-------
136+
Set[str]
137+
The path parameter names
138+
139+
"""
140+
return set(re.findall("<(.*?)>", path))
141+
142+
143+
def get_dependant(
144+
*,
145+
path: str,
146+
call: Callable[..., Any],
147+
name: Optional[str] = None,
148+
) -> Dependant:
149+
"""
150+
Returns a dependant model for a handler function. A dependant model is a model that contains
151+
the parameters and return value of a handler function.
152+
153+
Parameters
154+
----------
155+
path: str
156+
The path template
157+
call: Callable[..., Any]
158+
The handler function
159+
name: str, optional
160+
The name of the handler function
161+
162+
Returns
163+
-------
164+
Dependant
165+
The dependant model for the handler function
166+
"""
167+
path_param_names = get_path_param_names(path)
168+
endpoint_signature = get_typed_signature(call)
169+
signature_params = endpoint_signature.parameters
170+
171+
dependant = Dependant(
172+
call=call,
173+
name=name,
174+
path=path,
175+
)
176+
177+
# Add each parameter to the dependant model
178+
for param_name, param in signature_params.items():
179+
# If the parameter is a path parameter, we need to set the in_ field to "path".
180+
is_path_param = param_name in path_param_names
181+
182+
# Analyze the parameter to get the Pydantic field.
183+
param_field = analyze_param(
184+
param_name=param_name,
185+
annotation=param.annotation,
186+
value=param.default,
187+
is_path_param=is_path_param,
188+
is_response_param=False,
189+
)
190+
if param_field is None:
191+
raise AssertionError(f"Parameter field is None for param: {param_name}")
192+
193+
if is_body_param(param_field=param_field, is_path_param=is_path_param):
194+
dependant.body_params.append(param_field)
195+
else:
196+
add_param_to_fields(field=param_field, dependant=dependant)
197+
198+
# If the return annotation is not empty, add it to the dependant model.
199+
return_annotation = endpoint_signature.return_annotation
200+
if return_annotation is not inspect.Signature.empty:
201+
param_field = analyze_param(
202+
param_name="return",
203+
annotation=return_annotation,
204+
value=None,
205+
is_path_param=False,
206+
is_response_param=True,
207+
)
208+
if param_field is None:
209+
raise AssertionError("Param field is None for return annotation")
210+
211+
dependant.return_param = param_field
212+
213+
return dependant
214+
215+
216+
def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
217+
"""
218+
Returns whether a parameter is a request body parameter, by checking if it is a scalar field or a body field.
219+
220+
Parameters
221+
----------
222+
param_field: ModelField
223+
The parameter field
224+
is_path_param: bool
225+
Whether the parameter is a path parameter
226+
227+
Returns
228+
-------
229+
bool
230+
Whether the parameter is a request body parameter
231+
"""
232+
if is_path_param:
233+
if not is_scalar_field(field=param_field):
234+
raise AssertionError("Path params must be of one of the supported types")
235+
return False
236+
elif is_scalar_field(field=param_field):
237+
return False
238+
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
239+
return False
240+
else:
241+
if not isinstance(param_field.field_info, Body):
242+
raise AssertionError(f"Param: {param_field.name} can only be a request body, use Body()")
243+
return True
244+
245+
246+
def get_flat_params(dependant: Dependant) -> List[ModelField]:
247+
"""
248+
Get a list of all the parameters from a Dependant object.
249+
250+
Parameters
251+
----------
252+
dependant : Dependant
253+
The Dependant object containing the parameters.
254+
255+
Returns
256+
-------
257+
List[ModelField]
258+
A list of ModelField objects containing the flat parameters from the Dependant object.
259+
260+
"""
261+
flat_dependant = get_flat_dependant(dependant)
262+
return (
263+
flat_dependant.path_params
264+
+ flat_dependant.query_params
265+
+ flat_dependant.header_params
266+
+ flat_dependant.cookie_params
267+
)
268+
269+
270+
def get_body_field(*, dependant: Dependant, name: str) -> Optional[ModelField]:
271+
"""
272+
Get the Body field for a given Dependant object.
273+
"""
274+
275+
flat_dependant = get_flat_dependant(dependant)
276+
if not flat_dependant.body_params:
277+
return None
278+
279+
first_param = flat_dependant.body_params[0]
280+
field_info = first_param.field_info
281+
282+
# Handle the case where there is only one body parameter and it is embedded
283+
embed = getattr(field_info, "embed", None)
284+
body_param_names_set = {param.name for param in flat_dependant.body_params}
285+
if len(body_param_names_set) == 1 and not embed:
286+
return first_param
287+
288+
# If one field requires to embed, all have to be embedded
289+
for param in flat_dependant.body_params:
290+
setattr(param.field_info, "embed", True) # noqa: B010
291+
292+
# Generate a custom body model for this endpoint
293+
model_name = "Body_" + name
294+
body_model = create_body_model(fields=flat_dependant.body_params, model_name=model_name)
295+
296+
required = any(True for f in flat_dependant.body_params if f.required)
297+
298+
body_field_info, body_field_info_kwargs = get_body_field_info(
299+
body_model=body_model,
300+
flat_dependant=flat_dependant,
301+
required=required,
302+
)
303+
304+
final_field = create_response_field(
305+
name="body",
306+
type_=body_model,
307+
required=required,
308+
alias="body",
309+
field_info=body_field_info(**body_field_info_kwargs),
310+
)
311+
return final_field
312+
313+
314+
def get_body_field_info(
315+
*,
316+
body_model: Type[BaseModel],
317+
flat_dependant: Dependant,
318+
required: bool,
319+
) -> Tuple[Type[Body], Dict[str, Any]]:
320+
"""
321+
Get the Body field info and kwargs for a given body model.
322+
"""
323+
324+
body_field_info_kwargs: Dict[str, Any] = {"annotation": body_model, "alias": "body"}
325+
326+
if not required:
327+
body_field_info_kwargs["default"] = None
328+
329+
if any(isinstance(f.field_info, File) for f in flat_dependant.body_params):
330+
body_field_info: Type[Body] = File
331+
elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params):
332+
body_field_info = Form
333+
else:
334+
body_field_info = Body
335+
336+
body_param_media_types = [
337+
f.field_info.media_type for f in flat_dependant.body_params if isinstance(f.field_info, Body)
338+
]
339+
if len(set(body_param_media_types)) == 1:
340+
body_field_info_kwargs["media_type"] = body_param_media_types[0]
341+
342+
return body_field_info, body_field_info_kwargs
Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import dataclasses
2+
import datetime
3+
from collections import defaultdict, deque
4+
from decimal import Decimal
5+
from enum import Enum
6+
from pathlib import Path, PurePath
7+
from re import Pattern
8+
from types import GeneratorType
9+
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
10+
from uuid import UUID
11+
12+
from pydantic import BaseModel
13+
from pydantic.color import Color
14+
from pydantic.types import SecretBytes, SecretStr
15+
16+
from aws_lambda_powertools.event_handler.openapi.compat import _model_dump
17+
from aws_lambda_powertools.event_handler.openapi.types import IncEx
18+
19+
"""
20+
This module contains the encoders used by jsonable_encoder to convert Python objects to JSON serializable data types.
21+
"""
22+
23+
24+
def jsonable_encoder( # noqa: PLR0911
25+
obj: Any,
26+
include: Optional[IncEx] = None,
27+
exclude: Optional[IncEx] = None,
28+
by_alias: bool = True,
29+
exclude_unset: bool = False,
30+
exclude_defaults: bool = False,
31+
exclude_none: bool = False,
32+
) -> Any:
33+
"""
34+
JSON encodes an arbitrary Python object into JSON serializable data types.
35+
36+
This is a modified version of fastapi.encoders.jsonable_encoder that supports
37+
encoding of pydantic.BaseModel objects.
38+
39+
Parameters
40+
----------
41+
obj : Any
42+
The object to encode
43+
include : Optional[IncEx], optional
44+
A set or dictionary of strings that specifies which properties should be included, by default None,
45+
meaning everything is included
46+
exclude : Optional[IncEx], optional
47+
A set or dictionary of strings that specifies which properties should be excluded, by default None,
48+
meaning nothing is excluded
49+
by_alias : bool, optional
50+
Whether field aliases should be respected, by default True
51+
exclude_unset : bool, optional
52+
Whether fields that are not set should be excluded, by default False
53+
exclude_defaults : bool, optional
54+
Whether fields that are equal to their default value (as specified in the model) should be excluded,
55+
by default False
56+
exclude_none : bool, optional
57+
Whether fields that are equal to None should be excluded, by default False
58+
59+
Returns
60+
-------
61+
Any
62+
The JSON serializable data types
63+
"""
64+
if include is not None and not isinstance(include, (set, dict)):
65+
include = set(include)
66+
if exclude is not None and not isinstance(exclude, (set, dict)):
67+
exclude = set(exclude)
68+
69+
# Pydantic models
70+
if isinstance(obj, BaseModel):
71+
return _dump_base_model(
72+
obj=obj,
73+
include=include,
74+
exclude=exclude,
75+
by_alias=by_alias,
76+
exclude_unset=exclude_unset,
77+
exclude_none=exclude_none,
78+
exclude_defaults=exclude_defaults,
79+
)
80+
81+
# Dataclasses
82+
if dataclasses.is_dataclass(obj):
83+
obj_dict = dataclasses.asdict(obj)
84+
return jsonable_encoder(
85+
obj_dict,
86+
include=include,
87+
exclude=exclude,
88+
by_alias=by_alias,
89+
exclude_unset=exclude_unset,
90+
exclude_defaults=exclude_defaults,
91+
exclude_none=exclude_none,
92+
)
93+
94+
# Enums
95+
if isinstance(obj, Enum):
96+
return obj.value
97+
98+
# Paths
99+
if isinstance(obj, PurePath):
100+
return str(obj)
101+
102+
# Scalars
103+
if isinstance(obj, (str, int, float, type(None))):
104+
return obj
105+
106+
# Dictionaries
107+
if isinstance(obj, dict):
108+
return _dump_dict(
109+
obj=obj,
110+
include=include,
111+
exclude=exclude,
112+
by_alias=by_alias,
113+
exclude_none=exclude_none,
114+
exclude_unset=exclude_unset,
115+
)
116+
117+
# Sequences
118+
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple, deque)):
119+
return _dump_sequence(
120+
obj=obj,
121+
include=include,
122+
exclude=exclude,
123+
by_alias=by_alias,
124+
exclude_none=exclude_none,
125+
exclude_defaults=exclude_defaults,
126+
exclude_unset=exclude_unset,
127+
)
128+
129+
# Other types
130+
if type(obj) in ENCODERS_BY_TYPE:
131+
return ENCODERS_BY_TYPE[type(obj)](obj)
132+
133+
for encoder, classes_tuple in encoders_by_class_tuples.items():
134+
if isinstance(obj, classes_tuple):
135+
return encoder(obj)
136+
137+
# Default
138+
return _dump_other(
139+
obj=obj,
140+
include=include,
141+
exclude=exclude,
142+
by_alias=by_alias,
143+
exclude_none=exclude_none,
144+
exclude_unset=exclude_unset,
145+
exclude_defaults=exclude_defaults,
146+
)
147+
148+
149+
def _dump_base_model(
150+
*,
151+
obj: Any,
152+
include: Optional[IncEx] = None,
153+
exclude: Optional[IncEx] = None,
154+
by_alias: bool = True,
155+
exclude_unset: bool = False,
156+
exclude_none: bool = False,
157+
exclude_defaults: bool = False,
158+
):
159+
"""
160+
Dump a BaseModel object to a dict, using the same parameters as jsonable_encoder
161+
"""
162+
obj_dict = _model_dump(
163+
obj,
164+
mode="json",
165+
include=include,
166+
exclude=exclude,
167+
by_alias=by_alias,
168+
exclude_unset=exclude_unset,
169+
exclude_none=exclude_none,
170+
exclude_defaults=exclude_defaults,
171+
)
172+
if "__root__" in obj_dict:
173+
obj_dict = obj_dict["__root__"]
174+
175+
return jsonable_encoder(
176+
obj_dict,
177+
exclude_none=exclude_none,
178+
exclude_defaults=exclude_defaults,
179+
)
180+
181+
182+
def _dump_dict(
183+
*,
184+
obj: Any,
185+
include: Optional[IncEx] = None,
186+
exclude: Optional[IncEx] = None,
187+
by_alias: bool = True,
188+
exclude_unset: bool = False,
189+
exclude_none: bool = False,
190+
) -> Dict[str, Any]:
191+
"""
192+
Dump a dict to a dict, using the same parameters as jsonable_encoder
193+
"""
194+
encoded_dict = {}
195+
allowed_keys = set(obj.keys())
196+
if include is not None:
197+
allowed_keys &= set(include)
198+
if exclude is not None:
199+
allowed_keys -= set(exclude)
200+
for key, value in obj.items():
201+
if (
202+
(not isinstance(key, str) or not key.startswith("_sa"))
203+
and (value is not None or not exclude_none)
204+
and key in allowed_keys
205+
):
206+
encoded_key = jsonable_encoder(
207+
key,
208+
by_alias=by_alias,
209+
exclude_unset=exclude_unset,
210+
exclude_none=exclude_none,
211+
)
212+
encoded_value = jsonable_encoder(
213+
value,
214+
by_alias=by_alias,
215+
exclude_unset=exclude_unset,
216+
exclude_none=exclude_none,
217+
)
218+
encoded_dict[encoded_key] = encoded_value
219+
return encoded_dict
220+
221+
222+
def _dump_sequence(
223+
*,
224+
obj: Any,
225+
include: Optional[IncEx] = None,
226+
exclude: Optional[IncEx] = None,
227+
by_alias: bool = True,
228+
exclude_unset: bool = False,
229+
exclude_none: bool = False,
230+
exclude_defaults: bool = False,
231+
) -> List[Any]:
232+
"""
233+
Dump a sequence to a list, using the same parameters as jsonable_encoder
234+
"""
235+
encoded_list = []
236+
for item in obj:
237+
encoded_list.append(
238+
jsonable_encoder(
239+
item,
240+
include=include,
241+
exclude=exclude,
242+
by_alias=by_alias,
243+
exclude_unset=exclude_unset,
244+
exclude_defaults=exclude_defaults,
245+
exclude_none=exclude_none,
246+
),
247+
)
248+
return encoded_list
249+
250+
251+
def _dump_other(
252+
*,
253+
obj: Any,
254+
include: Optional[IncEx] = None,
255+
exclude: Optional[IncEx] = None,
256+
by_alias: bool = True,
257+
exclude_unset: bool = False,
258+
exclude_none: bool = False,
259+
exclude_defaults: bool = False,
260+
) -> Any:
261+
"""
262+
Dump an object to ah hashable object, using the same parameters as jsonable_encoder
263+
"""
264+
try:
265+
data = dict(obj)
266+
except Exception as e:
267+
errors: List[Exception] = [e]
268+
try:
269+
data = vars(obj)
270+
except Exception as e:
271+
errors.append(e)
272+
raise ValueError(errors) from e
273+
return jsonable_encoder(
274+
data,
275+
include=include,
276+
exclude=exclude,
277+
by_alias=by_alias,
278+
exclude_unset=exclude_unset,
279+
exclude_defaults=exclude_defaults,
280+
exclude_none=exclude_none,
281+
)
282+
283+
284+
def iso_format(o: Union[datetime.date, datetime.time]) -> str:
285+
"""
286+
ISO format for date and time
287+
"""
288+
return o.isoformat()
289+
290+
291+
def decimal_encoder(dec_value: Decimal) -> Union[int, float]:
292+
"""
293+
Encodes a Decimal as int of there's no exponent, otherwise float
294+
295+
This is useful when we use ConstrainedDecimal to represent Numeric(x,0)
296+
where an integer (but not int typed) is used. Encoding this as a float
297+
results in failed round-tripping between encode and parse.
298+
299+
>>> decimal_encoder(Decimal("1.0"))
300+
1.0
301+
302+
>>> decimal_encoder(Decimal("1"))
303+
1
304+
"""
305+
if dec_value.as_tuple().exponent >= 0: # type: ignore[operator]
306+
return int(dec_value)
307+
else:
308+
return float(dec_value)
309+
310+
311+
# Encoders for types that are not JSON serializable
312+
ENCODERS_BY_TYPE: Dict[Type[Any], Callable[[Any], Any]] = {
313+
bytes: lambda o: o.decode(),
314+
Color: str,
315+
datetime.date: iso_format,
316+
datetime.datetime: iso_format,
317+
datetime.time: iso_format,
318+
datetime.timedelta: lambda td: td.total_seconds(),
319+
Decimal: decimal_encoder,
320+
Enum: lambda o: o.value,
321+
frozenset: list,
322+
deque: list,
323+
GeneratorType: list,
324+
Path: str,
325+
Pattern: lambda o: o.pattern,
326+
SecretBytes: str,
327+
SecretStr: str,
328+
set: list,
329+
UUID: str,
330+
}
331+
332+
333+
# Generates a mapping of encoders to a tuple of classes that they can encode
334+
def generate_encoders_by_class_tuples(
335+
type_encoder_map: Dict[Any, Callable[[Any], Any]],
336+
) -> Dict[Callable[[Any], Any], Tuple[Any, ...]]:
337+
encoders: Dict[Callable[[Any], Any], Tuple[Any, ...]] = defaultdict(tuple)
338+
for type_, encoder in type_encoder_map.items():
339+
encoders[encoder] += (type_,)
340+
return encoders
341+
342+
343+
# Mapping of encoders to a tuple of classes that they can encode
344+
encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE)
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Any, Sequence
2+
3+
4+
class ValidationException(Exception):
5+
"""
6+
Base exception for all validation errors
7+
"""
8+
9+
def __init__(self, errors: Sequence[Any]) -> None:
10+
self._errors = errors
11+
12+
def errors(self) -> Sequence[Any]:
13+
return self._errors
14+
15+
16+
class RequestValidationError(ValidationException):
17+
"""
18+
Raised when the request body does not match the OpenAPI schema
19+
"""
20+
21+
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
22+
super().__init__(errors)
23+
self.body = body

‎aws_lambda_powertools/event_handler/openapi/models.py

Lines changed: 583 additions & 0 deletions
Large diffs are not rendered by default.

‎aws_lambda_powertools/event_handler/openapi/params.py

Lines changed: 841 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import types
2+
from enum import Enum
3+
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Set, Type, Union
4+
5+
if TYPE_CHECKING:
6+
from pydantic import BaseModel # noqa: F401
7+
8+
CacheKey = Optional[Callable[..., Any]]
9+
IncEx = Union[Set[int], Set[str], Dict[int, Any], Dict[str, Any]]
10+
ModelNameMap = Dict[Union[Type["BaseModel"], Type[Enum]], str]
11+
TypeModelOrEnum = Union[Type["BaseModel"], Type[Enum]]
12+
UnionType = getattr(types, "UnionType", Union)
13+
14+
15+
COMPONENT_REF_PREFIX = "#/components/schemas/"
16+
COMPONENT_REF_TEMPLATE = "#/components/schemas/{model}"
17+
METHODS_WITH_BODY = {"GET", "HEAD", "POST", "PUT", "DELETE", "PATCH"}
18+
19+
try:
20+
from pydantic.version import VERSION as PYDANTIC_VERSION
21+
22+
PYDANTIC_V2 = PYDANTIC_VERSION.startswith("2.")
23+
except ImportError:
24+
PYDANTIC_V2 = False
25+
26+
27+
validation_error_definition = {
28+
"title": "ValidationError",
29+
"type": "object",
30+
"properties": {
31+
"loc": {
32+
"title": "Location",
33+
"type": "array",
34+
"items": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
35+
},
36+
"msg": {"title": "Message", "type": "string"},
37+
"type": {"title": "Error Type", "type": "string"},
38+
},
39+
"required": ["loc", "msg", "type"],
40+
}
41+
42+
validation_error_response_definition = {
43+
"title": "HTTPValidationError",
44+
"type": "object",
45+
"properties": {
46+
"detail": {
47+
"title": "Detail",
48+
"type": "array",
49+
"items": {"$ref": COMPONENT_REF_PREFIX + "ValidationError"},
50+
},
51+
},
52+
}

‎aws_lambda_powertools/event_handler/vpc_lattice.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,10 @@ def __init__(
4848
debug: Optional[bool] = None,
4949
serializer: Optional[Callable[[Dict], str]] = None,
5050
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
51+
enable_validation: bool = False,
5152
):
5253
"""Amazon VPC Lattice resolver"""
53-
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes)
54+
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)
5455

5556

5657
class VPCLatticeV2Resolver(ApiGatewayResolver):
@@ -93,6 +94,7 @@ def __init__(
9394
debug: Optional[bool] = None,
9495
serializer: Optional[Callable[[Dict], str]] = None,
9596
strip_prefixes: Optional[List[Union[str, Pattern]]] = None,
97+
enable_validation: bool = False,
9698
):
9799
"""Amazon VPC Lattice resolver"""
98-
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes)
100+
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)

‎aws_lambda_powertools/shared/types.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,26 @@
66
else:
77
from typing_extensions import Literal, Protocol, TypedDict
88

9+
if sys.version_info >= (3, 9):
10+
from typing import Annotated
11+
else:
12+
from typing_extensions import Annotated
913

1014
if sys.version_info >= (3, 11):
1115
from typing import NotRequired
1216
else:
1317
from typing_extensions import NotRequired
1418

1519

20+
# Even though `get_args` and `get_origin` were added in Python 3.8, they only handle Annotated correctly on 3.10.
21+
# So for python < 3.10 we use the backport from typing_extensions.
1622
if sys.version_info >= (3, 10):
17-
from typing import TypeAlias
23+
from typing import TypeAlias, get_args, get_origin
1824
else:
19-
from typing_extensions import TypeAlias
25+
from typing_extensions import TypeAlias, get_args, get_origin
2026

2127
AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001
2228
# JSON primitives only, mypy doesn't support recursive tho
2329
JSONType = Union[str, int, float, bool, None, Dict[str, Any], List[Any]]
2430

25-
__all__ = ["Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"]
31+
__all__ = ["get_args", "get_origin", "Annotated", "Protocol", "TypedDict", "Literal", "NotRequired", "TypeAlias"]

‎codecov.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
ignore:
2+
- "aws_lambda_powertools/event_handler/openapi/compat.py"

‎pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ exclude = '''
169169
| buck-out
170170
| build
171171
| dist
172+
| aws_lambda_powertools/event_handler/openapi/compat.py
172173
)/
173174
| example
174175
)

‎ruff.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,5 +87,6 @@ split-on-trailing-comma = true
8787
"tests/e2e/utils/data_fetcher/__init__.py" = ["F401"]
8888
"aws_lambda_powertools/utilities/data_classes/s3_event.py" = ["A003"]
8989
"aws_lambda_powertools/utilities/parser/models/__init__.py" = ["E402"]
90+
"aws_lambda_powertools/event_handler/openapi/compat.py" = ["F401"]
9091
# Maintenance: we're keeping EphemeralMetrics code in case of Hyrum's law so we can quickly revert it
9192
"aws_lambda_powertools/metrics/metrics.py" = ["ERA001"]
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import math
2+
from dataclasses import dataclass
3+
from typing import List
4+
5+
import pytest
6+
from pydantic import BaseModel
7+
from pydantic.color import Color
8+
9+
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
10+
11+
12+
@pytest.fixture
13+
def pydanticv1_only():
14+
from pydantic import __version__
15+
16+
version = __version__.split(".")
17+
if version[0] != "1":
18+
pytest.skip("pydanticv1 test only")
19+
20+
21+
def test_openapi_encode_include():
22+
class User(BaseModel):
23+
name: str
24+
age: int
25+
26+
result = jsonable_encoder(User(name="John", age=20), include=["name"])
27+
assert result == {"name": "John"}
28+
29+
30+
def test_openapi_encode_exclude():
31+
class User(BaseModel):
32+
name: str
33+
age: int
34+
35+
result = jsonable_encoder(User(name="John", age=20), exclude=["age"])
36+
assert result == {"name": "John"}
37+
38+
39+
def test_openapi_encode_pydantic():
40+
class Order(BaseModel):
41+
quantity: int
42+
43+
class User(BaseModel):
44+
name: str
45+
order: Order
46+
47+
result = jsonable_encoder(User(name="John", order=Order(quantity=2)))
48+
assert result == {"name": "John", "order": {"quantity": 2}}
49+
50+
51+
@pytest.mark.usefixtures("pydanticv1_only")
52+
def test_openapi_encode_pydantic_root_types():
53+
class User(BaseModel):
54+
__root__: List[str]
55+
56+
result = jsonable_encoder(User(__root__=["John", "Jane"]))
57+
assert result == ["John", "Jane"]
58+
59+
60+
def test_openapi_encode_dataclass():
61+
@dataclass
62+
class Order:
63+
quantity: int
64+
65+
@dataclass
66+
class User:
67+
name: str
68+
order: Order
69+
70+
result = jsonable_encoder(User(name="John", order=Order(quantity=2)))
71+
assert result == {"name": "John", "order": {"quantity": 2}}
72+
73+
74+
def test_openapi_encode_enum():
75+
from enum import Enum
76+
77+
class Color(Enum):
78+
RED = "red"
79+
GREEN = "green"
80+
BLUE = "blue"
81+
82+
result = jsonable_encoder(Color.RED)
83+
assert result == "red"
84+
85+
86+
def test_openapi_encode_purepath():
87+
from pathlib import PurePath
88+
89+
result = jsonable_encoder(PurePath("/foo/bar"))
90+
assert result == "/foo/bar"
91+
92+
93+
def test_openapi_encode_scalars():
94+
result = jsonable_encoder("foo")
95+
assert result == "foo"
96+
97+
result = jsonable_encoder(1)
98+
assert result == 1
99+
100+
result = jsonable_encoder(1.0)
101+
assert math.isclose(result, 1.0)
102+
103+
result = jsonable_encoder(True)
104+
assert result is True
105+
106+
result = jsonable_encoder(None)
107+
assert result is None
108+
109+
110+
def test_openapi_encode_dict():
111+
result = jsonable_encoder({"foo": "bar"})
112+
assert result == {"foo": "bar"}
113+
114+
115+
def test_openapi_encode_dict_with_include():
116+
result = jsonable_encoder({"foo": "bar", "bar": "foo"}, include=["foo"])
117+
assert result == {"foo": "bar"}
118+
119+
120+
def test_openapi_encode_dict_with_exclude():
121+
result = jsonable_encoder({"foo": "bar", "bar": "foo"}, exclude=["bar"])
122+
assert result == {"foo": "bar"}
123+
124+
125+
def test_openapi_encode_sequences():
126+
result = jsonable_encoder(["foo", "bar"])
127+
assert result == ["foo", "bar"]
128+
129+
result = jsonable_encoder(("foo", "bar"))
130+
assert result == ["foo", "bar"]
131+
132+
result = jsonable_encoder({"foo", "bar"})
133+
assert set(result) == {"foo", "bar"}
134+
135+
result = jsonable_encoder(frozenset(("foo", "bar")))
136+
assert set(result) == {"foo", "bar"}
137+
138+
139+
def test_openapi_encode_bytes():
140+
result = jsonable_encoder(b"foo")
141+
assert result == "foo"
142+
143+
144+
def test_openapi_encode_timedelta():
145+
from datetime import timedelta
146+
147+
result = jsonable_encoder(timedelta(seconds=1))
148+
assert result == 1
149+
150+
151+
def test_openapi_encode_decimal():
152+
from decimal import Decimal
153+
154+
result = jsonable_encoder(Decimal("1.0"))
155+
assert math.isclose(result, 1.0)
156+
157+
result = jsonable_encoder(Decimal("1"))
158+
assert result == 1
159+
160+
161+
def test_openapi_encode_uuid():
162+
from uuid import UUID
163+
164+
result = jsonable_encoder(UUID("123e4567-e89b-12d3-a456-426614174000"))
165+
assert result == "123e4567-e89b-12d3-a456-426614174000"
166+
167+
168+
def test_openapi_encode_encodable():
169+
from datetime import date, datetime, time
170+
171+
result = jsonable_encoder(date(2021, 1, 1))
172+
assert result == "2021-01-01"
173+
174+
result = jsonable_encoder(datetime(2021, 1, 1, 0, 0, 0))
175+
assert result == "2021-01-01T00:00:00"
176+
177+
result = jsonable_encoder(time(0, 0, 0))
178+
assert result == "00:00:00"
179+
180+
181+
def test_openapi_encode_subclasses():
182+
class MyColor(Color):
183+
pass
184+
185+
result = jsonable_encoder(MyColor("red"))
186+
assert result == "red"
187+
188+
189+
def test_openapi_encode_other():
190+
class User:
191+
def __init__(self, name: str):
192+
self.name = name
193+
194+
result = jsonable_encoder(User(name="John"))
195+
assert result == {"name": "John"}
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
from dataclasses import dataclass
2+
from datetime import datetime
3+
from typing import List
4+
5+
from pydantic import BaseModel
6+
7+
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
8+
from aws_lambda_powertools.event_handler.openapi.models import (
9+
Example,
10+
Parameter,
11+
ParameterInType,
12+
Schema,
13+
)
14+
from aws_lambda_powertools.event_handler.openapi.params import (
15+
Body,
16+
Header,
17+
Param,
18+
ParamTypes,
19+
Query,
20+
_create_model_field,
21+
)
22+
from aws_lambda_powertools.shared.types import Annotated
23+
24+
JSON_CONTENT_TYPE = "application/json"
25+
26+
27+
def test_openapi_no_params():
28+
app = APIGatewayRestResolver()
29+
30+
@app.get("/")
31+
def handler():
32+
raise NotImplementedError()
33+
34+
schema = app.get_openapi_schema()
35+
assert schema.info.title == "Powertools API"
36+
assert schema.info.version == "1.0.0"
37+
38+
assert len(schema.paths.keys()) == 1
39+
assert "/" in schema.paths
40+
41+
path = schema.paths["/"]
42+
assert path.get
43+
44+
get = path.get
45+
assert get.summary == "GET /"
46+
assert get.operationId == "handler__get"
47+
48+
assert get.responses is not None
49+
assert 200 in get.responses.keys()
50+
response = get.responses[200]
51+
assert response.description == "Successful Response"
52+
53+
assert JSON_CONTENT_TYPE in response.content
54+
json_response = response.content[JSON_CONTENT_TYPE]
55+
assert json_response.schema_ == Schema()
56+
assert not json_response.examples
57+
assert not json_response.encoding
58+
59+
60+
def test_openapi_with_scalar_params():
61+
app = APIGatewayRestResolver()
62+
63+
@app.get("/users/<user_id>")
64+
def handler(user_id: str, include_extra: bool = False):
65+
raise NotImplementedError()
66+
67+
schema = app.get_openapi_schema(title="My API", version="0.2.2")
68+
assert schema.info.title == "My API"
69+
assert schema.info.version == "0.2.2"
70+
71+
assert len(schema.paths.keys()) == 1
72+
assert "/users/<user_id>" in schema.paths
73+
74+
path = schema.paths["/users/<user_id>"]
75+
assert path.get
76+
77+
get = path.get
78+
assert get.summary == "GET /users/<user_id>"
79+
assert get.operationId == "handler_users__user_id__get"
80+
assert len(get.parameters) == 2
81+
82+
parameter = get.parameters[0]
83+
assert isinstance(parameter, Parameter)
84+
assert parameter.in_ == ParameterInType.path
85+
assert parameter.name == "user_id"
86+
assert parameter.required is True
87+
assert parameter.schema_.default is None
88+
assert parameter.schema_.type == "string"
89+
assert parameter.schema_.title == "User Id"
90+
91+
parameter = get.parameters[1]
92+
assert isinstance(parameter, Parameter)
93+
assert parameter.in_ == ParameterInType.query
94+
assert parameter.name == "include_extra"
95+
assert parameter.required is False
96+
assert parameter.schema_.default is False
97+
assert parameter.schema_.type == "boolean"
98+
assert parameter.schema_.title == "Include Extra"
99+
100+
101+
def test_openapi_with_custom_params():
102+
app = APIGatewayRestResolver()
103+
104+
@app.get("/users", summary="Get Users", operation_id="GetUsers", description="Get paginated users", tags=["Users"])
105+
def handler(
106+
count: Annotated[
107+
int,
108+
Query(gt=0, lt=100, examples=[Example(summary="Example 1", value=10)]),
109+
] = 1,
110+
):
111+
print(count)
112+
raise NotImplementedError()
113+
114+
schema = app.get_openapi_schema()
115+
116+
get = schema.paths["/users"].get
117+
assert len(get.parameters) == 1
118+
assert get.summary == "Get Users"
119+
assert get.operationId == "GetUsers"
120+
assert get.description == "Get paginated users"
121+
assert get.tags == ["Users"]
122+
123+
parameter = get.parameters[0]
124+
assert parameter.required is False
125+
assert parameter.name == "count"
126+
assert parameter.in_ == ParameterInType.query
127+
assert parameter.schema_.type == "integer"
128+
assert parameter.schema_.default == 1
129+
assert parameter.schema_.title == "Count"
130+
assert parameter.schema_.exclusiveMinimum == 0
131+
assert parameter.schema_.exclusiveMaximum == 100
132+
assert len(parameter.schema_.examples) == 1
133+
assert parameter.schema_.examples[0].summary == "Example 1"
134+
assert parameter.schema_.examples[0].value == 10
135+
136+
137+
def test_openapi_with_scalar_returns():
138+
app = APIGatewayRestResolver()
139+
140+
@app.get("/")
141+
def handler() -> str:
142+
return "Hello, world"
143+
144+
schema = app.get_openapi_schema()
145+
assert len(schema.paths.keys()) == 1
146+
147+
get = schema.paths["/"].get
148+
assert get.parameters is None
149+
150+
response = get.responses[200].content[JSON_CONTENT_TYPE]
151+
assert response.schema_.title == "Return"
152+
assert response.schema_.type == "string"
153+
154+
155+
def test_openapi_with_pydantic_returns():
156+
app = APIGatewayRestResolver()
157+
158+
class User(BaseModel):
159+
name: str
160+
161+
@app.get("/")
162+
def handler() -> User:
163+
return User(name="Ruben Fonseca")
164+
165+
schema = app.get_openapi_schema()
166+
assert len(schema.paths.keys()) == 1
167+
168+
get = schema.paths["/"].get
169+
assert get.parameters is None
170+
171+
response = get.responses[200].content[JSON_CONTENT_TYPE]
172+
reference = response.schema_
173+
assert reference.ref == "#/components/schemas/User"
174+
175+
assert "User" in schema.components.schemas
176+
user_schema = schema.components.schemas["User"]
177+
assert isinstance(user_schema, Schema)
178+
assert user_schema.title == "User"
179+
assert "name" in user_schema.properties
180+
181+
182+
def test_openapi_with_pydantic_nested_returns():
183+
app = APIGatewayRestResolver()
184+
185+
class Order(BaseModel):
186+
date: datetime
187+
188+
class User(BaseModel):
189+
name: str
190+
orders: List[Order]
191+
192+
@app.get("/")
193+
def handler() -> User:
194+
return User(name="Ruben Fonseca", orders=[Order(date=datetime.now())])
195+
196+
schema = app.get_openapi_schema()
197+
assert len(schema.paths.keys()) == 1
198+
199+
assert "User" in schema.components.schemas
200+
assert "Order" in schema.components.schemas
201+
202+
user_schema = schema.components.schemas["User"]
203+
assert "orders" in user_schema.properties
204+
assert user_schema.properties["orders"].type == "array"
205+
206+
207+
def test_openapi_with_dataclass_return():
208+
app = APIGatewayRestResolver()
209+
210+
@dataclass
211+
class User:
212+
surname: str
213+
214+
@app.get("/")
215+
def handler() -> User:
216+
return User(surname="Fonseca")
217+
218+
schema = app.get_openapi_schema()
219+
assert len(schema.paths.keys()) == 1
220+
221+
get = schema.paths["/"].get
222+
assert get.parameters is None
223+
224+
response = get.responses[200].content[JSON_CONTENT_TYPE]
225+
reference = response.schema_
226+
assert reference.ref == "#/components/schemas/User"
227+
228+
assert "User" in schema.components.schemas
229+
user_schema = schema.components.schemas["User"]
230+
assert isinstance(user_schema, Schema)
231+
assert user_schema.title == "User"
232+
assert "surname" in user_schema.properties
233+
234+
235+
def test_openapi_with_body_param():
236+
app = APIGatewayRestResolver()
237+
238+
class User(BaseModel):
239+
name: str
240+
241+
@app.post("/users")
242+
def handler(user: User):
243+
print(user)
244+
245+
schema = app.get_openapi_schema()
246+
assert len(schema.paths.keys()) == 1
247+
248+
post = schema.paths["/users"].post
249+
assert post.parameters is None
250+
assert post.requestBody is not None
251+
252+
request_body = post.requestBody
253+
assert request_body.required is True
254+
assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/User"
255+
256+
257+
def test_openapi_with_embed_body_param():
258+
app = APIGatewayRestResolver()
259+
260+
class User(BaseModel):
261+
name: str
262+
263+
@app.post("/users")
264+
def handler(user: Annotated[User, Body(embed=True)]):
265+
print(user)
266+
267+
schema = app.get_openapi_schema()
268+
assert len(schema.paths.keys()) == 1
269+
270+
post = schema.paths["/users"].post
271+
assert post.parameters is None
272+
assert post.requestBody is not None
273+
274+
request_body = post.requestBody
275+
assert request_body.required is True
276+
# Notice here we craft a specific schema for the embedded user
277+
assert request_body.content[JSON_CONTENT_TYPE].schema_.ref == "#/components/schemas/Body_handler_users_post"
278+
279+
# Ensure that the custom body schema actually points to the real user class
280+
components = schema.components
281+
assert "Body_handler_users_post" in components.schemas
282+
body_post_handler_schema = components.schemas["Body_handler_users_post"]
283+
assert body_post_handler_schema.properties["user"].ref == "#/components/schemas/User"
284+
285+
286+
def test_create_header():
287+
header = Header(convert_underscores=True)
288+
assert header.convert_underscores is True
289+
290+
291+
def test_create_body():
292+
body = Body(embed=True, examples=[Example(summary="Example 1", value=10)])
293+
assert body.embed is True
294+
295+
296+
# Tests that when we try to create a model without a field type, we return None
297+
def test_create_empty_model_field():
298+
result = _create_model_field(None, int, "name", False)
299+
assert result is None
300+
301+
302+
# Tests that when we try to crate a param model without a source, we default to "query"
303+
def test_create_model_field_with_empty_in():
304+
field_info = Param()
305+
306+
result = _create_model_field(field_info, int, "name", False)
307+
assert result.field_info.in_ == ParamTypes.query
308+
309+
310+
# Tests that when we try to create a model field with convert_underscore, we convert the field name
311+
def test_create_model_field_convert_underscore():
312+
field_info = Header(alias=None, convert_underscores=True)
313+
314+
result = _create_model_field(field_info, int, "user_id", False)
315+
assert result.alias == "user-id"
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
2+
3+
4+
def test_openapi_default_response():
5+
app = APIGatewayRestResolver(enable_validation=True)
6+
7+
@app.get("/")
8+
def handler():
9+
pass
10+
11+
schema = app.get_openapi_schema()
12+
responses = schema.paths["/"].get.responses
13+
assert 200 in responses.keys()
14+
assert responses[200].description == "Successful Response"
15+
16+
assert 422 in responses.keys()
17+
assert responses[422].description == "Validation Error"
18+
19+
20+
def test_openapi_200_response_with_description():
21+
app = APIGatewayRestResolver(enable_validation=True)
22+
23+
@app.get("/", response_description="Custom response")
24+
def handler():
25+
return {"message": "hello world"}
26+
27+
schema = app.get_openapi_schema()
28+
responses = schema.paths["/"].get.responses
29+
assert 200 in responses.keys()
30+
assert responses[200].description == "Custom response"
31+
32+
assert 422 in responses.keys()
33+
assert responses[422].description == "Validation Error"
34+
35+
36+
def test_openapi_200_custom_response():
37+
app = APIGatewayRestResolver(enable_validation=True)
38+
39+
@app.get("/", responses={202: {"description": "Custom response"}})
40+
def handler():
41+
return {"message": "hello world"}
42+
43+
schema = app.get_openapi_schema()
44+
responses = schema.paths["/"].get.responses
45+
assert 202 in responses.keys()
46+
assert responses[202].description == "Custom response"
47+
48+
assert 200 not in responses.keys()
49+
assert 422 not in responses.keys()
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
from typing import Dict
3+
4+
import pytest
5+
6+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
7+
8+
9+
def test_openapi_duplicated_serialization():
10+
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True
11+
app = APIGatewayRestResolver(enable_validation=True)
12+
13+
# WHEN we have duplicated operations
14+
@app.get("/")
15+
def handler():
16+
pass
17+
18+
@app.get("/")
19+
def handler(): # noqa: F811
20+
pass
21+
22+
# THEN we should get a warning
23+
with pytest.warns(UserWarning, match="Duplicate Operation*"):
24+
app.get_openapi_schema()
25+
26+
27+
def test_openapi_serialize_json():
28+
# GIVEN APIGatewayRestResolver is initialized with enable_validation=True
29+
app = APIGatewayRestResolver(enable_validation=True)
30+
31+
@app.get("/")
32+
def handler():
33+
pass
34+
35+
# WHEN we serialize as json_schema
36+
schema = json.loads(app.get_openapi_json_schema())
37+
38+
# THEN we should get a dictionary
39+
assert isinstance(schema, Dict)
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
2+
from aws_lambda_powertools.event_handler.openapi.models import Server
3+
4+
5+
def test_openapi_schema_default_server():
6+
app = APIGatewayRestResolver()
7+
8+
schema = app.get_openapi_schema(title="Hello API", version="1.0.0")
9+
assert schema.servers
10+
assert len(schema.servers) == 1
11+
assert schema.servers[0].url == "/"
12+
13+
14+
def test_openapi_schema_custom_server():
15+
app = APIGatewayRestResolver()
16+
17+
schema = app.get_openapi_schema(
18+
title="Hello API",
19+
version="1.0.0",
20+
servers=[Server(url="https://example.org/", description="Example website")],
21+
)
22+
23+
assert schema.servers
24+
assert len(schema.servers) == 1
25+
assert str(schema.servers[0].url) == "https://example.org/"
26+
assert schema.servers[0].description == "Example website"
Lines changed: 290 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,290 @@
1+
import json
2+
from dataclasses import dataclass
3+
from enum import Enum
4+
from pathlib import PurePath
5+
from typing import Tuple
6+
7+
from pydantic import BaseModel
8+
9+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
10+
from aws_lambda_powertools.event_handler.openapi.params import Body
11+
from aws_lambda_powertools.shared.types import Annotated
12+
from tests.functional.utils import load_event
13+
14+
LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json")
15+
16+
17+
def test_validate_scalars():
18+
# GIVEN an APIGatewayRestResolver with validation enabled
19+
app = APIGatewayRestResolver(enable_validation=True)
20+
21+
# WHEN a handler is defined with a scalar parameter
22+
@app.get("/users/<user_id>")
23+
def handler(user_id: int):
24+
print(user_id)
25+
26+
# sending a number
27+
LOAD_GW_EVENT["path"] = "/users/123"
28+
29+
# THEN the handler should be invoked and return 200
30+
result = app(LOAD_GW_EVENT, {})
31+
assert result["statusCode"] == 200
32+
33+
# sending a string
34+
LOAD_GW_EVENT["path"] = "/users/abc"
35+
36+
# THEN the handler should be invoked and return 422
37+
result = app(LOAD_GW_EVENT, {})
38+
assert result["statusCode"] == 422
39+
assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
40+
41+
42+
def test_validate_scalars_with_default():
43+
# GIVEN an APIGatewayRestResolver with validation enabled
44+
app = APIGatewayRestResolver(enable_validation=True)
45+
46+
# WHEN a handler is defined with a default scalar parameter
47+
@app.get("/users/<user_id>")
48+
def handler(user_id: int = 123):
49+
print(user_id)
50+
51+
# sending a number
52+
LOAD_GW_EVENT["path"] = "/users/123"
53+
54+
# THEN the handler should be invoked and return 200
55+
result = app(LOAD_GW_EVENT, {})
56+
assert result["statusCode"] == 200
57+
58+
# sending a string
59+
LOAD_GW_EVENT["path"] = "/users/abc"
60+
61+
# THEN the handler should be invoked and return 422
62+
result = app(LOAD_GW_EVENT, {})
63+
assert result["statusCode"] == 422
64+
assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
65+
66+
67+
def test_validate_scalars_with_default_and_optional():
68+
# GIVEN an APIGatewayRestResolver with validation enabled
69+
app = APIGatewayRestResolver(enable_validation=True)
70+
71+
# WHEN a handler is defined with a default scalar parameter
72+
@app.get("/users/<user_id>")
73+
def handler(user_id: int = 123, include_extra: bool = False):
74+
print(user_id)
75+
76+
# sending a number
77+
LOAD_GW_EVENT["path"] = "/users/123"
78+
79+
# THEN the handler should be invoked and return 200
80+
result = app(LOAD_GW_EVENT, {})
81+
assert result["statusCode"] == 200
82+
83+
# sending a string
84+
LOAD_GW_EVENT["path"] = "/users/abc"
85+
86+
# THEN the handler should be invoked and return 422
87+
result = app(LOAD_GW_EVENT, {})
88+
assert result["statusCode"] == 422
89+
assert any(text in result["body"] for text in ["type_error.integer", "int_parsing"])
90+
91+
92+
def test_validate_return_type():
93+
# GIVEN an APIGatewayRestResolver with validation enabled
94+
app = APIGatewayRestResolver(enable_validation=True)
95+
96+
# WHEN a handler is defined with a return type
97+
@app.get("/")
98+
def handler() -> int:
99+
return 123
100+
101+
LOAD_GW_EVENT["path"] = "/"
102+
103+
# THEN the handler should be invoked and return 200
104+
# THEN the body must be 123
105+
result = app(LOAD_GW_EVENT, {})
106+
assert result["statusCode"] == 200
107+
assert result["body"] == 123
108+
109+
110+
def test_validate_return_tuple():
111+
# GIVEN an APIGatewayRestResolver with validation enabled
112+
app = APIGatewayRestResolver(enable_validation=True)
113+
114+
sample_tuple = (1, 2, 3)
115+
116+
# WHEN a handler is defined with a return type as Tuple
117+
@app.get("/")
118+
def handler() -> Tuple:
119+
return sample_tuple
120+
121+
LOAD_GW_EVENT["path"] = "/"
122+
123+
# THEN the handler should be invoked and return 200
124+
# THEN the body must be a tuple
125+
result = app(LOAD_GW_EVENT, {})
126+
assert result["statusCode"] == 200
127+
assert result["body"] == list(sample_tuple)
128+
129+
130+
def test_validate_return_purepath():
131+
# GIVEN an APIGatewayRestResolver with validation enabled
132+
app = APIGatewayRestResolver(enable_validation=True)
133+
134+
sample_path = PurePath(__file__)
135+
136+
# WHEN a handler is defined with a return type as string
137+
# WHEN return value is a PurePath
138+
@app.get("/")
139+
def handler() -> str:
140+
return sample_path
141+
142+
LOAD_GW_EVENT["path"] = "/"
143+
144+
# THEN the handler should be invoked and return 200
145+
# THEN the body must be a string
146+
result = app(LOAD_GW_EVENT, {})
147+
assert result["statusCode"] == 200
148+
assert result["body"] == sample_path.as_posix()
149+
150+
151+
def test_validate_return_enum():
152+
# GIVEN an APIGatewayRestResolver with validation enabled
153+
app = APIGatewayRestResolver(enable_validation=True)
154+
155+
class Model(Enum):
156+
name = "powertools"
157+
158+
# WHEN a handler is defined with a return type as Enum
159+
@app.get("/")
160+
def handler() -> Model:
161+
return Model.name.value
162+
163+
LOAD_GW_EVENT["path"] = "/"
164+
165+
# THEN the handler should be invoked and return 200
166+
# THEN the body must be a string
167+
result = app(LOAD_GW_EVENT, {})
168+
assert result["statusCode"] == 200
169+
assert result["body"] == "powertools"
170+
171+
172+
def test_validate_return_dataclass():
173+
# GIVEN an APIGatewayRestResolver with validation enabled
174+
app = APIGatewayRestResolver(enable_validation=True)
175+
176+
@dataclass
177+
class Model:
178+
name: str
179+
age: int
180+
181+
# WHEN a handler is defined with a return type as dataclass
182+
@app.get("/")
183+
def handler() -> Model:
184+
return Model(name="John", age=30)
185+
186+
LOAD_GW_EVENT["path"] = "/"
187+
188+
# THEN the handler should be invoked and return 200
189+
# THEN the body must be a dict
190+
result = app(LOAD_GW_EVENT, {})
191+
assert result["statusCode"] == 200
192+
assert result["body"] == {"name": "John", "age": 30}
193+
194+
195+
def test_validate_return_model():
196+
# GIVEN an APIGatewayRestResolver with validation enabled
197+
app = APIGatewayRestResolver(enable_validation=True)
198+
199+
class Model(BaseModel):
200+
name: str
201+
age: int
202+
203+
# WHEN a handler is defined with a return type as Pydantic model
204+
@app.get("/")
205+
def handler() -> Model:
206+
return Model(name="John", age=30)
207+
208+
LOAD_GW_EVENT["path"] = "/"
209+
210+
# THEN the handler should be invoked and return 200
211+
# THEN the body must be a dict
212+
result = app(LOAD_GW_EVENT, {})
213+
assert result["statusCode"] == 200
214+
assert result["body"] == {"name": "John", "age": 30}
215+
216+
217+
def test_validate_invalid_return_model():
218+
# GIVEN an APIGatewayRestResolver with validation enabled
219+
app = APIGatewayRestResolver(enable_validation=True)
220+
221+
class Model(BaseModel):
222+
name: str
223+
age: int
224+
225+
# WHEN a handler is defined with a return type as Pydantic model
226+
@app.get("/")
227+
def handler() -> Model:
228+
return {"name": "John"} # type: ignore
229+
230+
LOAD_GW_EVENT["path"] = "/"
231+
232+
# THEN the handler should be invoked and return 422
233+
# THEN the body must be a dict
234+
result = app(LOAD_GW_EVENT, {})
235+
assert result["statusCode"] == 422
236+
assert "missing" in result["body"]
237+
238+
239+
def test_validate_body_param():
240+
# GIVEN an APIGatewayRestResolver with validation enabled
241+
app = APIGatewayRestResolver(enable_validation=True)
242+
243+
class Model(BaseModel):
244+
name: str
245+
age: int
246+
247+
# WHEN a handler is defined with a body parameter
248+
@app.post("/")
249+
def handler(user: Model) -> Model:
250+
return user
251+
252+
LOAD_GW_EVENT["httpMethod"] = "POST"
253+
LOAD_GW_EVENT["path"] = "/"
254+
LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30})
255+
256+
# THEN the handler should be invoked and return 200
257+
# THEN the body must be a dict
258+
result = app(LOAD_GW_EVENT, {})
259+
assert result["statusCode"] == 200
260+
assert result["body"] == {"name": "John", "age": 30}
261+
262+
263+
def test_validate_embed_body_param():
264+
# GIVEN an APIGatewayRestResolver with validation enabled
265+
app = APIGatewayRestResolver(enable_validation=True)
266+
267+
class Model(BaseModel):
268+
name: str
269+
age: int
270+
271+
# WHEN a handler is defined with a body parameter
272+
@app.post("/")
273+
def handler(user: Annotated[Model, Body(embed=True)]) -> Model:
274+
return user
275+
276+
LOAD_GW_EVENT["httpMethod"] = "POST"
277+
LOAD_GW_EVENT["path"] = "/"
278+
LOAD_GW_EVENT["body"] = json.dumps({"name": "John", "age": 30})
279+
280+
# THEN the handler should be invoked and return 422
281+
# THEN the body must be a dict
282+
result = app(LOAD_GW_EVENT, {})
283+
assert result["statusCode"] == 422
284+
assert "missing" in result["body"]
285+
286+
# THEN the handler should be invoked and return 200
287+
# THEN the body must be a dict
288+
LOAD_GW_EVENT["body"] = json.dumps({"user": {"name": "John", "age": 30}})
289+
result = app(LOAD_GW_EVENT, {})
290+
assert result["statusCode"] == 200

0 commit comments

Comments
 (0)
Please sign in to comment.