Skip to content

Commit ed3f07b

Browse files
feat(event_handler): add Bedrock Agent event handler (#3285)
Co-authored-by: Heitor Lessa <[email protected]>
1 parent f6a699f commit ed3f07b

File tree

5 files changed

+248
-16
lines changed

5 files changed

+248
-16
lines changed

Diff for: aws_lambda_powertools/event_handler/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Response,
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
14+
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
1415
from aws_lambda_powertools.event_handler.lambda_function_url import (
1516
LambdaFunctionUrlResolver,
1617
)
@@ -22,6 +23,7 @@
2223
"APIGatewayHttpResolver",
2324
"ALBResolver",
2425
"ApiGatewayResolver",
26+
"BedrockAgentResolver",
2527
"CORSConfig",
2628
"LambdaFunctionUrlResolver",
2729
"Response",

Diff for: aws_lambda_powertools/event_handler/api_gateway.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Any,
1616
Callable,
1717
Dict,
18+
Generic,
1819
List,
1920
Match,
2021
Optional,
@@ -23,6 +24,7 @@
2324
Set,
2425
Tuple,
2526
Type,
27+
TypeVar,
2628
Union,
2729
cast,
2830
)
@@ -45,6 +47,7 @@
4547
ALBEvent,
4648
APIGatewayProxyEvent,
4749
APIGatewayProxyEventV2,
50+
BedrockAgentEvent,
4851
LambdaFunctionUrlEvent,
4952
VPCLatticeEvent,
5053
VPCLatticeEventV2,
@@ -62,6 +65,8 @@
6265
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
6366
_ROUTE_REGEX = "^{}$"
6467

68+
ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
69+
6570
if TYPE_CHECKING:
6671
from aws_lambda_powertools.event_handler.openapi.compat import (
6772
JsonSchemaValue,
@@ -85,6 +90,7 @@ class ProxyEventType(Enum):
8590
APIGatewayProxyEvent = "APIGatewayProxyEvent"
8691
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
8792
ALBEvent = "ALBEvent"
93+
BedrockAgentEvent = "BedrockAgentEvent"
8894
VPCLatticeEvent = "VPCLatticeEvent"
8995
VPCLatticeEventV2 = "VPCLatticeEventV2"
9096
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
@@ -208,7 +214,7 @@ def __init__(
208214
self,
209215
status_code: int,
210216
content_type: Optional[str] = None,
211-
body: Union[str, bytes, None] = None,
217+
body: Any = None,
212218
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
213219
cookies: Optional[List[Cookie]] = None,
214220
compress: Optional[bool] = None,
@@ -235,6 +241,7 @@ def __init__(
235241
self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {}
236242
self.cookies = cookies or []
237243
self.compress = compress
244+
self.content_type = content_type
238245
if content_type:
239246
self.headers.setdefault("Content-Type", content_type)
240247

@@ -689,14 +696,14 @@ def _generate_operation_id(self) -> str:
689696
return operation_id
690697

691698

692-
class ResponseBuilder:
699+
class ResponseBuilder(Generic[ResponseEventT]):
693700
"""Internally used Response builder"""
694701

695702
def __init__(self, response: Response, route: Optional[Route] = None):
696703
self.response = response
697704
self.route = route
698705

699-
def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
706+
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
700707
"""Update headers to include the configured Access-Control headers"""
701708
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))
702709

@@ -709,7 +716,7 @@ def _add_cache_control(self, cache_control: str):
709716
def _has_compression_enabled(
710717
route_compression: bool,
711718
response_compression: Optional[bool],
712-
event: BaseProxyEvent,
719+
event: ResponseEventT,
713720
) -> bool:
714721
"""
715722
Checks if compression is enabled.
@@ -722,7 +729,7 @@ def _has_compression_enabled(
722729
A boolean indicating whether compression is enabled or not in the route setting.
723730
response_compression: bool, optional
724731
A boolean indicating whether compression is enabled or not in the response setting.
725-
event: BaseProxyEvent
732+
event: ResponseEventT
726733
The event object containing the request details.
727734
728735
Returns
@@ -752,7 +759,7 @@ def _compress(self):
752759
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
753760
self.response.body = gzip.compress(self.response.body) + gzip.flush()
754761

755-
def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
762+
def _route(self, event: ResponseEventT, cors: Optional[CORSConfig]):
756763
"""Optionally handle any of the route's configure response handling"""
757764
if self.route is None:
758765
return
@@ -767,7 +774,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
767774
):
768775
self._compress()
769776

770-
def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
777+
def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
771778
"""Build the full response dict to be returned by the lambda"""
772779
self._route(event, cors)
773780

