Skip to content

feat(data_classes): add support for Bedrock Agents event #3262

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Nov 2, 2023
2 changes: 2 additions & 0 deletions aws_lambda_powertools/utilities/data_classes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2
from .appsync_resolver_event import AppSyncResolverEvent
from .aws_config_rule_event import AWSConfigRuleEvent
from .bedrock_agent_event import BedrockAgentEvent
from .cloud_watch_custom_widget_event import CloudWatchDashboardCustomWidgetEvent
from .cloud_watch_logs_event import CloudWatchLogsEvent
from .code_pipeline_job_event import CodePipelineJobEvent
Expand Down Expand Up @@ -35,6 +36,7 @@
"SecretsManagerEvent",
"AppSyncResolverEvent",
"ALBEvent",
"BedrockAgentEvent",
"CloudWatchDashboardCustomWidgetEvent",
"CloudWatchLogsEvent",
"CodePipelineJobEvent",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
from typing import Dict, List, Optional

from aws_lambda_powertools.utilities.data_classes.common import DictWrapper


class BedrockAgentInfo(DictWrapper):
@property
def name(self) -> str:
return self["name"]

@property
def id(self) -> str: # noqa: A003
return self["id"]

@property
def alias(self) -> str:
return self["alias"]

@property
def version(self) -> str:
return self["version"]


class BedrockAgentProperty(DictWrapper):
@property
def name(self) -> str:
return self["name"]

@property
def type(self) -> str: # noqa: A003
return self["type"]

@property
def value(self) -> str:
return self["value"]


class BedrockAgentRequestMedia(DictWrapper):
@property
def properties(self) -> List[BedrockAgentProperty]:
return [BedrockAgentProperty(x) for x in self["properties"]]


class BedrockAgentRequestBody(DictWrapper):
@property
def content(self) -> Dict[str, BedrockAgentRequestMedia]:
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}


class BedrockAgentEvent(DictWrapper):
"""
Bedrock Agent input event

See https://docs.aws.amazon.com/bedrock/latest/userguide/agents-create.html
"""

@property
def message_version(self) -> str:
return self["messageVersion"]

@property
def input_text(self) -> str:
return self["inputText"]

@property
def session_id(self) -> str:
return self["sessionId"]

@property
def action_group(self) -> str:
return self["actionGroup"]

@property
def api_path(self) -> str:
return self["apiPath"]

@property
def http_method(self) -> str:
return self["httpMethod"]

@property
def parameters(self) -> Optional[List[BedrockAgentProperty]]:
return [BedrockAgentProperty(x) for x in self["parameters"]] if self.get("parameters") else None

@property
def request_body(self) -> Optional[BedrockAgentRequestBody]:
return BedrockAgentRequestBody(self["requestBody"]) if self.get("requestBody") else None

@property
def agent(self) -> BedrockAgentInfo:
return BedrockAgentInfo(self["agent"])

@property
def session_attributes(self) -> Dict[str, str]:
return self["sessionAttributes"]

