Skip to content

Commit 1e22e6f

Browse files
fix(event-handler): enable path parameters on Bedrock handler (#3312)
* fix(event-handler): enable path parameters on Bedrock handler * chore: change default openapi version to 3.0.0 * fix: unexport Form, File and Header --------- Co-authored-by: Leandro Damascena <[email protected]>
1 parent b926bd1 commit 1e22e6f

File tree

9 files changed

+85
-17
lines changed

9 files changed

+85
-17
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1829,7 +1829,8 @@ def _resolve(self) -> ResponseBuilder:
18291829
# Add matched Route reference into the Resolver context
18301830
self.append_context(_route=route, _path=path)
18311831

1832-
return self._call_route(route, match_results.groupdict()) # pass fn args
1832+
route_keys = self._convert_matches_into_route_keys(match_results)
1833+
return self._call_route(route, route_keys) # pass fn args
18331834

18341835
logger.debug(f"No match found for path {path} and method {method}")
18351836
return self._not_found(method)
@@ -1858,6 +1859,10 @@ def _remove_prefix(self, path: str) -> str:
18581859

18591860
return path
18601861

1862+
def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]:
1863+
"""Converts the regex match into a dict of route keys"""
1864+
return match.groupdict()
1865+
18611866
@staticmethod
18621867
def _path_starts_with(path: str, prefix: str):
18631868
"""Returns true if the `path` starts with a prefix plus a `/`"""

aws_lambda_powertools/event_handler/bedrock_agent.py

+10
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from re import Match
12
from typing import Any, Dict
23

34
from typing_extensions import override
@@ -75,3 +76,12 @@ def __init__(self, debug: bool = False, enable_validation: bool = True):
7576
enable_validation=enable_validation,
7677
)
7778
self._response_builder_class = BedrockResponseBuilder
79+
80+
@override
81+
def _convert_matches_into_route_keys(self, match: Match) -> Dict[str, str]:
82+
# In Bedrock Agents, all the parameters come inside the "parameters" key, not on the apiPath
83+
# So we have to search for route parameters in the parameters key
84+
parameters: Dict[str, str] = {}
85+
if match.groupdict() and self.current_event.parameters:
86+
parameters = {parameter["name"]: parameter["value"] for parameter in self.current_event.parameters}
87+
return parameters
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
DEFAULT_API_VERSION = "1.0.0"
2-
DEFAULT_OPENAPI_VERSION = "3.1.0"
2+
DEFAULT_OPENAPI_VERSION = "3.0.0"

