diff --git a/aws_lambda_powertools/utilities/data_classes/__init__.py b/aws_lambda_powertools/utilities/data_classes/__init__.py index d245bc35f0d..fd9294bc8bb 100644 --- a/aws_lambda_powertools/utilities/data_classes/__init__.py +++ b/aws_lambda_powertools/utilities/data_classes/__init__.py @@ -6,6 +6,7 @@ from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2 from .appsync_resolver_event import AppSyncResolverEvent from .aws_config_rule_event import AWSConfigRuleEvent +from .bedrock_agent_event import BedrockAgentEvent from .cloud_watch_custom_widget_event import CloudWatchDashboardCustomWidgetEvent from .cloud_watch_logs_event import CloudWatchLogsEvent from .code_pipeline_job_event import CodePipelineJobEvent @@ -35,6 +36,7 @@ "SecretsManagerEvent", "AppSyncResolverEvent", "ALBEvent", + "BedrockAgentEvent", "CloudWatchDashboardCustomWidgetEvent", "CloudWatchLogsEvent", "CodePipelineJobEvent", diff --git a/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py new file mode 100644 index 00000000000..b482b5b2b3e --- /dev/null +++ b/aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py @@ -0,0 +1,99 @@ +from typing import Dict, List, Optional + +from aws_lambda_powertools.utilities.data_classes.common import DictWrapper + + +class BedrockAgentInfo(DictWrapper): + @property + def name(self) -> str: + return self["name"] + + @property + def id(self) -> str: # noqa: A003 + return self["id"] + + @property + def alias(self) -> str: + return self["alias"] + + @property + def version(self) -> str: + return self["version"] + + +class BedrockAgentProperty(DictWrapper): + @property + def name(self) -> str: + return self["name"] + + @property + def type(self) -> str: # noqa: A003 + return self["type"] + + @property + def value(self) -> str: + return self["value"] + + +class BedrockAgentRequestMedia(DictWrapper): + @property + def properties(self) -> List[BedrockAgentProperty]: + return [BedrockAgentProperty(x) for x in self["properties"]] + + +class BedrockAgentRequestBody(DictWrapper): + @property + def content(self) -> Dict[str, BedrockAgentRequestMedia]: + return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()} + + +class BedrockAgentEvent(DictWrapper): + """ + Bedrock Agent input event + + See https://docs.aws.amazon.com/bedrock/latest/userguide/agents-create.html + """ + + @property + def message_version(self) -> str: + return self["messageVersion"] + + @property + def input_text(self) -> str: + return self["inputText"] + + @property + def session_id(self) -> str: + return self["sessionId"] + + @property + def action_group(self) -> str: + return self["actionGroup"] + + @property + def api_path(self) -> str: + return self["apiPath"] + + @property + def http_method(self) -> str: + return self["httpMethod"] + + @property + def parameters(self) -> Optional[List[BedrockAgentProperty]]: + return [BedrockAgentProperty(x) for x in self["parameters"]] if self.get("parameters") else None + + @property + def request_body(self) -> Optional[BedrockAgentRequestBody]: + return BedrockAgentRequestBody(self["requestBody"]) if self.get("requestBody") else None + + @property + def agent(self) -> BedrockAgentInfo: + return BedrockAgentInfo(self["agent"]) + + @property + def session_attributes(self) -> Dict[str, str]: + return self["sessionAttributes"] + + @property + def prompt_session_attributes(self) -> Dict[str, str]: + return self["promptSessionAttributes"] diff --git a/docs/utilities/data_classes.md b/docs/utilities/data_classes.md index 7cc966313fb..37d3725967c 100644 --- a/docs/utilities/data_classes.md +++ b/docs/utilities/data_classes.md @@ -85,6 +85,7 @@ Log Data Event for Troubleshooting | [AppSync Authorizer](#appsync-authorizer) | `AppSyncAuthorizerEvent` | | [AppSync Resolver](#appsync-resolver) | `AppSyncResolverEvent` | | [AWS Config Rule](#aws-config-rule) | `AWSConfigRuleEvent` | +| [Bedrock Agent](#bedrock-agent) | `BedrockAgent` | | [CloudWatch Dashboard Custom Widget](#cloudwatch-dashboard-custom-widget) | `CloudWatchDashboardCustomWidgetEvent` | | [CloudWatch Logs](#cloudwatch-logs) | `CloudWatchLogsEvent` | | [CodePipeline Job Event](#codepipeline-job) | `CodePipelineJobEvent` | @@ -484,6 +485,14 @@ In this example, we also use the new Logger `correlation_id` and built-in `corre --8<-- "examples/event_sources/src/aws_config_rule_scheduled.json" ``` +### Bedrock Agent + +=== "app.py" + + ```python hl_lines="2 8 10" + --8<-- "examples/event_sources/src/bedrock_agent_event.py" + ``` + ### CloudWatch Dashboard Custom Widget === "app.py" diff --git a/examples/event_sources/src/bedrock_agent_event.py b/examples/event_sources/src/bedrock_agent_event.py new file mode 100644 index 00000000000..b16d3c86bad --- /dev/null +++ b/examples/event_sources/src/bedrock_agent_event.py @@ -0,0 +1,25 @@ +from aws_lambda_powertools import Logger +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent, event_source +from aws_lambda_powertools.utilities.typing import LambdaContext + +logger = Logger() + + +@event_source(data_class=BedrockAgentEvent) +def lambda_handler(event: BedrockAgentEvent, context: LambdaContext) -> dict: + input_text = event.input_text + + logger.info(f"Bedrock Agent {event.action_group} invoked with input", input_text=input_text) + + return { + "message_version": "1.0", + "responses": [ + { + "action_group": event.action_group, + "api_path": event.api_path, + "http_method": event.http_method, + "http_status_code": 200, + "response_body": {"application/json": {"body": "This is the response"}}, + }, + ], + } diff --git a/tests/events/bedrockAgentEvent.json b/tests/events/bedrockAgentEvent.json new file mode 100644 index 00000000000..b7ad75b3c43 --- /dev/null +++ b/tests/events/bedrockAgentEvent.json @@ -0,0 +1,16 @@ +{ + "actionGroup": "ClaimManagementActionGroup", + "messageVersion": "1.0", + "sessionId": "12345678912345", + "sessionAttributes": {}, + "promptSessionAttributes": {}, + "inputText": "I want to claim my insurance", + "agent": { + "alias": "TSTALIASID", + "name": "test", + "version": "DRAFT", + "id": "8ZXY0W8P1H" + }, + "httpMethod": "GET", + "apiPath": "/claims" +} diff --git a/tests/events/bedrockAgentPostEvent.json b/tests/events/bedrockAgentPostEvent.json new file mode 100644 index 00000000000..f223bfcd516 --- /dev/null +++ b/tests/events/bedrockAgentPostEvent.json @@ -0,0 +1,35 @@ +{ + "actionGroup": "ClaimManagementActionGroup", + "messageVersion": "1.0", + "sessionId": "12345678912345", + "sessionAttributes": {}, + "promptSessionAttributes": {}, + "inputText": "Send reminders to all pending documents", + "agent": { + "alias": "TSTALIASID", + "name": "test", + "version": "DRAFT", + "id": "8ZXY0W8P1H" + }, + "httpMethod": "POST", + "apiPath": "/send-reminders", + "requestBody": { + "content": { + "application/json": { + "properties": [ + { + "name": "claimId", + "type": "string", + "value": "20" + }, + { + "name": "pendingDocuments", + "type": "string", + "value": "social number and vat" + } + ] + } + } + }, + "parameters": [] +} diff --git a/tests/unit/data_classes/test_bedrock_agent_event.py b/tests/unit/data_classes/test_bedrock_agent_event.py new file mode 100644 index 00000000000..c4b56695774 --- /dev/null +++ b/tests/unit/data_classes/test_bedrock_agent_event.py @@ -0,0 +1,62 @@ +from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent +from tests.functional.utils import load_event + + +def test_bedrock_agent_event(): + raw_event = load_event("bedrockAgentEvent.json") + parsed_event = BedrockAgentEvent(raw_event) + + assert parsed_event.session_id == raw_event["sessionId"] + assert parsed_event.input_text == raw_event["inputText"] + assert parsed_event.message_version == raw_event["messageVersion"] + assert parsed_event.http_method == raw_event["httpMethod"] + assert parsed_event.api_path == raw_event["apiPath"] + assert parsed_event.session_attributes == {} + assert parsed_event.prompt_session_attributes == {} + assert parsed_event.action_group == raw_event["actionGroup"] + + assert parsed_event.request_body is None + + agent = parsed_event.agent + raw_agent = raw_event["agent"] + assert agent.alias == raw_agent["alias"] + assert agent.name == raw_agent["name"] + assert agent.version == raw_agent["version"] + assert agent.id == raw_agent["id"] + + +def test_bedrock_agent_event_with_post(): + raw_event = load_event("bedrockAgentPostEvent.json") + parsed_event = BedrockAgentEvent(raw_event) + + assert parsed_event.session_id == raw_event["sessionId"] + assert parsed_event.input_text == raw_event["inputText"] + assert parsed_event.message_version == raw_event["messageVersion"] + assert parsed_event.http_method == raw_event["httpMethod"] + assert parsed_event.api_path == raw_event["apiPath"] + assert parsed_event.session_attributes == {} + assert parsed_event.prompt_session_attributes == {} + assert parsed_event.action_group == raw_event["actionGroup"] + + agent = parsed_event.agent + raw_agent = raw_event["agent"] + assert agent.alias == raw_agent["alias"] + assert agent.name == raw_agent["name"] + assert agent.version == raw_agent["version"] + assert agent.id == raw_agent["id"] + + request_body = parsed_event.request_body.content + assert "application/json" in request_body + + json_request = request_body["application/json"] + properties = json_request.properties + assert len(properties) == 2 + + raw_properties = raw_event["requestBody"]["content"]["application/json"]["properties"] + assert properties[0].name == raw_properties[0]["name"] + assert properties[0].type == raw_properties[0]["type"] + assert properties[0].value == raw_properties[0]["value"] + + assert properties[1].name == raw_properties[1]["name"] + assert properties[1].type == raw_properties[1]["type"] + assert properties[1].value == raw_properties[1]["value"]