Skip to content

Commit d7792fd

Browse files
committed
fix: pydantic 2
1 parent 568bcc8 commit d7792fd

File tree

3 files changed

+36
-9
lines changed

3 files changed

+36
-9
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@ def dependant(self) -> "Dependant":
445445
if self._dependant is None:
446446
from aws_lambda_powertools.event_handler.openapi.dependant import get_dependant
447447

448-
self._dependant = get_dependant(path=self.openapi_path, call=self.func)
448+
self._dependant = get_dependant(path=self.openapi_path, call=self.func, responses=self.responses)
449449

450450
return self._dependant
451451

@@ -525,15 +525,15 @@ def _get_openapi_path(
525525

526526
# Case 2.1: the 'content' has a model
527527
if "model" in payload:
528-
from aws_lambda_powertools.event_handler.openapi.params import analyze_param
529-
530-
return_field = analyze_param(
531-
param_name="return",
532-
annotation=cast(OpenAPIResponseContentModel, payload)["model"],
533-
value=None,
534-
is_path_param=False,
535-
is_response_param=True,
528+
# Find the model in the dependant's extra models
529+
return_field = next(
530+
filter(
531+
lambda model: model.type_ is cast(OpenAPIResponseContentModel, payload)["model"],
532+
self.dependant.response_extra_models,
533+
),
536534
)
535+
if not return_field:
536+
raise AssertionError("Model declared in custom responses was not found")
537537

538538
new_payload = self._openapi_operation_return(
539539
param=return_field,
@@ -2151,6 +2151,9 @@ def _get_fields_from_routes(routes: Sequence[Route]) -> List["ModelField"]:
21512151
if route.dependant.return_param:
21522152
responses_from_routes.append(route.dependant.return_param)
21532153

2154+
if route.dependant.response_extra_models:
2155+
responses_from_routes.extend(route.dependant.response_extra_models)
2156+
21542157
flat_models = list(responses_from_routes + request_fields_from_routes + body_fields_from_routes)
21552158
return flat_models
21562159

aws_lambda_powertools/event_handler/openapi/dependant.py

+22
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
create_response_field,
2525
get_flat_dependant,
2626
)
27+
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel
2728

2829
"""
2930
This turns the opaque function signature into typed, validated models.
@@ -145,6 +146,7 @@ def get_dependant(
145146
path: str,
146147
call: Callable[..., Any],
147148
name: Optional[str] = None,
149+
responses: Optional[Dict[int, OpenAPIResponse]] = None,
148150
) -> Dependant:
149151
"""
150152
Returns a dependant model for a handler function. A dependant model is a model that contains
@@ -158,6 +160,8 @@ def get_dependant(
158160
The handler function
159161
name: str, optional
160162
The name of the handler function
163+
responses: List[Dict[int, OpenAPIResponse]], optional
164+
The list of extra responses for the handler function
161165
162166
Returns
163167
-------
@@ -210,6 +214,24 @@ def get_dependant(
210214

211215
dependant.return_param = param_field
212216

217+
# Also add the optional extra responses to the dependant model.
218+
if responses:
219+
for response in responses.values():
220+
if "content" in response:
221+
for schema in response["content"].values():
222+
if "model" in schema:
223+
response_field = analyze_param(
224+
param_name="return",
225+
annotation=cast(OpenAPIResponseContentModel, schema)["model"],
226+
value=None,
227+
is_path_param=False,
228+
is_response_param=True,
229+
)
230+
if response_field is None:
231+
raise AssertionError("Response field is None for response model")
232+
233+
dependant.response_extra_models.append(response_field)
234+
213235
return dependant
214236

215237

aws_lambda_powertools/event_handler/openapi/params.py

+2
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(
4949
cookie_params: Optional[List[ModelField]] = None,
5050
body_params: Optional[List[ModelField]] = None,
5151
return_param: Optional[ModelField] = None,
52+
response_extra_models: Optional[List[ModelField]] = None,
5253
name: Optional[str] = None,
5354
call: Optional[Callable[..., Any]] = None,
5455
request_param_name: Optional[str] = None,
@@ -64,6 +65,7 @@ def __init__(
6465
self.cookie_params = cookie_params or []
6566
self.body_params = body_params or []
6667
self.return_param = return_param or None
68+
self.response_extra_models = response_extra_models or []
6769
self.request_param_name = request_param_name
6870
self.websocket_param_name = websocket_param_name
6971
self.http_connection_param_name = http_connection_param_name

0 commit comments

Comments
 (0)