Skip to content

Commit 86c7ab7

Browse files
committed
bedrockfunctionresponse and response state
1 parent e42ceff commit 86c7ab7

File tree

5 files changed

+59
-213
lines changed

5 files changed

+59
-213
lines changed

aws_lambda_powertools/event_handler/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
)
1313
from aws_lambda_powertools.event_handler.appsync import AppSyncResolver
1414
from aws_lambda_powertools.event_handler.bedrock_agent import BedrockAgentResolver
15-
from aws_lambda_powertools.event_handler.bedrock_agent_function import BedrockAgentFunctionResolver, BedrockResponse
15+
from aws_lambda_powertools.event_handler.bedrock_agent_function import (
16+
BedrockAgentFunctionResolver,
17+
BedrockFunctionResponse,
18+
)
1619
from aws_lambda_powertools.event_handler.events_appsync.appsync_events import AppSyncEventsResolver
1720
from aws_lambda_powertools.event_handler.lambda_function_url import (
1821
LambdaFunctionUrlResolver,
@@ -31,7 +34,7 @@
3134
"CORSConfig",
3235
"LambdaFunctionUrlResolver",
3336
"Response",
34-
"BedrockResponse",
37+
"BedrockFunctionResponse",
3538
"VPCLatticeResolver",
3639
"VPCLatticeV2Resolver",
3740
]

aws_lambda_powertools/event_handler/bedrock_agent_function.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,10 @@
55
if TYPE_CHECKING:
66
from collections.abc import Callable
77

8-
from enum import Enum
9-
108
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent
119

1210

13-
class ResponseState(Enum):
14-
FAILURE = "FAILURE"
15-
REPROMPT = "REPROMPT"
16-
17-
18-
class BedrockResponse:
11+
class BedrockFunctionResponse:
1912
"""Response class for Bedrock Agent Functions
2013
2114
Parameters
@@ -26,15 +19,15 @@ class BedrockResponse:
2619
Session attributes to include in the response
2720
prompt_session_attributes : dict[str, str] | None
2821
Prompt session attributes to include in the response
29-
status_code : int
30-
Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE)
22+
response_state : str | None
23+
Response state ("FAILURE" or "REPROMPT")
3124
3225
Examples
3326
--------
3427
```python
3528
@app.tool(description="Function that uses session attributes")
3629
def test_function():
37-
return BedrockResponse(
30+
return BedrockFunctionResponse(
3831
body="Hello",
3932
session_attributes={"userId": "123"},
4033
prompt_session_attributes={"lastAction": "login"}
@@ -48,40 +41,39 @@ def __init__(
4841
session_attributes: dict[str, str] | None = None,
4942
prompt_session_attributes: dict[str, str] | None = None,
5043
knowledge_bases: list[dict[str, Any]] | None = None,
51-
status_code: int = 200,
44+
response_state: str | None = None,
5245
) -> None:
5346
self.body = body
5447
self.session_attributes = session_attributes
5548
self.prompt_session_attributes = prompt_session_attributes
5649
self.knowledge_bases = knowledge_bases
57-
self.status_code = status_code
50+
self.response_state = response_state
5851

5952

6053
class BedrockFunctionsResponseBuilder:
6154
"""
6255
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda
6356
when using Bedrock Agent Functions.
64-
65-
Since the payload format is different from the standard API Gateway Proxy event,
66-
we override the build method.
6757
"""
6858

69-
def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None:
59+
def __init__(self, result: BedrockFunctionResponse | Any) -> None:
7060
self.result = result
71-
self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code
7261

7362
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
7463
"""Build the full response dict to be returned by the lambda"""
75-
if isinstance(self.result, BedrockResponse):
64+
if isinstance(self.result, BedrockFunctionResponse):
7665
body = self.result.body
7766
session_attributes = self.result.session_attributes
7867
prompt_session_attributes = self.result.prompt_session_attributes
7968
knowledge_bases = self.result.knowledge_bases
69+
response_state = self.result.response_state
70+
8071
else:
8172
body = self.result
8273
session_attributes = None
8374
prompt_session_attributes = None
8475
knowledge_bases = None
76+
response_state = None
8577

8678
response: dict[str, Any] = {
8779
"messageVersion": "1.0",
@@ -92,11 +84,9 @@ def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]:
9284
},
9385
}
9486

95-
# Add responseState if it's an error
96-
if self.status_code >= 400:
97-
response["response"]["functionResponse"]["responseState"] = (
98-
ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value
99-
)
87+
# Add responseState if provided
88+
if response_state:
89+
response["response"]["functionResponse"]["responseState"] = response_state
10090

10191
# Add session attributes if provided in response or maintain from input
10292
response.update(
@@ -186,9 +176,8 @@ def _resolve(self) -> dict[str, Any]:
186176

187177
if function_name not in self._tools:
188178
return BedrockFunctionsResponseBuilder(
189-
BedrockResponse(
179+
BedrockFunctionResponse(
190180
body=f"Function not found: {function_name}",
191-
status_code=400, # Using 400 to trigger REPROMPT
192181
),
193182
).build(self.current_event)
194183

@@ -197,8 +186,7 @@ def _resolve(self) -> dict[str, Any]:
197186
return BedrockFunctionsResponseBuilder(result).build(self.current_event)
198187
except Exception as e:
199188
return BedrockFunctionsResponseBuilder(
200-
BedrockResponse(
189+
BedrockFunctionResponse(
201190
body=f"Error: {str(e)}",
202-
status_code=500, # Using 500 to trigger FAILURE
203191
),
204192
).build(self.current_event)

aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
from __future__ import annotations
22

3-
from typing import Any
4-
53
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper
64

75

@@ -45,32 +43,6 @@ class BedrockAgentFunctionEvent(DictWrapper):
4543
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html
4644
"""
4745

48-
@classmethod
49-
def validate_required_fields(cls, data: dict[str, Any]) -> None:
50-
required_fields = {
51-
"messageVersion": str,
52-
"agent": dict,
53-
"inputText": str,
54-
"sessionId": str,
55-
"actionGroup": str,
56-
"function": str,
57-
}
58-
59-
for field, field_type in required_fields.items():
60-
if field not in data:
61-
raise ValueError(f"Missing required field: {field}")
62-
if not isinstance(data[field], field_type):
63-
raise TypeError(f"Field {field} must be of type {field_type}")
64-
65-
# Validate agent structure
66-
required_agent_fields = {"name", "id", "alias", "version"}
67-
if not all(field in data["agent"] for field in required_agent_fields):
68-
raise ValueError("Agent object missing required fields")
69-
70-
def __init__(self, data: dict[str, Any]) -> None:
71-
super().__init__(data)
72-
self.validate_required_fields(data)
73-
7446
@property
7547
def message_version(self) -> str:
7648
return self["messageVersion"]

0 commit comments

Comments
 (0)