Skip to content

Commit 9a288aa

Browse files
committed
feat(event_handler): add Bedrock Agent event handler
1 parent e9c280b commit 9a288aa

File tree

6 files changed

+185
-9
lines changed

6 files changed

+185
-9
lines changed

Diff for: aws_lambda_powertools/event_handler/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Response,
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
14+
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
1415
from aws_lambda_powertools.event_handler.lambda_function_url import (
1516
LambdaFunctionUrlResolver,
1617
)
@@ -22,6 +23,7 @@
2223
"APIGatewayHttpResolver",
2324
"ALBResolver",
2425
"ApiGatewayResolver",
26+
"BedrockAgentResolver",
2527
"CORSConfig",
2628
"LambdaFunctionUrlResolver",
2729
"Response",

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+13-7
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
ALBEvent,
4646
APIGatewayProxyEvent,
4747
APIGatewayProxyEventV2,
48+
BedrockAgentEvent,
4849
LambdaFunctionUrlEvent,
4950
VPCLatticeEvent,
5051
VPCLatticeEventV2,
@@ -85,6 +86,7 @@ class ProxyEventType(Enum):
8586
APIGatewayProxyEvent = "APIGatewayProxyEvent"
8687
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
8788
ALBEvent = "ALBEvent"
89+
BedrockAgentEvent = "BedrockAgentEvent"
8890
VPCLatticeEvent = "VPCLatticeEvent"
8991
VPCLatticeEventV2 = "VPCLatticeEventV2"
9092
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
@@ -1315,6 +1317,7 @@ def __init__(
13151317
self._strip_prefixes = strip_prefixes
13161318
self.context: Dict = {} # early init as customers might add context before event resolution
13171319
self.processed_stack_frames = []
1320+
self.response_builder_class = ResponseBuilder
13181321

13191322
# Allow for a custom serializer or a concise json serialization
13201323
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1784,14 +1787,17 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
17841787
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
17851788
return re.compile(base_regex.format(rule_regex))
17861789

1787-
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
1790+
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911
17881791
"""Convert the event dict to the corresponding data class"""
17891792
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
17901793
logger.debug("Converting event to API Gateway REST API contract")
17911794
return APIGatewayProxyEvent(event)
17921795
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
17931796
logger.debug("Converting event to API Gateway HTTP API contract")
17941797
return APIGatewayProxyEventV2(event)
1798+
if self._proxy_type == ProxyEventType.BedrockAgentEvent:
1799+
logger.debug("Converting event to Bedrock Agent contract")
1800+
return BedrockAgentEvent(event)
17951801
if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent:
17961802
logger.debug("Converting event to Lambda Function URL contract")
17971803
return LambdaFunctionUrlEvent(event)
@@ -1869,9 +1875,9 @@ def _not_found(self, method: str) -> ResponseBuilder:
18691875

18701876
handler = self._lookup_exception_handler(NotFoundError)
18711877
if handler:
1872-
return ResponseBuilder(handler(NotFoundError()))
1878+
return self.response_builder_class(handler(NotFoundError()))
18731879

1874-
return ResponseBuilder(
1880+
return self.response_builder_class(
18751881
Response(
18761882
status_code=HTTPStatus.NOT_FOUND.value,
18771883
content_type=content_types.APPLICATION_JSON,
@@ -1886,7 +1892,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
18861892
# Reset Processed stack for Middleware (for debugging purposes)
18871893
self._reset_processed_stack()
18881894

1889-
return ResponseBuilder(
1895+
return self.response_builder_class(
18901896
self._to_response(
18911897
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
18921898
),
@@ -1903,7 +1909,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
19031909
# If the user has turned on debug mode,
19041910
# we'll let the original exception propagate, so
19051911
# they get more information about what went wrong.
1906-
return ResponseBuilder(
1912+
return self.response_builder_class(
19071913
Response(
19081914
status_code=500,
19091915
content_type=content_types.TEXT_PLAIN,
@@ -1942,12 +1948,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
19421948
handler = self._lookup_exception_handler(type(exp))
19431949
if handler:
19441950
try:
1945-
return ResponseBuilder(handler(exp), route)
1951+
return self.response_builder_class(handler(exp), route)
19461952
except ServiceError as service_error:
19471953
exp = service_error
19481954

19491955
if isinstance(exp, ServiceError):
1950-
return ResponseBuilder(
1956+
return self.response_builder_class(
19511957
Response(
19521958
status_code=exp.status_code,
19531959
content_type=content_types.APPLICATION_JSON,

Diff for: aws_lambda_powertools/event_handler/bedrock_agent.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import logging
2+
from typing import Any, Dict, Optional, cast
3+
4+
from aws_lambda_powertools.event_handler import ApiGatewayResolver
5+
from aws_lambda_powertools.event_handler.api_gateway import CORSConfig, ProxyEventType, ResponseBuilder
6+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
7+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class BedrockResponseBuilder(ResponseBuilder):
13+
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
14+
"""Build the full response dict to be returned by the lambda"""
15+
self._route(event, cors)
16+
17+
bedrock_event = cast(BedrockAgentEvent, event)
18+
19+
return {
20+
"messageVersion": "1.0",
21+
"response": {
22+
"actionGroup": bedrock_event.action_group,
23+
"apiPath": bedrock_event.api_path,
24+
"httpMethod": bedrock_event.http_method,
25+
"httpStatusCode": self.response.status_code,
26+
"responseBody": {
27+
"application/json": {
28+
"body": self.response.body,
29+
},
30+
},
31+
},
32+
}
33+
34+
35+
class BedrockAgentResolver(ApiGatewayResolver):
36+
"""Bedrock Agent Resolver
37+
38+
See https://aws.amazon.com/bedrock/agents/ for more information.
39+
40+
Examples
41+
--------
42+
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
43+
44+
```python
45+
from aws_lambda_powertools import Tracer
46+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
47+
48+
tracer = Tracer()
49+
app = BedrockAgentResolver()
50+
51+
@app.get("/claims")
52+
def simple_get():
53+
return "You have 3 claims"
54+
55+
@tracer.capture_lambda_handler
56+
def lambda_handler(event, context):
57+
return app.resolve(event, context)
58+
"""
59+
60+
current_event: BedrockAgentEvent
61+
62+
def __init__(self, debug: bool = False, enable_validation: bool = True):
63+
super().__init__(
64+
ProxyEventType.BedrockAgentEvent,
65+
None,
66+
debug,
67+
None,
68+
None,
69+
enable_validation,
70+
)
71+
self.response_builder_class = BedrockResponseBuilder

Diff for: aws_lambda_powertools/shared/headers_serializer.py

+9
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,12 @@ def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Coo
123123
payload["headers"][key] = values[-1]
124124

125125
return payload
126+
127+
128+
class NoopSerializer(BaseHeadersSerializer):
129+
"""
130+
Noop serializer that doesn't do anything. This is useful for resolvers that don't need to set headers or cookies.
131+
"""
132+
133+
def serialize(self, headers: Dict[str, Union[str, List[str]]], cookies: List[Cookie]) -> Dict[str, Any]:
134+
return {}

Diff for: aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Dict, List, Optional
22

3-
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
3+
from aws_lambda_powertools.shared.headers_serializer import BaseHeadersSerializer, NoopSerializer
4+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
45

56

67
class BedrockAgentInfo(DictWrapper):
@@ -47,7 +48,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
4748
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
4849

4950

50-
class BedrockAgentEvent(DictWrapper):
51+
class BedrockAgentEvent(BaseProxyEvent):
5152
"""
5253
Bedrock Agent input event
5354
@@ -97,3 +98,11 @@ def session_attributes(self) -> Dict[str, str]:
9798
@property
9899
def prompt_session_attributes(self) -> Dict[str, str]:
99100
return self["promptSessionAttributes"]
101+
102+
# For compatibility with BaseProxyEvent
103+
@property
104+
def path(self) -> str:
105+
return self["apiPath"]
106+
107+
def header_serializer(self) -> BaseHeadersSerializer:
108+
return NoopSerializer()

Diff for: tests/functional/event_handler/test_bedrock_agent.py

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import json
2+
from typing import Any, Dict
3+
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
5+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
6+
from tests.functional.utils import load_event
7+
8+
claims_response = "You have 3 claims"
9+
10+
11+
def test_bedrock_agent_event():
12+
# GIVEN a Bedrock Agent event
13+
app = BedrockAgentResolver()
14+
15+
@app.get("/claims")
16+
def claims() -> Dict[str, Any]:
17+
assert isinstance(app.current_event, BedrockAgentEvent)
18+
assert app.lambda_context == {}
19+
return {"output": claims_response}
20+
21+
# WHEN calling the event handler
22+
result = app(load_event("bedrockAgentEvent.json"), {})
23+
24+
# THEN process event correctly
25+
# AND set the current_event type as BedrockAgentEvent
26+
assert result["messageVersion"] == "1.0"
27+
assert result["response"]["apiPath"] == "/claims"
28+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
29+
assert result["response"]["httpMethod"] == "GET"
30+
assert result["response"]["httpStatusCode"] == 200
31+
32+
body = result["response"]["responseBody"]["application/json"]["body"]
33+
assert body == {"output": claims_response}
34+
35+
36+
def test_bedrock_agent_event_with_response():
37+
# GIVEN a Bedrock Agent event
38+
app = BedrockAgentResolver()
39+
output = json.dumps({"output": claims_response})
40+
41+
@app.get("/claims")
42+
def claims():
43+
assert isinstance(app.current_event, BedrockAgentEvent)
44+
assert app.lambda_context == {}
45+
return Response(200, content_types.APPLICATION_JSON, output)
46+
47+
# WHEN calling the event handler
48+
result = app(load_event("bedrockAgentEvent.json"), {})
49+
50+
# THEN process event correctly
51+
# AND set the current_event type as BedrockAgentEvent
52+
assert result["messageVersion"] == "1.0"
53+
assert result["response"]["apiPath"] == "/claims"
54+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
55+
assert result["response"]["httpMethod"] == "GET"
56+
assert result["response"]["httpStatusCode"] == 200
57+
58+
body = result["response"]["responseBody"]["application/json"]["body"]
59+
assert body == output
60+
61+
62+
def test_bedrock_agent_event_with_no_matches():
63+
# GIVEN a Bedrock Agent event
64+
app = BedrockAgentResolver()
65+
66+
@app.get("/no_match")
67+
def claims():
68+
raise RuntimeError()
69+
70+
# WHEN calling the event handler
71+
result = app(load_event("bedrockAgentEvent.json"), {})
72+
73+
# THEN process event correctly
74+
# AND return 404 because the event doesn't match any known rule
75+
assert result["messageVersion"] == "1.0"
76+
assert result["response"]["apiPath"] == "/claims"
77+
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
78+
assert result["response"]["httpMethod"] == "GET"
79+
assert result["response"]["httpStatusCode"] == 404

0 commit comments

Comments
 (0)