@@ -1315,6 +1322,7 @@ def __init__(
13151322
self._strip_prefixes = strip_prefixes
13161323
self.context: Dict = {} # early init as customers might add context before event resolution
13171324
self.processed_stack_frames = []
1325+
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
13181326

13191327
# Allow for a custom serializer or a concise json serialization
13201328
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1784,14 +1792,17 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
17841792
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
17851793
return re.compile(base_regex.format(rule_regex))
17861794

1787-
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
1795+
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 # ignore many returns
17881796
"""Convert the event dict to the corresponding data class"""
17891797
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
17901798
logger.debug("Converting event to API Gateway REST API contract")
17911799
return APIGatewayProxyEvent(event)
17921800
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
17931801
logger.debug("Converting event to API Gateway HTTP API contract")
17941802
return APIGatewayProxyEventV2(event)
1803+
if self._proxy_type == ProxyEventType.BedrockAgentEvent:
1804+
logger.debug("Converting event to Bedrock Agent contract")
1805+
return BedrockAgentEvent(event)
17951806
if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent:
17961807
logger.debug("Converting event to Lambda Function URL contract")
17971808
return LambdaFunctionUrlEvent(event)
@@ -1869,9 +1880,9 @@ def _not_found(self, method: str) -> ResponseBuilder:
18691880

18701881
handler = self._lookup_exception_handler(NotFoundError)
18711882
if handler:
1872-
return ResponseBuilder(handler(NotFoundError()))
1883+
return self._response_builder_class(handler(NotFoundError()))
18731884

1874-
return ResponseBuilder(
1885+
return self._response_builder_class(
18751886
Response(
18761887
status_code=HTTPStatus.NOT_FOUND.value,
18771888
content_type=content_types.APPLICATION_JSON,
@@ -1886,7 +1897,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
18861897
# Reset Processed stack for Middleware (for debugging purposes)
18871898
self._reset_processed_stack()
18881899

1889-
return ResponseBuilder(
1900+
return self._response_builder_class(
18901901
self._to_response(
18911902
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
18921903
),
@@ -1903,7 +1914,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
19031914
# If the user has turned on debug mode,
19041915
# we'll let the original exception propagate, so
19051916
# they get more information about what went wrong.
1906-
return ResponseBuilder(
1917+
return self._response_builder_class(
19071918
Response(
19081919
status_code=500,
19091920
content_type=content_types.TEXT_PLAIN,
@@ -1942,12 +1953,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
19421953
handler = self._lookup_exception_handler(type(exp))
19431954
if handler:
19441955
try:
1945-
return ResponseBuilder(handler(exp), route)
1956+
return self._response_builder_class(handler(exp), route)
19461957
except ServiceError as service_error:
19471958
exp = service_error
19481959

19491960
if isinstance(exp, ServiceError):
1950-
return ResponseBuilder(
1961+
return self._response_builder_class(
19511962
Response(
19521963
status_code=exp.status_code,
19531964
content_type=content_types.APPLICATION_JSON,

Diff for: aws_lambda_powertools/event_handler/bedrock_agent.py

+77
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from typing import Any, Dict
2+
3+
from typing_extensions import override
4+
5+
from aws_lambda_powertools.event_handler import ApiGatewayResolver
6+
from aws_lambda_powertools.event_handler.api_gateway import (
7+
ProxyEventType,
8+
ResponseBuilder,
9+
)
10+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
11+
12+
13+
class BedrockResponseBuilder(ResponseBuilder):
14+
"""
15+
Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents.
16+
17+
Since the payload format is different from the standard API Gateway Proxy event, we override the build method.
18+
"""
19+
20+
@override
21+
def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
22+
"""Build the full response dict to be returned by the lambda"""
23+
self._route(event, None)
24+
25+
return {
26+
"messageVersion": "1.0",
27+
"response": {
28+
"actionGroup": event.action_group,
29+
"apiPath": event.api_path,
30+
"httpMethod": event.http_method,
31+
"httpStatusCode": self.response.status_code,
32+
"responseBody": {
33+
self.response.content_type: {
34+
"body": self.response.body,
35+
},
36+
},
37+
},
38+
}
39+
40+
41+
class BedrockAgentResolver(ApiGatewayResolver):
42+
"""Bedrock Agent Resolver
43+
44+
See https://aws.amazon.com/bedrock/agents/ for more information.
45+
46+
Examples
47+
--------
48+
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
49+
50+
```python
51+
from aws_lambda_powertools import Tracer
52+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
53+
54+
tracer = Tracer()
55+
app = BedrockAgentResolver()
56+
57+
@app.get("/claims")
58+
def simple_get():
59+
return "You have 3 claims"
60+
61+
@tracer.capture_lambda_handler
62+
def lambda_handler(event, context):
63+
return app.resolve(event, context)
64+
"""
65+
66+
current_event: BedrockAgentEvent
67+
68+
def __init__(self, debug: bool = False, enable_validation: bool = True):
69+
super().__init__(
70+
proxy_type=ProxyEventType.BedrockAgentEvent,
71+
cors=None,
72+
debug=debug,
73+
serializer=None,
74+
strip_prefixes=None,
75+
enable_validation=enable_validation,
76+
)
77+
self._response_builder_class = BedrockResponseBuilder

Diff for: aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import Dict, List, Optional
22

3-
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
3+
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper
44

55

66
class BedrockAgentInfo(DictWrapper):
@@ -47,7 +47,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
4747
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
4848

4949

50-
class BedrockAgentEvent(DictWrapper):
50+
class BedrockAgentEvent(BaseProxyEvent):
5151
"""
5252
Bedrock Agent input event
5353
@@ -97,3 +97,8 @@ def session_attributes(self) -> Dict[str, str]:
9797
@property
9898
def prompt_session_attributes(self) -> Dict[str, str]:
9999
return self["promptSessionAttributes"]
100+
101+
# For compatibility with BaseProxyEvent
102+
@property
103+
def path(self) -> str:
104+
return self["apiPath"]

0 commit comments

Comments
 (0)