5
5
if TYPE_CHECKING :
6
6
from collections .abc import Callable
7
7
8
- from enum import Enum
9
-
10
8
from aws_lambda_powertools .utilities .data_classes import BedrockAgentFunctionEvent
11
9
12
10
13
- class ResponseState (Enum ):
14
- FAILURE = "FAILURE"
15
- REPROMPT = "REPROMPT"
16
-
17
-
18
- class BedrockResponse :
11
+ class BedrockFunctionResponse :
19
12
"""Response class for Bedrock Agent Functions
20
13
21
14
Parameters
@@ -26,15 +19,15 @@ class BedrockResponse:
26
19
Session attributes to include in the response
27
20
prompt_session_attributes : dict[str, str] | None
28
21
Prompt session attributes to include in the response
29
- status_code : int
30
- Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE )
22
+ response_state : str | None
23
+ Response state ("FAILURE" or " REPROMPT" )
31
24
32
25
Examples
33
26
--------
34
27
```python
35
28
@app.tool(description="Function that uses session attributes")
36
29
def test_function():
37
- return BedrockResponse (
30
+ return BedrockFunctionResponse (
38
31
body="Hello",
39
32
session_attributes={"userId": "123"},
40
33
prompt_session_attributes={"lastAction": "login"}
@@ -48,40 +41,39 @@ def __init__(
48
41
session_attributes : dict [str , str ] | None = None ,
49
42
prompt_session_attributes : dict [str , str ] | None = None ,
50
43
knowledge_bases : list [dict [str , Any ]] | None = None ,
51
- status_code : int = 200 ,
44
+ response_state : str | None = None ,
52
45
) -> None :
53
46
self .body = body
54
47
self .session_attributes = session_attributes
55
48
self .prompt_session_attributes = prompt_session_attributes
56
49
self .knowledge_bases = knowledge_bases
57
- self .status_code = status_code
50
+ self .response_state = response_state
58
51
59
52
60
53
class BedrockFunctionsResponseBuilder :
61
54
"""
62
55
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
63
56
when using Bedrock Agent Functions.
64
-
65
- Since the payload format is different from the standard API Gateway Proxy event,
66
- we override the build method.
67
57
"""
68
58
69
- def __init__ (self , result : BedrockResponse | Any , status_code : int = 200 ) -> None :
59
+ def __init__ (self , result : BedrockFunctionResponse | Any ) -> None :
70
60
self .result = result
71
- self .status_code = status_code if not isinstance (result , BedrockResponse ) else result .status_code
72
61
73
62
def build (self , event : BedrockAgentFunctionEvent ) -> dict [str , Any ]:
74
63
"""Build the full response dict to be returned by the lambda"""
75
- if isinstance (self .result , BedrockResponse ):
64
+ if isinstance (self .result , BedrockFunctionResponse ):
76
65
body = self .result .body
77
66
session_attributes = self .result .session_attributes
78
67
prompt_session_attributes = self .result .prompt_session_attributes
79
68
knowledge_bases = self .result .knowledge_bases
69
+ response_state = self .result .response_state
70
+
80
71
else :
81
72
body = self .result
82
73
session_attributes = None
83
74
prompt_session_attributes = None
84
75
knowledge_bases = None
76
+ response_state = None
85
77
86
78
response : dict [str , Any ] = {
87
79
"messageVersion" : "1.0" ,
@@ -92,11 +84,9 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
92
84
},
93
85
}
94
86
95
- # Add responseState if it's an error
96
- if self .status_code >= 400 :
97
- response ["response" ]["functionResponse" ]["responseState" ] = (
98
- ResponseState .REPROMPT .value if self .status_code == 400 else ResponseState .FAILURE .value
99
- )
87
+ # Add responseState if provided
88
+ if response_state :
89
+ response ["response" ]["functionResponse" ]["responseState" ] = response_state
100
90
101
91
# Add session attributes if provided in response or maintain from input
102
92
response .update (
@@ -186,9 +176,8 @@ def _resolve(self) -> dict[str, Any]:
186
176
187
177
if function_name not in self ._tools :
188
178
return BedrockFunctionsResponseBuilder (
189
- BedrockResponse (
179
+ BedrockFunctionResponse (
190
180
body = f"Function not found: { function_name } " ,
191
- status_code = 400 , # Using 400 to trigger REPROMPT
192
181
),
193
182
).build (self .current_event )
194
183
@@ -197,8 +186,7 @@ def _resolve(self) -> dict[str, Any]:
197
186
return BedrockFunctionsResponseBuilder (result ).build (self .current_event )
198
187
except Exception as e :
199
188
return BedrockFunctionsResponseBuilder (
200
- BedrockResponse (
189
+ BedrockFunctionResponse (
201
190
body = f"Error: { str (e )} " ,
202
- status_code = 500 , # Using 500 to trigger FAILURE
203
191
),
204
192
).build (self .current_event )
0 commit comments