Skip to content

Commit 275e7e7

Browse files
committed
feat(event_handler): add Bedrock Agent event handler
1 parent 0964e75 commit 275e7e7

File tree

6 files changed

+185
-9
lines changed

6 files changed

+185
-9
lines changed

Diff for: aws_lambda_powertools/event_handler/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Response,
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
14+
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
1415
from aws_lambda_powertools.event_handler.lambda_function_url import (
1516
LambdaFunctionUrlResolver,
1617
)
@@ -22,6 +23,7 @@
2223
"APIGatewayHttpResolver",
2324
"ALBResolver",
2425
"ApiGatewayResolver",
26+
"BedrockAgentResolver",
2527
"CORSConfig",
2628
"LambdaFunctionUrlResolver",
2729
"Response",

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
ALBEvent,
4343
APIGatewayProxyEvent,
4444
APIGatewayProxyEventV2,
45+
BedrockAgentEvent,
4546
LambdaFunctionUrlEvent,
4647
VPCLatticeEvent,
4748
VPCLatticeEventV2,
@@ -83,6 +84,7 @@ class ProxyEventType(Enum):
8384
APIGatewayProxyEvent = "APIGatewayProxyEvent"
8485
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
8586
ALBEvent = "ALBEvent"
87+
BedrockAgentEvent = "BedrockAgentEvent"
8688
VPCLatticeEvent = "VPCLatticeEvent"
8789
VPCLatticeEventV2 = "VPCLatticeEventV2"
8890
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
@@ -1291,6 +1293,7 @@ def __init__(
12911293
self._strip_prefixes = strip_prefixes
12921294
self.context: Dict = {} # early init as customers might add context before event resolution
12931295
self.processed_stack_frames = []
1296+
self.response_builder_class = ResponseBuilder
12941297

12951298
# Allow for a custom serializer or a concise json serialization
12961299
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1647,14 +1650,17 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
16471650
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
16481651
return re.compile(base_regex.format(rule_regex))
16491652

1650-
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
1653+
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911
16511654
"""Convert the event dict to the corresponding data class"""
16521655
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
16531656
logger.debug("Converting event to API Gateway REST API contract")
16541657
return APIGatewayProxyEvent(event)
16551658
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
16561659
logger.debug("Converting event to API Gateway HTTP API contract")
16571660
return APIGatewayProxyEventV2(event)
1661+
if self._proxy_type == ProxyEventType.BedrockAgentEvent:
1662+
logger.debug("Converting event to Bedrock Agent contract")
1663+
return BedrockAgentEvent(event)
16581664
if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent:
16591665
logger.debug("Converting event to Lambda Function URL contract")
16601666
return LambdaFunctionUrlEvent(event)
@@ -1732,9 +1738,9 @@ def _not_found(self, method: str) -> ResponseBuilder:
17321738

17331739
handler = self._lookup_exception_handler(NotFoundError)
17341740
if handler:
1735-
return ResponseBuilder(handler(NotFoundError()))
1741+
return self.response_builder_class(handler(NotFoundError()))
17361742

1737-
return ResponseBuilder(
1743+
return self.response_builder_class(
17381744
Response(
17391745
status_code=HTTPStatus.NOT_FOUND.value,
17401746
content_type=content_types.APPLICATION_JSON,
@@ -1749,7 +1755,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
17491755
# Reset Processed stack for Middleware (for debugging purposes)
17501756
self._reset_processed_stack()
17511757

1752-
return ResponseBuilder(
1758+
return self.response_builder_class(
17531759
self._to_response(
17541760
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
17551761
),
@@ -1766,7 +1772,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
17661772
# If the user has turned on debug mode,
17671773
# we'll let the original exception propagate, so
17681774
# they get more information about what went wrong.
1769-
return ResponseBuilder(
1775+
return self.response_builder_class(
17701776
Response(
17711777
status_code=500,
17721778
content_type=content_types.TEXT_PLAIN,
@@ -1805,12 +1811,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
18051811
handler = self._lookup_exception_handler(type(exp))
18061812
if handler:
18071813
try:
1808-
return ResponseBuilder(handler(exp), route)
1814+
return self.response_builder_class(handler(exp), route)
18091815
except ServiceError as service_error:
18101816
exp = service_error
18111817

18121818
if isinstance(exp, ServiceError):
1813-
return ResponseBuilder(
1819+
return self.response_builder_class(
18141820
Response(
18151821
status_code=exp.status_code,
18161822
content_type=content_types.APPLICATION_JSON,

Diff for: aws_lambda_powertools/event_handler/bedrock_agent.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
from typing import Any, Dict, Optional, cast
3+
4+
from aws_lambda_powertools.event_handler import ApiGatewayResolver
5+
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder
6+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
7+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class BedrockResponseBuilder(ResponseBuilder):
13+
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
14+
"""Build the full response dict to be returned by the lambda"""
15+
self._route(event, cors)
16+
17+
bedrock_event = cast(BedrockAgentEvent, event)
18+
19+
return {
20+
"messageVersion": "1.0",
21+
"response": {
22+
"actionGroup": bedrock_event.action_group,
23+
"apiPath": bedrock_event.api_path,
24+
"httpMethod": bedrock_event.http_method,
25+
"httpStatusCode": self.response.status_code,
26+
"responseBody": {
27+
"application/json": {
28+
"body": self.response.body,
29+
},
30+
},
31+
},
32+
}
33+
34+
35+
class BedrockAgentResolver(ApiGatewayResolver):
36+
"""Bedrock Agent Resolver
37+
38+
See https://aws.amazon.com/bedrock/agents/ for more information.
39+
40+
Examples
41+
--------
42+
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
43+
44+
```python
45+
from aws_lambda_powertools import Tracer
46+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
47+
48+
tracer = Tracer()
49+
app = BedrockAgentResolver()
50+
51+
@app.get("/claims")
52+
def simple_get():
53+
return "You have 3 claims"
54+
55+
@tracer.capture_lambda_handler
56+
def lambda_handler(event, context):
57+
return app.resolve(event, context)
58+
"""
59+
60+
current_event: BedrockAgentEvent
61+
62+
def __init__(self, debug: bool = False, enable_validation: bool = True):
63+
super().__init__(
64+
ProxyEventType.BedrockAgentEvent,
65+
None,
66+
debug,
67+
None,
68+
None,
69+
enable_validation,
70+
)
71+
self.response_builder_class = BedrockResponseBuilder

Diff for: aws_lambda_powertools/shared/headers_serializer.py

+9
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,12 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo
123123
payload["headers"][key] = values[-1]
124124

125125
return payload
126+
127+
128+
class NoopSerializer(BaseHeadersSerializer):
129+
"""
130+
Noop serializer that doesn't do anything. This is useful for resolvers that don't need to set headers or cookies.
131+
"""
132+
133+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
134+
return {}

Diff for: aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, List, Optional
22

3-
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
3+
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer, NoopSerializer
4+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
45

56

67
class BedrockAgentInfo(DictWrapper):
@@ -47,7 +48,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
4748
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
4849

4950

50-
class BedrockAgentEvent(DictWrapper):
51+
class BedrockAgentEvent(BaseProxyEvent):
5152
"""
5253
Bedrock Agent input event
5354
@@ -97,3 +98,11 @@ def session_attributes(self) -> Dict[str, str]:
9798
@property
9899
def prompt_session_attributes(self) -> Dict[str, str]:
99100
return self["promptSessionAttributes"]
101+
102+
# For compatibility with BaseProxyEvent
103+
@property
104+
def path(self) -> str:
105+
return self["apiPath"]
106+
107+
def header_serializer(self) -> BaseHeadersSerializer:
108+
return NoopSerializer()

Diff for: tests/functional/event_handler/test_bedrock_agent.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
5+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
6+
from tests.functional.utils import load_event
7+
8+
claims_response = "You have 3 claims"
9+
10+
11+
def test_bedrock_agent_event():
12+
# GIVEN a Bedrock Agent event
13+
app = BedrockAgentResolver()
14+
15+
@app.get("/claims")
16+
def claims() -> Dict[str, Any]:
17+
assert isinstance(app.current_event, BedrockAgentEvent)
18+
assert app.lambda_context == {}
19+
return {"output": claims_response}
20+
21+
# WHEN calling the event handler
22+
result = app(load_event("bedrockAgentEvent.json"), {})
23+
24+
# THEN process event correctly
25+
# AND set the current_event type as BedrockAgentEvent
26+
assert result["messageVersion"] == "1.0"
27+
assert result["response"]["apiPath"] == "/claims"
28+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
29+
assert result["response"]["httpMethod"] == "GET"
30+
assert result["response"]["httpStatusCode"] == 200
31+
32+
body = result["response"]["responseBody"]["application/json"]["body"]
33+
assert body == {"output": claims_response}
34+
35+
36+
def test_bedrock_agent_event_with_response():
37+
# GIVEN a Bedrock Agent event
38+
app = BedrockAgentResolver()
39+
output = json.dumps({"output": claims_response})
40+
41+
@app.get("/claims")
42+
def claims():
43+
assert isinstance(app.current_event, BedrockAgentEvent)
44+
assert app.lambda_context == {}
45+
return Response(200, content_types.APPLICATION_JSON, output)
46+
47+
# WHEN calling the event handler
48+
result = app(load_event("bedrockAgentEvent.json"), {})
49+
50+
# THEN process event correctly
51+
# AND set the current_event type as BedrockAgentEvent
52+
assert result["messageVersion"] == "1.0"
53+
assert result["response"]["apiPath"] == "/claims"
54+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
55+
assert result["response"]["httpMethod"] == "GET"
56+
assert result["response"]["httpStatusCode"] == 200
57+
58+
body = result["response"]["responseBody"]["application/json"]["body"]
59+
assert body == output
60+
61+
62+
def test_bedrock_agent_event_with_no_matches():
63+
# GIVEN a Bedrock Agent event
64+
app = BedrockAgentResolver()
65+
66+
@app.get("/no_match")
67+
def claims():
68+
raise RuntimeError()
69+
70+
# WHEN calling the event handler
71+
result = app(load_event("bedrockAgentEvent.json"), {})
72+
73+
# THEN process event correctly
74+
# AND return 404 because the event doesn't match any known rule
75+
assert result["messageVersion"] == "1.0"
76+
assert result["response"]["apiPath"] == "/claims"
77+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
78+
assert result["response"]["httpMethod"] == "GET"
79+
assert result["response"]["httpStatusCode"] == 404

0 commit comments

Comments
 (0)