Skip to content

Commit 0964e75

Browse files
feat(data_classes): add support for Bedrock Agents event (aws-powertools#3262)
* feat(data_classes): add support for Bedrock Agents events * fix: add tests * fix: add docs * fix: changed to the new payload schema * fix: example * fix: add more tests * fix: remove response * fix: part 2 * fix: comments * fix: remove response example --------- Signed-off-by: Leandro Damascena <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent c0d5224 commit 0964e75

File tree

7 files changed

+248
-0
lines changed

7 files changed

+248
-0
lines changed

aws_lambda_powertools/utilities/data_classes/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .api_gateway_proxy_event import APIGatewayProxyEvent, APIGatewayProxyEventV2
77
from .appsync_resolver_event import AppSyncResolverEvent
88
from .aws_config_rule_event import AWSConfigRuleEvent
9+
from .bedrock_agent_event import BedrockAgentEvent
910
from .cloud_watch_custom_widget_event import CloudWatchDashboardCustomWidgetEvent
1011
from .cloud_watch_logs_event import CloudWatchLogsEvent
1112
from .code_pipeline_job_event import CodePipelineJobEvent
@@ -35,6 +36,7 @@
3536
"SecretsManagerEvent",
3637
"AppSyncResolverEvent",
3738
"ALBEvent",
39+
"BedrockAgentEvent",
3840
"CloudWatchDashboardCustomWidgetEvent",
3941
"CloudWatchLogsEvent",
4042
"CodePipelineJobEvent",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
from typing import Dict, List, Optional
2+
3+
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
4+
5+
6+
class BedrockAgentInfo(DictWrapper):
7+
@property
8+
def name(self) -> str:
9+
return self["name"]
10+
11+
@property
12+
def id(self) -> str: # noqa: A003
13+
return self["id"]
14+
15+
@property
16+
def alias(self) -> str:
17+
return self["alias"]
18+
19+
@property
20+
def version(self) -> str:
21+
return self["version"]
22+
23+
24+
class BedrockAgentProperty(DictWrapper):
25+
@property
26+
def name(self) -> str:
27+
return self["name"]
28+
29+
@property
30+
def type(self) -> str: # noqa: A003
31+
return self["type"]
32+
33+
@property
34+
def value(self) -> str:
35+
return self["value"]
36+
37+
38+
class BedrockAgentRequestMedia(DictWrapper):
39+
@property
40+
def properties(self) -> List[BedrockAgentProperty]:
41+
return [BedrockAgentProperty(x) for x in self["properties"]]
42+
43+
44+
class BedrockAgentRequestBody(DictWrapper):
45+
@property
46+
def content(self) -> Dict[str, BedrockAgentRequestMedia]:
47+
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
48+
49+
50+
class BedrockAgentEvent(DictWrapper):
51+
"""
52+
Bedrock Agent input event
53+
54+
See https://docs.aws.amazon.com/bedrock/latest/userguide/agents-create.html
55+
"""
56+
57+
@property
58+
def message_version(self) -> str:
59+
return self["messageVersion"]
60+
61+
@property
62+
def input_text(self) -> str:
63+
return self["inputText"]
64+
65+
@property
66+
def session_id(self) -> str:
67+
return self["sessionId"]
68+
69+
@property
70+
def action_group(self) -> str:
71+
return self["actionGroup"]
72+
73+
@property
74+
def api_path(self) -> str:
75+
return self["apiPath"]
76+
77+
@property
78+
def http_method(self) -> str:
79+
return self["httpMethod"]
80+
81+
@property
82+
def parameters(self) -> Optional[List[BedrockAgentProperty]]:
83+
return [BedrockAgentProperty(x) for x in self["parameters"]] if self.get("parameters") else None
84+
85+
@property
86+
def request_body(self) -> Optional[BedrockAgentRequestBody]:
87+
return BedrockAgentRequestBody(self["requestBody"]) if self.get("requestBody") else None
88+
89+
@property
90+
def agent(self) -> BedrockAgentInfo:
91+
return BedrockAgentInfo(self["agent"])
92+
93+
@property
94+
def session_attributes(self) -> Dict[str, str]:
95+
return self["sessionAttributes"]
96+
97+
@property
98+
def prompt_session_attributes(self) -> Dict[str, str]:
99+
return self["promptSessionAttributes"]

docs/utilities/data_classes.md

+9
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ Log Data Event for Troubleshooting
8585
| [AppSync Authorizer](#appsync-authorizer) | `AppSyncAuthorizerEvent` |
8686
| [AppSync Resolver](#appsync-resolver) | `AppSyncResolverEvent` |
8787
| [AWS Config Rule](#aws-config-rule) | `AWSConfigRuleEvent` |
88+
| [Bedrock Agent](#bedrock-agent) | `BedrockAgent` |
8889
| [CloudWatch Dashboard Custom Widget](#cloudwatch-dashboard-custom-widget) | `CloudWatchDashboardCustomWidgetEvent` |
8990
| [CloudWatch Logs](#cloudwatch-logs) | `CloudWatchLogsEvent` |
9091
| [CodePipeline Job Event](#codepipeline-job) | `CodePipelineJobEvent` |
@@ -484,6 +485,14 @@ In this example, we also use the new Logger `correlation_id` and built-in `corre
484485
--8<-- "examples/event_sources/src/aws_config_rule_scheduled.json"
485486
```
486487

488+
### Bedrock Agent
489+
490+
=== "app.py"
491+
492+
```python hl_lines="2 8 10"
493+
--8<-- "examples/event_sources/src/bedrock_agent_event.py"
494+
```
495+
487496
### CloudWatch Dashboard Custom Widget
488497

489498
=== "app.py"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from aws_lambda_powertools import Logger
2+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent, event_source
3+
from aws_lambda_powertools.utilities.typing import LambdaContext
4+
5+
logger = Logger()
6+
7+
8+
@event_source(data_class=BedrockAgentEvent)
9+
def lambda_handler(event: BedrockAgentEvent, context: LambdaContext) -> dict:
10+
input_text = event.input_text
11+
12+
logger.info(f"Bedrock Agent {event.action_group} invoked with input", input_text=input_text)
13+
14+
return {
15+
"message_version": "1.0",
16+
"responses": [
17+
{
18+
"action_group": event.action_group,
19+
"api_path": event.api_path,
20+
"http_method": event.http_method,
21+
"http_status_code": 200,
22+
"response_body": {"application/json": {"body": "This is the response"}},
23+
},
24+
],
25+
}

tests/events/bedrockAgentEvent.json

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"actionGroup": "ClaimManagementActionGroup",
3+
"messageVersion": "1.0",
4+
"sessionId": "12345678912345",
5+
"sessionAttributes": {},
6+
"promptSessionAttributes": {},
7+
"inputText": "I want to claim my insurance",
8+
"agent": {
9+
"alias": "TSTALIASID",
10+
"name": "test",
11+
"version": "DRAFT",
12+
"id": "8ZXY0W8P1H"
13+
},
14+
"httpMethod": "GET",
15+
"apiPath": "/claims"
16+
}
+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
{
2+
"actionGroup": "ClaimManagementActionGroup",
3+
"messageVersion": "1.0",
4+
"sessionId": "12345678912345",
5+
"sessionAttributes": {},
6+
"promptSessionAttributes": {},
7+
"inputText": "Send reminders to all pending documents",
8+
"agent": {
9+
"alias": "TSTALIASID",
10+
"name": "test",
11+
"version": "DRAFT",
12+
"id": "8ZXY0W8P1H"
13+
},
14+
"httpMethod": "POST",
15+
"apiPath": "/send-reminders",
16+
"requestBody": {
17+
"content": {
18+
"application/json": {
19+
"properties": [
20+
{
21+
"name": "claimId",
22+
"type": "string",
23+
"value": "20"
24+
},
25+
{
26+
"name": "pendingDocuments",
27+
"type": "string",
28+
"value": "social number and vat"
29+
}
30+
]
31+
}
32+
}
33+
},
34+
"parameters": []
35+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
2+
from tests.functional.utils import load_event
3+
4+
5+
def test_bedrock_agent_event():
6+
raw_event = load_event("bedrockAgentEvent.json")
7+
parsed_event = BedrockAgentEvent(raw_event)
8+
9+
assert parsed_event.session_id == raw_event["sessionId"]
10+
assert parsed_event.input_text == raw_event["inputText"]
11+
assert parsed_event.message_version == raw_event["messageVersion"]
12+
assert parsed_event.http_method == raw_event["httpMethod"]
13+
assert parsed_event.api_path == raw_event["apiPath"]
14+
assert parsed_event.session_attributes == {}
15+
assert parsed_event.prompt_session_attributes == {}
16+
assert parsed_event.action_group == raw_event["actionGroup"]
17+
18+
assert parsed_event.request_body is None
19+
20+
agent = parsed_event.agent
21+
raw_agent = raw_event["agent"]
22+
assert agent.alias == raw_agent["alias"]
23+
assert agent.name == raw_agent["name"]
24+
assert agent.version == raw_agent["version"]
25+
assert agent.id == raw_agent["id"]
26+
27+
28+
def test_bedrock_agent_event_with_post():
29+
raw_event = load_event("bedrockAgentPostEvent.json")
30+
parsed_event = BedrockAgentEvent(raw_event)
31+
32+
assert parsed_event.session_id == raw_event["sessionId"]
33+
assert parsed_event.input_text == raw_event["inputText"]
34+
assert parsed_event.message_version == raw_event["messageVersion"]
35+
assert parsed_event.http_method == raw_event["httpMethod"]
36+
assert parsed_event.api_path == raw_event["apiPath"]
37+
assert parsed_event.session_attributes == {}
38+
assert parsed_event.prompt_session_attributes == {}
39+
assert parsed_event.action_group == raw_event["actionGroup"]
40+
41+
agent = parsed_event.agent
42+
raw_agent = raw_event["agent"]
43+
assert agent.alias == raw_agent["alias"]
44+
assert agent.name == raw_agent["name"]
45+
assert agent.version == raw_agent["version"]
46+
assert agent.id == raw_agent["id"]
47+
48+
request_body = parsed_event.request_body.content
49+
assert "application/json" in request_body
50+
51+
json_request = request_body["application/json"]
52+
properties = json_request.properties
53+
assert len(properties) == 2
54+
55+
raw_properties = raw_event["requestBody"]["content"]["application/json"]["properties"]
56+
assert properties[0].name == raw_properties[0]["name"]
57+
assert properties[0].type == raw_properties[0]["type"]
58+
assert properties[0].value == raw_properties[0]["value"]
59+
60+
assert properties[1].name == raw_properties[1]["name"]
61+
assert properties[1].type == raw_properties[1]["type"]
62+
assert properties[1].value == raw_properties[1]["value"]

0 commit comments

Comments
 (0)