diff --git a/aws_lambda_powertools/event_handler/__init__.py b/aws_lambda_powertools/event_handler/__init__.py index 7bdd9a97f72..ffbb2abe4ae 100644 --- a/aws_lambda_powertools/event_handler/__init__.py +++ b/aws_lambda_powertools/event_handler/__init__.py @@ -11,6 +11,7 @@ Response, ) from aws_lambda_powertools.event_handler.appsync import AppSyncResolver +from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver from aws_lambda_powertools.event_handler.lambda_function_url import ( LambdaFunctionUrlResolver, ) @@ -22,6 +23,7 @@ "APIGatewayHttpResolver", "ALBResolver", "ApiGatewayResolver", + "BedrockAgentResolver", "CORSConfig", "LambdaFunctionUrlResolver", "Response", diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 0ddf287f264..1e494fd1c0f 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -15,6 +15,7 @@ Any, Callable, Dict, + Generic, List, Match, Optional, @@ -23,6 +24,7 @@ Set, Tuple, Type, + TypeVar, Union, cast, ) @@ -45,6 +47,7 @@ ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2, + BedrockAgentEvent, LambdaFunctionUrlEvent, VPCLatticeEvent, VPCLatticeEventV2, @@ -62,6 +65,8 @@ _DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response" _ROUTE_REGEX = "^{}$" +ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent) + if TYPE_CHECKING: from aws_lambda_powertools.event_handler.openapi.compat import ( JsonSchemaValue, @@ -85,6 +90,7 @@ class ProxyEventType(Enum): APIGatewayProxyEvent = "APIGatewayProxyEvent" APIGatewayProxyEventV2 = "APIGatewayProxyEventV2" ALBEvent = "ALBEvent" + BedrockAgentEvent = "BedrockAgentEvent" VPCLatticeEvent = "VPCLatticeEvent" VPCLatticeEventV2 = "VPCLatticeEventV2" LambdaFunctionUrlEvent = "LambdaFunctionUrlEvent" @@ -208,7 +214,7 @@ def __init__( self, status_code: int, content_type: Optional[str] = None, - body: Union[str, bytes, None] = None, + body: Any = None, headers: Optional[Dict[str, Union[str, List[str]]]] = None, cookies: Optional[List[Cookie]] = None, compress: Optional[bool] = None, @@ -235,6 +241,7 @@ def __init__( self.headers: Dict[str, Union[str, List[str]]] = headers if headers else {} self.cookies = cookies or [] self.compress = compress + self.content_type = content_type if content_type: self.headers.setdefault("Content-Type", content_type) @@ -689,14 +696,14 @@ def _generate_operation_id(self) -> str: return operation_id -class ResponseBuilder: +class ResponseBuilder(Generic[ResponseEventT]): """Internally used Response builder""" def __init__(self, response: Response, route: Optional[Route] = None): self.response = response self.route = route - def _add_cors(self, event: BaseProxyEvent, cors: CORSConfig): + def _add_cors(self, event: ResponseEventT, cors: CORSConfig): """Update headers to include the configured Access-Control headers""" self.response.headers.update(cors.to_dict(event.get_header_value("Origin"))) @@ -709,7 +716,7 @@ def _add_cache_control(self, cache_control: str): def _has_compression_enabled( route_compression: bool, response_compression: Optional[bool], - event: BaseProxyEvent, + event: ResponseEventT, ) -> bool: """ Checks if compression is enabled. @@ -722,7 +729,7 @@ def _has_compression_enabled( A boolean indicating whether compression is enabled or not in the route setting. response_compression: bool, optional A boolean indicating whether compression is enabled or not in the response setting. - event: BaseProxyEvent + event: ResponseEventT The event object containing the request details. Returns @@ -752,7 +759,7 @@ def _compress(self): gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) self.response.body = gzip.compress(self.response.body) + gzip.flush() - def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): + def _route(self, event: ResponseEventT, cors: Optional[CORSConfig]): """Optionally handle any of the route's configure response handling""" if self.route is None: return @@ -767,7 +774,7 @@ def _route(self, event: BaseProxyEvent, cors: Optional[CORSConfig]): ): self._compress() - def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: + def build(self, event: ResponseEventT, cors: Optional[CORSConfig] = None) -> Dict[str, Any]: """Build the full response dict to be returned by the lambda""" self._route(event, cors) @@ -1315,6 +1322,7 @@ def __init__( self._strip_prefixes = strip_prefixes self.context: Dict = {} # early init as customers might add context before event resolution self.processed_stack_frames = [] + self._response_builder_class = ResponseBuilder[BaseProxyEvent] # Allow for a custom serializer or a concise json serialization self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder) @@ -1784,7 +1792,7 @@ def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX): rule_regex: str = re.sub(_DYNAMIC_ROUTE_PATTERN, _NAMED_GROUP_BOUNDARY_PATTERN, rule) return re.compile(base_regex.format(rule_regex)) - def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: + def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: # noqa: PLR0911 # ignore many returns """Convert the event dict to the corresponding data class""" if self._proxy_type == ProxyEventType.APIGatewayProxyEvent: logger.debug("Converting event to API Gateway REST API contract") @@ -1792,6 +1800,9 @@ def _to_proxy_event(self, event: Dict) -> BaseProxyEvent: if self._proxy_type == ProxyEventType.APIGatewayProxyEventV2: logger.debug("Converting event to API Gateway HTTP API contract") return APIGatewayProxyEventV2(event) + if self._proxy_type == ProxyEventType.BedrockAgentEvent: + logger.debug("Converting event to Bedrock Agent contract") + return BedrockAgentEvent(event) if self._proxy_type == ProxyEventType.LambdaFunctionUrlEvent: logger.debug("Converting event to Lambda Function URL contract") return LambdaFunctionUrlEvent(event) @@ -1869,9 +1880,9 @@ def _not_found(self, method: str) -> ResponseBuilder: handler = self._lookup_exception_handler(NotFoundError) if handler: - return ResponseBuilder(handler(NotFoundError())) + return self._response_builder_class(handler(NotFoundError())) - return ResponseBuilder( + return self._response_builder_class( Response( status_code=HTTPStatus.NOT_FOUND.value, content_type=content_types.APPLICATION_JSON, @@ -1886,7 +1897,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # Reset Processed stack for Middleware (for debugging purposes) self._reset_processed_stack() - return ResponseBuilder( + return self._response_builder_class( self._to_response( route(router_middlewares=self._router_middlewares, app=self, route_arguments=route_arguments), ), @@ -1903,7 +1914,7 @@ def _call_route(self, route: Route, route_arguments: Dict[str, str]) -> Response # If the user has turned on debug mode, # we'll let the original exception propagate, so # they get more information about what went wrong. - return ResponseBuilder( + return self._response_builder_class( Response( status_code=500, content_type=content_types.TEXT_PLAIN, @@ -1942,12 +1953,12 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> Optional[Resp handler = self._lookup_exception_handler(type(exp)) if handler: try: - return ResponseBuilder(handler(exp), route) + return self._response_builder_class(handler(exp), route) except ServiceError as service_error: exp = service_error if isinstance(exp, ServiceError): - return ResponseBuilder( + return self._response_builder_class( Response( status_code=exp.status_code, content_type=content_types.APPLICATION_JSON, diff --git a/aws_lambda_powertools/event_handler/bedrock_agent.py b/aws_lambda_powertools/event_handler/bedrock_agent.py new file mode 100644 index 00000000000..258fc7dcaee --- /dev/null +++ b/aws_lambda_powertools/event_handler/bedrock_agent.py @@ -0,0 +1,77 @@ +from typing import Any, Dict + +from typing_extensions import override + +from aws_lambda_powertools.event_handler import ApiGatewayResolver +from aws_lambda_powertools.event_handler.api_gateway import ( + ProxyEventType, + ResponseBuilder, +) +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent + + +class BedrockResponseBuilder(ResponseBuilder): + """ + Bedrock Response Builder. This builds the response dict to be returned by Lambda when using Bedrock Agents. + + Since the payload format is different from the standard API Gateway Proxy event, we override the build method. + """ + + @override + def build(self, event: BedrockAgentEvent, *args) -> Dict[str, Any]: + """Build the full response dict to be returned by the lambda""" + self._route(event, None) + + return { + "messageVersion": "1.0", + "response": { + "actionGroup": event.action_group, + "apiPath": event.api_path, + "httpMethod": event.http_method, + "httpStatusCode": self.response.status_code, + "responseBody": { + self.response.content_type: { + "body": self.response.body, + }, + }, + }, + } + + +class BedrockAgentResolver(ApiGatewayResolver): + """Bedrock Agent Resolver + + See https://aws.amazon.com/bedrock/agents/ for more information. + + Examples + -------- + Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator + + ```python + from aws_lambda_powertools import Tracer + from aws_lambda_powertools.event_handler import BedrockAgentResolver + + tracer = Tracer() + app = BedrockAgentResolver() + + @app.get("/claims") + def simple_get(): + return "You have 3 claims" + + @tracer.capture_lambda_handler + def lambda_handler(event, context): + return app.resolve(event, context) + """ + + current_event: BedrockAgentEvent + + def __init__(self, debug: bool = False, enable_validation: bool = True): + super().__init__( + proxy_type=ProxyEventType.BedrockAgentEvent, + cors=None, + debug=debug, + serializer=None, + strip_prefixes=None, + enable_validation=enable_validation, + ) + self._response_builder_class = BedrockResponseBuilder diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py index b482b5b2b3e..1577ad62895 100644 --- a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from aws_lambda_powertools.utilities.data_classes.common import DictWrapper +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent, DictWrapper class BedrockAgentInfo(DictWrapper): @@ -47,7 +47,7 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]: return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()} -class BedrockAgentEvent(DictWrapper): +class BedrockAgentEvent(BaseProxyEvent): """ Bedrock Agent input event @@ -97,3 +97,8 @@ def session_attributes(self) -> Dict[str, str]: @property def prompt_session_attributes(self) -> Dict[str, str]: return self["promptSessionAttributes"] + + # For compatibility with BaseProxyEvent + @property + def path(self) -> str: + return self["apiPath"] diff --git a/tests/functional/event_handler/test_bedrock_agent.py b/tests/functional/event_handler/test_bedrock_agent.py new file mode 100644 index 00000000000..dcdca460d25 --- /dev/null +++ b/tests/functional/event_handler/test_bedrock_agent.py @@ -0,0 +1,137 @@ +import json +from typing import Any, Dict + +from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types +from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2 +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent +from tests.functional.utils import load_event + +claims_response = "You have 3 claims" + + +def test_bedrock_agent_event(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims") + def claims() -> Dict[str, Any]: + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return {"output": claims_response} + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 200 + + body = result["response"]["responseBody"]["application/json"]["body"] + assert body == json.dumps({"output": claims_response}) + + +def test_bedrock_agent_event_with_response(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + output = {"output": claims_response} + + @app.get("/claims") + def claims(): + assert isinstance(app.current_event, BedrockAgentEvent) + assert app.lambda_context == {} + return Response(200, content_types.APPLICATION_JSON, output) + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 200 + + body = result["response"]["responseBody"]["application/json"]["body"] + assert body == json.dumps(output) + + +def test_bedrock_agent_event_with_no_matches(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/no_match") + def claims(): + raise RuntimeError() + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND return 404 because the event doesn't match any known rule + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 404 + + +def test_bedrock_agent_event_with_validation_error(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.get("/claims") + def claims() -> Dict[str, Any]: + return "oh no, this is not a dict" # type: ignore + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as BedrockAgentEvent + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 422 + + body = result["response"]["responseBody"]["application/json"]["body"] + if PYDANTIC_V2: + assert "should be a valid dictionary" in body + else: + assert "value is not a valid dict" in body + + +def test_bedrock_agent_event_with_exception(): + # GIVEN a Bedrock Agent event + app = BedrockAgentResolver() + + @app.exception_handler(RuntimeError) + def handle_runtime_error(ex: RuntimeError): + return Response( + status_code=500, + content_type=content_types.TEXT_PLAIN, + body="Something went wrong", + ) + + @app.get("/claims") + def claims(): + raise RuntimeError() + + # WHEN calling the event handler + result = app(load_event("bedrockAgentEvent.json"), {}) + + # THEN process the exception correctly + # AND return 500 because of the internal server error + assert result["messageVersion"] == "1.0" + assert result["response"]["apiPath"] == "/claims" + assert result["response"]["actionGroup"] == "ClaimManagementActionGroup" + assert result["response"]["httpMethod"] == "GET" + assert result["response"]["httpStatusCode"] == 500 + + body = result["response"]["responseBody"]["text/plain"]["body"] + assert body == "Something went wrong"