Skip to content

feat(event_handler): add Bedrock Agent event handler #3285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Nov 8, 2023
2 changes: 2 additions & 0 deletions aws_lambda_powertools/event_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Response,
)
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
from aws_lambda_powertools.event_handler.lambda_function_url import (
LambdaFunctionUrlResolver,
)
Expand All @@ -22,6 +23,7 @@
"APIGatewayHttpResolver",
"ALBResolver",
"ApiGatewayResolver",
"BedrockAgentResolver",
"CORSConfig",
"LambdaFunctionUrlResolver",
"Response",
Expand Down
39 changes: 25 additions & 14 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Any,
Callable,
Dict,
Generic,
List,
Match,
Optional,
Expand All @@ -23,6 +24,7 @@
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)
Expand All @@ -45,6 +47,7 @@
ALBEvent,
APIGatewayProxyEvent,
APIGatewayProxyEventV2,
BedrockAgentEvent,
LambdaFunctionUrlEvent,
VPCLatticeEvent,
VPCLatticeEventV2,
Expand All @@ -62,6 +65,8 @@
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
_ROUTE_REGEX = "^{}$"

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.openapi.compat import (
JsonSchemaValue,
Expand All @@ -85,6 +90,7 @@ class ProxyEventType(Enum):
APIGatewayProxyEvent = "APIGatewayProxyEvent"
APIGatewayProxyEventV2 = "APIGatewayProxyEventV2"
ALBEvent = "ALBEvent"
BedrockAgentEvent = "BedrockAgentEvent"
VPCLatticeEvent = "VPCLatticeEvent"
VPCLatticeEventV2 = "VPCLatticeEventV2"
LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent"
Expand Down Expand Up @@ -208,7 +214,7 @@ def __init__(
self,
status_code: int,
content_type: Optional[str] = None,
body: Union[str, bytes, None] = None,
body: Any = None,
headers: Optional[Dict[str, Union[str, List[str]]]] = None,
cookies: Optional[List[Cookie]] = None,
compress: Optional[bool] = None,
Expand All @@ -235,6 +241,7 @@ def __init__(
self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {}
self.cookies = cookies or []
self.compress = compress
self.content_type = content_type
if content_type:
self.headers.setdefault("Content-Type", content_type)

Expand Down Expand Up @@ -689,14 +696,14 @@ def _generate_operation_id(self) -> str:
return operation_id


class ResponseBuilder:
class ResponseBuilder(Generic[ResponseEventT]):
"""Internally used Response builder"""

def __init__(self, response: Response, route: Optional[Route] = None):
self.response = response
self.route = route

def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig):
def _add_cors(self, event: ResponseEventT, cors: CORSConfig):
"""Update headers to include the configured Access-Control headers"""
self.response.headers.update(cors.to_dict(event.get_header_value("Origin")))

Expand All @@ -709,7 +716,7 @@ def _add_cache_control(self, cache_control: str):
def _has_compression_enabled(
route_compression: bool,
response_compression: Optional[bool],
event: BaseProxyEvent,
event: ResponseEventT,
) -> bool:
"""
Checks if compression is enabled.
Expand All @@ -722,7 +729,7 @@ def _has_compression_enabled(
A boolean indicating whether compression is enabled or not in the route setting.
response_compression: bool, optional
A boolean indicating whether compression is enabled or not in the response setting.
event: BaseProxyEvent
event: ResponseEventT
The event object containing the request details.
Returns
Expand Down Expand Up @@ -752,7 +759,7 @@ def _compress(self):
gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16)
self.response.body = gzip.compress(self.response.body) + gzip.flush()

def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
def _route(self, event: ResponseEventT, cors: Optional[CORSConfig]):
"""Optionally handle any of the route's configure response handling"""
if self.route is None:
return
Expand All @@ -767,7 +774,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]):
):
self._compress()

def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, cors)

Expand Down Expand Up @@ -1315,6 +1322,7 @@ def __init__(
self._strip_prefixes = strip_prefixes
self.context: Dict = {} # early init as customers might add context before event resolution
self.processed_stack_frames = []
self._response_builder_class = ResponseBuilder[BaseProxyEvent]

# Allow for a custom serializer or a concise json serialization
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
Expand Down Expand Up @@ -1784,14 +1792,17 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):
rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule)
return re.compile(base_regex.format(rule_regex))

