Skip to content

Commit 2b4cf35

Browse files
committed
fix: add tests
1 parent 030aed6 commit 2b4cf35

File tree

4 files changed

+138
-14
lines changed

4 files changed

+138
-14
lines changed

Diff for: aws_lambda_powertools/utilities/data_classes/bedrock_agent_event.py

+40-14
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,35 @@ def content(self) -> Dict[str, BedrockAgentRequestMedia]:
4747
return {k: BedrockAgentRequestMedia(v) for k, v in self["content"].items()}
4848

4949

50+
class BedrockAgentActionGroup(DictWrapper):
51+
@property
52+
def action_group(self) -> str:
53+
return self["actionGroup"]
54+
55+
@property
56+
def api_path(self) -> str:
57+
return self["apiPath"]
58+
59+
@property
60+
def http_method(self) -> str:
61+
return self["httpMethod"]
62+
63+
@property
64+
def parameters(self) -> List[BedrockAgentProperty]:
65+
return [BedrockAgentProperty(x) for x in self["parameters"]]
66+
67+
@property
68+
def request_body(self) -> BedrockAgentRequestBody:
69+
return BedrockAgentRequestBody(self["requestBody"])
70+
71+
5072
class BedrockAgentEvent(DictWrapper):
73+
"""
74+
Bedrock Agent input event
75+
76+
See https://docs.aws.amazon.com/bedrock/latest/userguide/agents-create.html
77+
"""
78+
5179
@property
5280
def message_version(self) -> str:
5381
return self["messageVersion"]
@@ -61,20 +89,12 @@ def session_id(self) -> str:
6189
return self["sessionId"]
6290

6391
@property
64-
def action_group(self) -> str:
65-
return self["actionGroup"]
92+
def action_groups(self) -> List[BedrockAgentActionGroup]:
93+
return [BedrockAgentActionGroup(x) for x in self["actionGroups"]]
6694

6795
@property
68-
def api_path(self) -> str:
69-
return self["apiPath"]
70-
71-
@property
72-
def http_method(self) -> str:
73-
return self["httpMethod"]
74-
75-
@property
76-
def parameters(self) -> List[BedrockAgentProperty]:
77-
return [BedrockAgentProperty(x) for x in self["parameters"]]
96+
def agent(self) -> BedrockAgentInfo:
97+
return BedrockAgentInfo(self["agent"])
7898

7999
@property
80100
def session_attributes(self) -> Dict[str, str]:
@@ -110,10 +130,16 @@ def response_body(self) -> Dict[str, BedrockAgentResponseMedia]:
110130

111131

112132
class BedrockAgentResponseEvent(DictWrapper):
133+
"""
134+
Bedrock Agent output event
135+
136+
See: https://docs.aws.amazon.com/bedrock/latest/userguide/agents-create.html
137+
"""
138+
113139
@property
114140
def message_version(self) -> str:
115141
return self["messageVersion"]
116142

117143
@property
118-
def response(self) -> BedrockAgentResponsePayload:
119-
return BedrockAgentResponsePayload(self["response"])
144+
def responses(self) -> List[BedrockAgentResponsePayload]:
145+
return [BedrockAgentResponsePayload(x) for x in self["response"]]

Diff for: tests/events/bedrockAgentEvent.json

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
{
2+
"sessionId": "12345678912345",
3+
"agent": {
4+
"alias": "TSTALIASID",
5+
"name": "test",
6+
"version": "DRAFT",
7+
"id": "8ZXY0W8P1H"
8+
},
9+
"inputText": "ABC12345",
10+
"sessionAttributes": {},
11+
"messageVersion": "1.0",
12+
"actionGroups": [
13+
{
14+
"httpMethod": "GET",
15+
"apiPath": "/claims/{claimId}/identify-missing-documents",
16+
"actionGroup": "ClaimManagementActionGroup",
17+
"parameters": [
18+
{
19+
"name": "claimId",
20+
"type": "string",
21+
"value": "ABC12345"
22+
}
23+
]
24+
}
25+
]
26+
}

Diff for: tests/events/bedrockAgentResponseEvent.json

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"messageVersion": "1.0",
3+
"response": [
4+
{
5+
"actionGroup": "ClaimManagementActionGroup",
6+
"apiPath": "/claims/{claimId}/identify-missing-documents",
7+
"httpMethod": "GET",
8+
"httpStatusCode": 200,
9+
"responseBody": {
10+
"application/json": {
11+
"body": "This is the response"
12+
}
13+
}
14+
}
15+
]
16+
}

Diff for: tests/unit/data_classes/test_bedrock_agent_event.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent, BedrockAgentResponseEvent
2+
from tests.functional.utils import load_event
3+
4+
5+
def test_bedrock_agent_event():
6+
raw_event = load_event("bedrockAgentEvent.json")
7+
parsed_event = BedrockAgentEvent(raw_event)
8+
9+
assert parsed_event.session_id == "12345678912345"
10+
assert parsed_event.input_text == "ABC12345"
11+
assert parsed_event.session_attributes == {}
12+
assert parsed_event.message_version == "1.0"
13+
14+
agent = parsed_event.agent
15+
assert agent.alias == "TSTALIASID"
16+
assert agent.name == "test"
17+
assert agent.version == "DRAFT"
18+
assert agent.id == "8ZXY0W8P1H"
19+
20+
action_groups = parsed_event.action_groups
21+
assert len(action_groups) == 1
22+
23+
action_group = action_groups[0]
24+
assert action_group.http_method == "GET"
25+
assert action_group.api_path == "/claims/{claimId}/identify-missing-documents"
26+
assert action_group.action_group == "ClaimManagementActionGroup"
27+
28+
parameters = action_group.parameters
29+
assert len(parameters) == 1
30+
31+
parameter = parameters[0]
32+
assert parameter.name == "claimId"
33+
assert parameter.type == "string"
34+
assert parameter.value == "ABC12345"
35+
36+
37+
def test_bedrock_agent_response_event():
38+
raw_event = load_event("bedrockAgentResponseEvent.json")
39+
parsed_event = BedrockAgentResponseEvent(raw_event)
40+
41+
assert parsed_event.message_version == "1.0"
42+
43+
responses = parsed_event.responses
44+
assert len(responses) == 1
45+
46+
response = responses[0]
47+
assert response.action_group == "ClaimManagementActionGroup"
48+
assert response.api_path == "/claims/{claimId}/identify-missing-documents"
49+
assert response.http_method == "GET"
50+
assert response.http_status_code == 200
51+
52+
response_body = response.response_body
53+
assert "application/json" in response_body
54+
55+
json_response = response_body["application/json"]
56+
assert json_response.body == "This is the response"

0 commit comments

Comments
 (0)