Skip to content

Commit a7eb9e2

Browse files
committed
fix: changed to the new payload schema
1 parent 2eacca1 commit a7eb9e2

File tree

4 files changed

+51
-81
lines changed

4 files changed

+51
-81
lines changed

aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+24-26
Original file line numberDiff line numberDiff line change
@@ -47,28 +47,6 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
4747
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
4848

4949

50-
class BedrockAgentActionGroup(DictWrapper):
51-
@property
52-
def action_group(self) -> str:
53-
return self["actionGroup"]
54-
55-
@property
56-
def api_path(self) -> str:
57-
return self["apiPath"]
58-
59-
@property
60-
def http_method(self) -> str:
61-
return self["httpMethod"]
62-
63-
@property
64-
def parameters(self) -> List[BedrockAgentProperty]:
65-
return [BedrockAgentProperty(x) for x in self["parameters"]]
66-
67-
@property
68-
def request_body(self) -> BedrockAgentRequestBody:
69-
return BedrockAgentRequestBody(self["requestBody"])
70-
71-
7250
class BedrockAgentEvent(DictWrapper):
7351
"""
7452
Bedrock Agent input event
@@ -89,8 +67,24 @@ def session_id(self) -> str:
8967
return self["sessionId"]
9068

9169
@property
92-
def action_groups(self) -> List[BedrockAgentActionGroup]:
93-
return [BedrockAgentActionGroup(x) for x in self["actionGroups"]]
70+
def action_group(self) -> str:
71+
return self["actionGroup"]
72+
73+
@property
74+
def api_path(self) -> str:
75+
return self["apiPath"]
76+
77+
@property
78+
def http_method(self) -> str:
79+
return self["httpMethod"]
80+
81+
@property
82+
def parameters(self) -> List[BedrockAgentProperty]:
83+
return [BedrockAgentProperty(x) for x in self["parameters"]]
84+
85+
@property
86+
def request_body(self) -> BedrockAgentRequestBody:
87+
return BedrockAgentRequestBody(self["requestBody"])
9488

9589
@property
9690
def agent(self) -> BedrockAgentInfo:
@@ -100,6 +94,10 @@ def agent(self) -> BedrockAgentInfo:
10094
def session_attributes(self) -> Dict[str, str]:
10195
return self["sessionAttributes"]
10296

97+
@property
98+
def prompt_session_attributes(self) -> Dict[str, str]:
99+
return self["promptSessionAttributes"]
100+
103101

104102
class BedrockAgentResponseMedia(DictWrapper):
105103
@property
@@ -141,5 +139,5 @@ def message_version(self) -> str:
141139
return self["messageVersion"]
142140

143141
@property
144-
def responses(self) -> List[BedrockAgentResponsePayload]:
145-
return [BedrockAgentResponsePayload(x) for x in self["response"]]
142+
def response(self) -> BedrockAgentResponsePayload:
143+
return BedrockAgentResponsePayload(self["response"])

tests/events/bedrockAgentEvent.json

+7-17
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,16 @@
11
{
2+
"actionGroup": "ClaimManagementActionGroup",
3+
"messageVersion": "1.0",
24
"sessionId": "12345678912345",
5+
"sessionAttributes": {},
6+
"promptSessionAttributes": {},
7+
"inputText": "I want to claim my insurance",
38
"agent": {
49
"alias": "TSTALIASID",
510
"name": "test",
611
"version": "DRAFT",
712
"id": "8ZXY0W8P1H"
813
},
9-
"inputText": "ABC12345",
10-
"sessionAttributes": {},
11-
"messageVersion": "1.0",
12-
"actionGroups": [
13-
{
14-
"httpMethod": "GET",
15-
"apiPath": "/claims/{claimId}/identify-missing-documents",
16-
"actionGroup": "ClaimManagementActionGroup",
17-
"parameters": [
18-
{
19-
"name": "claimId",
20-
"type": "string",
21-
"value": "ABC12345"
22-
}
23-
]
24-
}
25-
]
14+
"httpMethod": "GET",
15+
"apiPath": "/claims"
2616
}
+9-11
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,14 @@
11
{
22
"messageVersion": "1.0",
3-
"response": [
4-
{
5-
"actionGroup": "ClaimManagementActionGroup",
6-
"apiPath": "/claims/{claimId}/identify-missing-documents",
7-
"httpMethod": "GET",
8-
"httpStatusCode": 200,
9-
"responseBody": {
10-
"application/json": {
11-
"body": "This is the response"
12-
}
3+
"response": {
4+
"actionGroup": "ClaimManagementActionGroup",
5+
"apiPath": "/claims/{claimId}/identify-missing-documents",
6+
"httpMethod": "GET",
7+
"httpStatusCode": 200,
8+
"responseBody": {
9+
"application/json": {
10+
"body": "This is the response"
1311
}
1412
}
15-
]
13+
}
1614
}

tests/unit/data_classes/test_bedrock_agent_event.py

+11-27
Original file line numberDiff line numberDiff line change
@@ -7,49 +7,33 @@ def test_bedrock_agent_event():
77
parsed_event = BedrockAgentEvent(raw_event)
88

99
assert parsed_event.session_id == "12345678912345"
10-
assert parsed_event.input_text == "ABC12345"
11-
assert parsed_event.session_attributes == {}
10+
assert parsed_event.input_text == "I want to claim my insurance"
1211
assert parsed_event.message_version == "1.0"
12+
assert parsed_event.http_method == "GET"
13+
assert parsed_event.api_path == "/claims"
14+
assert parsed_event.session_attributes == {}
15+
assert parsed_event.prompt_session_attributes == {}
16+
assert parsed_event.action_group == "ClaimManagementActionGroup"
1317

1418
agent = parsed_event.agent
1519
assert agent.alias == "TSTALIASID"
1620
assert agent.name == "test"
1721
assert agent.version == "DRAFT"
1822
assert agent.id == "8ZXY0W8P1H"
1923

20-
action_groups = parsed_event.action_groups
21-
assert len(action_groups) == 1
22-
23-
action_group = action_groups[0]
24-
assert action_group.http_method == "GET"
25-
assert action_group.api_path == "/claims/{claimId}/identify-missing-documents"
26-
assert action_group.action_group == "ClaimManagementActionGroup"
27-
28-
parameters = action_group.parameters
29-
assert len(parameters) == 1
30-
31-
parameter = parameters[0]
32-
assert parameter.name == "claimId"
33-
assert parameter.type == "string"
34-
assert parameter.value == "ABC12345"
35-
3624

3725
def test_bedrock_agent_response_event():
3826
raw_event = load_event("bedrockAgentResponseEvent.json")
3927
parsed_event = BedrockAgentResponseEvent(raw_event)
4028

4129
assert parsed_event.message_version == "1.0"
4230

43-
responses = parsed_event.responses
44-
assert len(responses) == 1
45-
46-
response = responses[0]
47-
assert response.action_group == "ClaimManagementActionGroup"
48-
assert response.api_path == "/claims/{claimId}/identify-missing-documents"
49-
assert response.http_method == "GET"
50-
assert response.http_status_code == 200
31+
assert parsed_event.response.action_group == "ClaimManagementActionGroup"
32+
assert parsed_event.response.api_path == "/claims/{claimId}/identify-missing-documents"
33+
assert parsed_event.response.http_method == "GET"
34+
assert parsed_event.response.http_status_code == 200
5135

52-
response_body = response.response_body
36+
response_body = parsed_event.response.response_body
5337
assert "application/json" in response_body
5438

5539
json_response = response_body["application/json"]

0 commit comments

Comments
 (0)