def _to_proxy_event(self, event: Dict) -> BaseProxyEvent:
def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 # ignore many returns
"""Convert the event dict to the corresponding data class"""
if self._proxy_type == ProxyEventType.APIGatewayProxyEvent:
logger.debug("Converting event to API Gateway REST API contract")
return APIGatewayProxyEvent(event)
if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2:
logger.debug("Converting event to API Gateway HTTP API contract")
return APIGatewayProxyEventV2(event)
if self._proxy_type == ProxyEventType.BedrockAgentEvent:
logger.debug("Converting event to Bedrock Agent contract")
return BedrockAgentEvent(event)
if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent:
logger.debug("Converting event to Lambda Function URL contract")
return LambdaFunctionUrlEvent(event)
Expand Down Expand Up @@ -1869,9 +1880,9 @@ def _not_found(self, method: str) -> ResponseBuilder:

handler = self._lookup_exception_handler(NotFoundError)
if handler:
return ResponseBuilder(handler(NotFoundError()))
return self._response_builder_class(handler(NotFoundError()))

return ResponseBuilder(
return self._response_builder_class(
Response(
status_code=HTTPStatus.NOT_FOUND.value,
content_type=content_types.APPLICATION_JSON,
Expand All @@ -1886,7 +1897,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
# Reset Processed stack for Middleware (for debugging purposes)
self._reset_processed_stack()

return ResponseBuilder(
return self._response_builder_class(
self._to_response(
route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments),
),
Expand All @@ -1903,7 +1914,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response
# If the user has turned on debug mode,
# we'll let the original exception propagate, so
# they get more information about what went wrong.
return ResponseBuilder(
return self._response_builder_class(
Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
Expand Down Expand Up @@ -1942,12 +1953,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp
handler = self._lookup_exception_handler(type(exp))
if handler:
try:
return ResponseBuilder(handler(exp), route)
return self._response_builder_class(handler(exp), route)
except ServiceError as service_error:
exp = service_error

if isinstance(exp, ServiceError):
return ResponseBuilder(
return self._response_builder_class(
Response(
status_code=exp.status_code,
content_type=content_types.APPLICATION_JSON,
Expand Down
77 changes: 77 additions & 0 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Any, Dict

from typing_extensions import override

from aws_lambda_powertools.event_handler import ApiGatewayResolver
from aws_lambda_powertools.event_handler.api_gateway import (
ProxyEventType,
ResponseBuilder,
)
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent


class BedrockResponseBuilder(ResponseBuilder):
"""
Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents.
Since the payload format is different from the standard API Gateway Proxy event, we override the build method.
"""

@override
def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]:
"""Build the full response dict to be returned by the lambda"""
self._route(event, None)

return {
"messageVersion": "1.0",
"response": {
"actionGroup": event.action_group,
"apiPath": event.api_path,
"httpMethod": event.http_method,
"httpStatusCode": self.response.status_code,
"responseBody": {
self.response.content_type: {
"body": self.response.body,
},
},
},
}


class BedrockAgentResolver(ApiGatewayResolver):
"""Bedrock Agent Resolver
See https://aws.amazon.com/bedrock/agents/ for more information.
Examples
--------
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
```python
from aws_lambda_powertools import Tracer
from aws_lambda_powertools.event_handler import BedrockAgentResolver
tracer = Tracer()
app = BedrockAgentResolver()
@app.get("/claims")
def simple_get():
return "You have 3 claims"
@tracer.capture_lambda_handler
def lambda_handler(event, context):
return app.resolve(event, context)
"""

current_event: BedrockAgentEvent

def __init__(self, debug: bool = False, enable_validation: bool = True):
super().__init__(
proxy_type=ProxyEventType.BedrockAgentEvent,
cors=None,
debug=debug,
serializer=None,
strip_prefixes=None,
enable_validation=enable_validation,
)
self._response_builder_class = BedrockResponseBuilder
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper


class BedrockAgentInfo(DictWrapper):
Expand Down Expand Up @@ -47,7 +47,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}


class BedrockAgentEvent(DictWrapper):
class BedrockAgentEvent(BaseProxyEvent):
"""
Bedrock Agent input event
Expand Down Expand Up @@ -97,3 +97,8 @@ def session_attributes(self) -> Dict[str, str]:
@property
def prompt_session_attributes(self) -> Dict[str, str]:
return self["promptSessionAttributes"]

# For compatibility with BaseProxyEvent
@property
def path(self) -> str:
return self["apiPath"]
Loading