-
Notifications
You must be signed in to change notification settings - Fork 420
/
Copy pathtest_bedrock_agent.py
159 lines (122 loc) · 5.56 KB
/
test_bedrock_agent.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
from typing import Any, Dict
from aws_lambda_powertools.event_handler import BedrockAgentResolver, Response, content_types
from aws_lambda_powertools.event_handler.openapi.types import PYDANTIC_V2
from aws_lambda_powertools.utilities.data_classes import BedrockAgentEvent
from tests.functional.utils import load_event
claims_response = "You have 3 claims"
def test_bedrock_agent_event():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
@app.get("/claims")
def claims() -> Dict[str, Any]:
assert isinstance(app.current_event, BedrockAgentEvent)
assert app.lambda_context == {}
return {"output": claims_response}
# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})
# THEN process event correctly
# AND set the current_event type as BedrockAgentEvent
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 200
body = result["response"]["responseBody"]["application/json"]["body"]
assert json.loads(body) == {"output": claims_response}
def test_bedrock_agent_with_path_params():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
@app.get("/claims/<claim_id>")
def claims(claim_id: str):
assert isinstance(app.current_event, BedrockAgentEvent)
assert app.lambda_context == {}
assert claim_id == "123"
# WHEN calling the event handler
result = app(load_event("bedrockAgentEventWithPathParams.json"), {})
# THEN process event correctly
# AND set the current_event type as BedrockAgentEvent
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims/<claim_id>"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 200
def test_bedrock_agent_event_with_response():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
output = {"output": claims_response}
@app.get("/claims")
def claims():
assert isinstance(app.current_event, BedrockAgentEvent)
assert app.lambda_context == {}
return Response(200, content_types.APPLICATION_JSON, output)
# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})
# THEN process event correctly
# AND set the current_event type as BedrockAgentEvent
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 200
body = result["response"]["responseBody"]["application/json"]["body"]
assert json.loads(body) == output
def test_bedrock_agent_event_with_no_matches():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
@app.get("/no_match")
def claims():
raise RuntimeError()
# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})
# THEN process event correctly
# AND return 404 because the event doesn't match any known rule
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 404
def test_bedrock_agent_event_with_validation_error():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
@app.get("/claims")
def claims() -> Dict[str, Any]:
return "oh no, this is not a dict" # type: ignore
# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})
# THEN process event correctly
# AND set the current_event type as BedrockAgentEvent
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 422
body = result["response"]["responseBody"]["application/json"]["body"]
if PYDANTIC_V2:
assert "should be a valid dictionary" in body
else:
assert "value is not a valid dict" in body
def test_bedrock_agent_event_with_exception():
# GIVEN a Bedrock Agent event
app = BedrockAgentResolver()
@app.exception_handler(RuntimeError)
def handle_runtime_error(ex: RuntimeError):
return Response(
status_code=500,
content_type=content_types.TEXT_PLAIN,
body="Something went wrong",
)
@app.get("/claims")
def claims():
raise RuntimeError()
# WHEN calling the event handler
result = app(load_event("bedrockAgentEvent.json"), {})
# THEN process the exception correctly
# AND return 500 because of the internal server error
assert result["messageVersion"] == "1.0"
assert result["response"]["apiPath"] == "/claims"
assert result["response"]["actionGroup"] == "ClaimManagementActionGroup"
assert result["response"]["httpMethod"] == "GET"
assert result["response"]["httpStatusCode"] == 500
body = result["response"]["responseBody"]["text/plain"]["body"]
assert body == "Something went wrong"