@property
def prompt_session_attributes(self) -> Dict[str, str]:
return self["promptSessionAttributes"]
9 changes: 9 additions & 0 deletions docs/utilities/data_classes.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ Log Data Event for Troubleshooting
| [AppSync Authorizer](#appsync-authorizer) | `AppSyncAuthorizerEvent` |
| [AppSync Resolver](#appsync-resolver) | `AppSyncResolverEvent` |
| [AWS Config Rule](#aws-config-rule) | `AWSConfigRuleEvent` |
| [Bedrock Agent](#bedrock-agent) | `BedrockAgent` |
| [CloudWatch Dashboard Custom Widget](#cloudwatch-dashboard-custom-widget) | `CloudWatchDashboardCustomWidgetEvent` |
| [CloudWatch Logs](#cloudwatch-logs) | `CloudWatchLogsEvent` |
| [CodePipeline Job Event](#codepipeline-job) | `CodePipelineJobEvent` |
Expand Down Expand Up @@ -484,6 +485,14 @@ In this example, we also use the new Logger `correlation_id` and built-in `corre
--8<-- "examples/event_sources/src/aws_config_rule_scheduled.json"
```

### Bedrock Agent

=== "app.py"

```python hl_lines="2 8 10"
--8<-- "examples/event_sources/src/bedrock_agent_event.py"
```

### CloudWatch Dashboard Custom Widget

=== "app.py"
Expand Down
25 changes: 25 additions & 0 deletions examples/event_sources/src/bedrock_agent_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from aws_lambda_powertools import Logger
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent, event_source
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()


@event_source(data_class=BedrockAgentEvent)
def lambda_handler(event: BedrockAgentEvent, context: LambdaContext) -> dict:
input_text = event.input_text

logger.info(f"Bedrock Agent {event.action_group} invoked with input", input_text=input_text)

return {
"message_version": "1.0",
"responses": [
{
"action_group": event.action_group,
"api_path": event.api_path,
"http_method": event.http_method,
"http_status_code": 200,
"response_body": {"application/json": {"body": "This is the response"}},
},
],
}
16 changes: 16 additions & 0 deletions tests/events/bedrockAgentEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"actionGroup": "ClaimManagementActionGroup",
"messageVersion": "1.0",
"sessionId": "12345678912345",
"sessionAttributes": {},
"promptSessionAttributes": {},
"inputText": "I want to claim my insurance",
"agent": {
"alias": "TSTALIASID",
"name": "test",
"version": "DRAFT",
"id": "8ZXY0W8P1H"
},
"httpMethod": "GET",
"apiPath": "/claims"
}
35 changes: 35 additions & 0 deletions tests/events/bedrockAgentPostEvent.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
{
"actionGroup": "ClaimManagementActionGroup",
"messageVersion": "1.0",
"sessionId": "12345678912345",
"sessionAttributes": {},
"promptSessionAttributes": {},
"inputText": "Send reminders to all pending documents",
"agent": {
"alias": "TSTALIASID",
"name": "test",
"version": "DRAFT",
"id": "8ZXY0W8P1H"
},
"httpMethod": "POST",
"apiPath": "/send-reminders",
"requestBody": {
"content": {
"application/json": {
"properties": [
{
"name": "claimId",
"type": "string",
"value": "20"
},
{
"name": "pendingDocuments",
"type": "string",
"value": "social number and vat"
}
]
}
}
},
"parameters": []
}
62 changes: 62 additions & 0 deletions tests/unit/data_classes/test_bedrock_agent_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
from tests.functional.utils import load_event


def test_bedrock_agent_event():
raw_event = load_event("bedrockAgentEvent.json")
parsed_event = BedrockAgentEvent(raw_event)

assert parsed_event.session_id == raw_event["sessionId"]
assert parsed_event.input_text == raw_event["inputText"]
assert parsed_event.message_version == raw_event["messageVersion"]
assert parsed_event.http_method == raw_event["httpMethod"]
assert parsed_event.api_path == raw_event["apiPath"]
assert parsed_event.session_attributes == {}
assert parsed_event.prompt_session_attributes == {}
assert parsed_event.action_group == raw_event["actionGroup"]

assert parsed_event.request_body is None

agent = parsed_event.agent
raw_agent = raw_event["agent"]
assert agent.alias == raw_agent["alias"]
assert agent.name == raw_agent["name"]
assert agent.version == raw_agent["version"]
assert agent.id == raw_agent["id"]


def test_bedrock_agent_event_with_post():
raw_event = load_event("bedrockAgentPostEvent.json")
parsed_event = BedrockAgentEvent(raw_event)

assert parsed_event.session_id == raw_event["sessionId"]
assert parsed_event.input_text == raw_event["inputText"]
assert parsed_event.message_version == raw_event["messageVersion"]
assert parsed_event.http_method == raw_event["httpMethod"]
assert parsed_event.api_path == raw_event["apiPath"]
assert parsed_event.session_attributes == {}
assert parsed_event.prompt_session_attributes == {}
assert parsed_event.action_group == raw_event["actionGroup"]

agent = parsed_event.agent
raw_agent = raw_event["agent"]
assert agent.alias == raw_agent["alias"]
assert agent.name == raw_agent["name"]
assert agent.version == raw_agent["version"]
assert agent.id == raw_agent["id"]

request_body = parsed_event.request_body.content
assert "application/json" in request_body

json_request = request_body["application/json"]
properties = json_request.properties
assert len(properties) == 2

raw_properties = raw_event["requestBody"]["content"]["application/json"]["properties"]
assert properties[0].name == raw_properties[0]["name"]
assert properties[0].type == raw_properties[0]["type"]
assert properties[0].value == raw_properties[0]["value"]

assert properties[1].name == raw_properties[1]["name"]
assert properties[1].type == raw_properties[1]["type"]
assert properties[1].value == raw_properties[1]["value"]