aws_lambda_powertools/event_handler/openapi/dependant.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from aws_lambda_powertools.event_handler.openapi.params import (
1515
Body,
1616
Dependant,
17-
File,
18-
Form,
19-
Header,
2017
Param,
2118
ParamTypes,
2219
Query,
20+
_File,
21+
_Form,
22+
_Header,
2323
analyze_param,
2424
create_response_field,
2525
get_flat_dependant,
@@ -235,7 +235,7 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
235235
return False
236236
elif is_scalar_field(field=param_field):
237237
return False
238-
elif isinstance(param_field.field_info, (Query, Header)) and is_scalar_sequence_field(param_field):
238+
elif isinstance(param_field.field_info, (Query, _Header)) and is_scalar_sequence_field(param_field):
239239
return False
240240
else:
241241
if not isinstance(param_field.field_info, Body):
@@ -326,10 +326,12 @@ def get_body_field_info(
326326
if not required:
327327
body_field_info_kwargs["default"] = None
328328

329-
if any(isinstance(f.field_info, File) for f in flat_dependant.body_params):
330-
body_field_info: Type[Body] = File
331-
elif any(isinstance(f.field_info, Form) for f in flat_dependant.body_params):
332-
body_field_info = Form
329+
if any(isinstance(f.field_info, _File) for f in flat_dependant.body_params):
330+
# MAINTENANCE: body_field_info: Type[Body] = _File
331+
raise NotImplementedError("_File fields are not supported in request bodies")
332+
elif any(isinstance(f.field_info, _Form) for f in flat_dependant.body_params):
333+
# MAINTENANCE: body_field_info: Type[Body] = _Form
334+
raise NotImplementedError("_Form fields are not supported in request bodies")
333335
else:
334336
body_field_info = Body
335337

aws_lambda_powertools/event_handler/openapi/params.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def __init__(
308308
)
309309

310310

311-
class Header(Param):
311+
class _Header(Param):
312312
"""
313313
A class used internally to represent a header parameter in a path operation.
314314
"""
@@ -471,7 +471,7 @@ def __repr__(self) -> str:
471471
return f"{self.__class__.__name__}({self.default})"
472472

473473

474-
class Form(Body):
474+
class _Form(Body):
475475
"""
476476
A class used internally to represent a form parameter in a path operation.
477477
"""
@@ -543,7 +543,7 @@ def __init__(
543543
)
544544

545545

546-
class File(Form):
546+
class _File(_Form):
547547
"""
548548
A class used internally to represent a file parameter in a path operation.
549549
"""

aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,13 @@ def session_attributes(self) -> Dict[str, str]:
9898
def prompt_session_attributes(self) -> Dict[str, str]:
9999
return self["promptSessionAttributes"]
100100

101-
# For compatibility with BaseProxyEvent
101+
# The following methods add compatibility with BaseProxyEvent
102102
@property
103103
def path(self) -> str:
104104
return self["apiPath"]
105+
106+
@property
107+
def query_string_parameters(self) -> Optional[Dict[str, str]]:
108+
# In Bedrock Agent events, query string parameters are passed as undifferentiated parameters,
109+
# together with the other parameters. So we just return all parameters here.
110+
return {x["name"]: x["value"] for x in self["parameters"]} if self.get("parameters") else None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{
2+
"actionGroup": "ClaimManagementActionGroup",
3+
"messageVersion": "1.0",
4+
"sessionId": "12345678912345",
5+
"sessionAttributes": {},
6+
"promptSessionAttributes": {},
7+
"inputText": "I want to claim my insurance",
8+
"agent": {
9+
"alias": "TSTALIASID",
10+
"name": "test",
11+
"version": "DRAFT",
12+
"id": "8ZXY0W8P1H"
13+
},
14+
"parameters": [
15+
{
16+
"type": "string",
17+
"name": "claim_id",
18+
"value": "123"
19+
}
20+
],
21+
"httpMethod": "GET",
22+
"apiPath": "/claims/<claim_id>"
23+
}

tests/functional/event_handler/test_bedrock_agent.py

+22
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,28 @@ def claims() -> Dict[str, Any]:
3434
assert body == json.dumps({"output": claims_response})
3535

3636

37+
def test_bedrock_agent_with_path_params():
38+
# GIVEN a Bedrock Agent event
39+
app = BedrockAgentResolver()
40+
41+
@app.get("/claims/<claim_id>")
42+
def claims(claim_id: str):
43+
assert isinstance(app.current_event, BedrockAgentEvent)
44+
assert app.lambda_context == {}
45+
assert claim_id == "123"
46+
47+
# WHEN calling the event handler
48+
result = app(load_event("bedrockAgentEventWithPathParams.json"), {})
49+
50+
# THEN process event correctly
51+
# AND set the current_event type as BedrockAgentEvent
52+
assert result["messageVersion"] == "1.0"
53+
assert result["response"]["apiPath"] == "/claims/<claim_id>"
54+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
55+
assert result["response"]["httpMethod"] == "GET"
56+
assert result["response"]["httpStatusCode"] == 200
57+
58+
3759
def test_bedrock_agent_event_with_response():
3860
# GIVEN a Bedrock Agent event
3961
app = BedrockAgentResolver()

tests/functional/event_handler/test_openapi_params.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@
1414
)
1515
from aws_lambda_powertools.event_handler.openapi.params import (
1616
Body,
17-
Header,
1817
Param,
1918
ParamTypes,
2019
Query,
2120
_create_model_field,
21+
_Header,
2222
)
2323
from aws_lambda_powertools.shared.types import Annotated
2424

@@ -375,7 +375,7 @@ def secret():
375375

376376

377377
def test_create_header():
378-
header = Header(convert_underscores=True)
378+
header = _Header(convert_underscores=True)
379379
assert header.convert_underscores is True
380380

381381

@@ -400,7 +400,7 @@ def test_create_model_field_with_empty_in():
400400

401401
# Tests that when we try to create a model field with convert_underscore, we convert the field name
402402
def test_create_model_field_convert_underscore():
403-
field_info = Header(alias=None, convert_underscores=True)
403+
field_info = _Header(alias=None, convert_underscores=True)
404404

405405
result = _create_model_field(field_info, int, "user_id", False)
406406
assert result.alias == "user-id"

0 commit comments

Comments
 (0)