-
Notifications
You must be signed in to change notification settings - Fork 433
feat(bedrock_agent): add new Amazon Bedrock Agents Functions Resolver #6564
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 6 commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
41bc401
feat(bedrock_agent): create bedrock agents functions data class
anafalcao bed8f3f
create resolver
anafalcao a3765f0
mypy
anafalcao 44d80f8
add response
anafalcao abbc100
add name param to tool
anafalcao e42ceff
add response optional fields
anafalcao 86c7ab7
bedrockfunctionresponse and response state
anafalcao 34948d7
remove body message
anafalcao 24978cb
add parser
anafalcao 45f85f6
add test for required fields
anafalcao b420a90
Merge branch 'develop' into feat/bedrock_functions
anafalcao 84bb6b0
add more tests for parser and resolver
anafalcao 20bbe9f
Merge branch 'feat/bedrock_functions' of https://github.com/aws-power…
anafalcao d463304
add validation response state
anafalcao b4ab6b9
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 39e0d36
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 54a7edf
params injection
anafalcao fdde207
doc event handler, parser and data class
anafalcao c8b1b2f
fix doc typo
anafalcao db7d6b9
fix doc typo
anafalcao 266ebcb
mypy
anafalcao 4211b72
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 20215ed
Small refactor + documentation
leandrodamascena ef31cb5
Small refactor + documentation
leandrodamascena c200a3a
Small refactor + documentation
leandrodamascena ca700c8
Small refactor + documentation
leandrodamascena c992463
Aligning Python implementation with TS
leandrodamascena 214d061
Adding custom serializer
leandrodamascena 9914ab5
Adding custom serializer
leandrodamascena df3f29b
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena 3694cb2
More documentation
leandrodamascena 13d5569
Merge branch 'develop' into feat/bedrock_functions
leandrodamascena File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
204 changes: 204 additions & 0 deletions
204
aws_lambda_powertools/event_handler/bedrock_agent_function.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING, Any | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable | ||
|
||
from enum import Enum | ||
|
||
from aws_lambda_powertools.utilities.data_classes import BedrockAgentFunctionEvent | ||
|
||
|
||
class ResponseState(Enum): | ||
FAILURE = "FAILURE" | ||
REPROMPT = "REPROMPT" | ||
|
||
|
||
class BedrockResponse: | ||
"""Response class for Bedrock Agent Functions | ||
|
||
Parameters | ||
---------- | ||
body : Any, optional | ||
Response body | ||
session_attributes : dict[str, str] | None | ||
Session attributes to include in the response | ||
prompt_session_attributes : dict[str, str] | None | ||
Prompt session attributes to include in the response | ||
status_code : int | ||
Status code to determine responseState (400 for REPROMPT, >=500 for FAILURE) | ||
|
||
Examples | ||
-------- | ||
```python | ||
@app.tool(description="Function that uses session attributes") | ||
def test_function(): | ||
return BedrockResponse( | ||
body="Hello", | ||
session_attributes={"userId": "123"}, | ||
prompt_session_attributes={"lastAction": "login"} | ||
) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
body: Any = None, | ||
session_attributes: dict[str, str] | None = None, | ||
prompt_session_attributes: dict[str, str] | None = None, | ||
knowledge_bases: list[dict[str, Any]] | None = None, | ||
status_code: int = 200, | ||
) -> None: | ||
self.body = body | ||
self.session_attributes = session_attributes | ||
self.prompt_session_attributes = prompt_session_attributes | ||
self.knowledge_bases = knowledge_bases | ||
self.status_code = status_code | ||
|
||
|
||
class BedrockFunctionsResponseBuilder: | ||
""" | ||
Bedrock Functions Response Builder. This builds the response dict to be returned by Lambda | ||
when using Bedrock Agent Functions. | ||
|
||
Since the payload format is different from the standard API Gateway Proxy event, | ||
we override the build method. | ||
""" | ||
|
||
def __init__(self, result: BedrockResponse | Any, status_code: int = 200) -> None: | ||
self.result = result | ||
self.status_code = status_code if not isinstance(result, BedrockResponse) else result.status_code | ||
|
||
def build(self, event: BedrockAgentFunctionEvent) -> dict[str, Any]: | ||
"""Build the full response dict to be returned by the lambda""" | ||
if isinstance(self.result, BedrockResponse): | ||
body = self.result.body | ||
session_attributes = self.result.session_attributes | ||
prompt_session_attributes = self.result.prompt_session_attributes | ||
knowledge_bases = self.result.knowledge_bases | ||
else: | ||
body = self.result | ||
session_attributes = None | ||
prompt_session_attributes = None | ||
knowledge_bases = None | ||
|
||
response: dict[str, Any] = { | ||
"messageVersion": "1.0", | ||
"response": { | ||
"actionGroup": event.action_group, | ||
"function": event.function, | ||
"functionResponse": {"responseBody": {"TEXT": {"body": str(body if body is not None else "")}}}, | ||
}, | ||
} | ||
|
||
# Add responseState if it's an error | ||
if self.status_code >= 400: | ||
response["response"]["functionResponse"]["responseState"] = ( | ||
ResponseState.REPROMPT.value if self.status_code == 400 else ResponseState.FAILURE.value | ||
) | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# Add session attributes if provided in response or maintain from input | ||
response.update( | ||
{ | ||
"sessionAttributes": session_attributes or event.session_attributes or {}, | ||
"promptSessionAttributes": prompt_session_attributes or event.prompt_session_attributes or {}, | ||
}, | ||
) | ||
|
||
# Add knowledge bases configuration if provided | ||
if knowledge_bases: | ||
response["knowledgeBasesConfiguration"] = knowledge_bases | ||
|
||
return response | ||
|
||
|
||
class BedrockAgentFunctionResolver: | ||
"""Bedrock Agent Function resolver that handles function definitions | ||
|
||
Examples | ||
-------- | ||
```python | ||
from aws_lambda_powertools.event_handler import BedrockAgentFunctionResolver | ||
|
||
app = BedrockAgentFunctionResolver() | ||
|
||
@app.tool(description="Gets the current UTC time") | ||
def get_current_time(): | ||
from datetime import datetime | ||
return datetime.utcnow().isoformat() | ||
|
||
def lambda_handler(event, context): | ||
return app.resolve(event, context) | ||
``` | ||
""" | ||
|
||
def __init__(self) -> None: | ||
self._tools: dict[str, dict[str, Any]] = {} | ||
self.current_event: BedrockAgentFunctionEvent | None = None | ||
self._response_builder_class = BedrockFunctionsResponseBuilder | ||
|
||
def tool( | ||
self, | ||
description: str | None = None, | ||
name: str | None = None, | ||
) -> Callable: | ||
"""Decorator to register a tool function | ||
|
||
Parameters | ||
---------- | ||
description : str | None | ||
Description of what the tool does | ||
name : str | None | ||
Custom name for the tool. If not provided, uses the function name | ||
""" | ||
|
||
def decorator(func: Callable) -> Callable: | ||
if not description: | ||
raise ValueError("Tool description is required") | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
function_name = name or func.__name__ | ||
if function_name in self._tools: | ||
raise ValueError(f"Tool '{function_name}' already registered") | ||
anafalcao marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self._tools[function_name] = { | ||
"function": func, | ||
"description": description, | ||
} | ||
return func | ||
|
||
return decorator | ||
|
||
def resolve(self, event: dict[str, Any], context: Any) -> dict[str, Any]: | ||
"""Resolves the function call from Bedrock Agent event""" | ||
try: | ||
self.current_event = BedrockAgentFunctionEvent(event) | ||
return self._resolve() | ||
except KeyError as e: | ||
raise ValueError(f"Missing required field: {str(e)}") | ||
|
||
def _resolve(self) -> dict[str, Any]: | ||
"""Internal resolution logic""" | ||
if self.current_event is None: | ||
raise ValueError("No event to process") | ||
|
||
function_name = self.current_event.function | ||
|
||
if function_name not in self._tools: | ||
return BedrockFunctionsResponseBuilder( | ||
BedrockResponse( | ||
body=f"Function not found: {function_name}", | ||
status_code=400, # Using 400 to trigger REPROMPT | ||
), | ||
).build(self.current_event) | ||
|
||
try: | ||
result = self._tools[function_name]["function"]() | ||
leandrodamascena marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return BedrockFunctionsResponseBuilder(result).build(self.current_event) | ||
except Exception as e: | ||
return BedrockFunctionsResponseBuilder( | ||
BedrockResponse( | ||
body=f"Error: {str(e)}", | ||
status_code=500, # Using 500 to trigger FAILURE | ||
), | ||
).build(self.current_event) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
109 changes: 109 additions & 0 deletions
109
aws_lambda_powertools/utilities/data_classes/bedrock_agent_function_event.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Any | ||
|
||
from aws_lambda_powertools.utilities.data_classes.common import DictWrapper | ||
|
||
|
||
class BedrockAgentInfo(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def id(self) -> str: # noqa: A003 | ||
return self["id"] | ||
|
||
@property | ||
def alias(self) -> str: | ||
return self["alias"] | ||
|
||
@property | ||
def version(self) -> str: | ||
return self["version"] | ||
|
||
|
||
class BedrockAgentFunctionParameter(DictWrapper): | ||
@property | ||
def name(self) -> str: | ||
return self["name"] | ||
|
||
@property | ||
def type(self) -> str: # noqa: A003 | ||
return self["type"] | ||
|
||
@property | ||
def value(self) -> str: | ||
return self["value"] | ||
|
||
|
||
class BedrockAgentFunctionEvent(DictWrapper): | ||
""" | ||
Bedrock Agent Function input event | ||
|
||
Documentation: | ||
https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html | ||
""" | ||
|
||
@classmethod | ||
def validate_required_fields(cls, data: dict[str, Any]) -> None: | ||
required_fields = { | ||
"messageVersion": str, | ||
"agent": dict, | ||
"inputText": str, | ||
"sessionId": str, | ||
"actionGroup": str, | ||
"function": str, | ||
} | ||
|
||
for field, field_type in required_fields.items(): | ||
if field not in data: | ||
raise ValueError(f"Missing required field: {field}") | ||
if not isinstance(data[field], field_type): | ||
raise TypeError(f"Field {field} must be of type {field_type}") | ||
|
||
# Validate agent structure | ||
required_agent_fields = {"name", "id", "alias", "version"} | ||
if not all(field in data["agent"] for field in required_agent_fields): | ||
raise ValueError("Agent object missing required fields") | ||
|
||
def __init__(self, data: dict[str, Any]) -> None: | ||
super().__init__(data) | ||
self.validate_required_fields(data) | ||
|
||
@property | ||
def message_version(self) -> str: | ||
return self["messageVersion"] | ||
|
||
@property | ||
def input_text(self) -> str: | ||
return self["inputText"] | ||
|
||
@property | ||
def session_id(self) -> str: | ||
return self["sessionId"] | ||
|
||
@property | ||
def action_group(self) -> str: | ||
return self["actionGroup"] | ||
|
||
@property | ||
def function(self) -> str: | ||
return self["function"] | ||
|
||
@property | ||
def parameters(self) -> list[BedrockAgentFunctionParameter]: | ||
parameters = self.get("parameters") or [] | ||
return [BedrockAgentFunctionParameter(x) for x in parameters] | ||
|
||
@property | ||
def agent(self) -> BedrockAgentInfo: | ||
return BedrockAgentInfo(self["agent"]) | ||
|
||
@property | ||
def session_attributes(self) -> dict[str, str]: | ||
return self.get("sessionAttributes", {}) or {} | ||
|
||
@property | ||
def prompt_session_attributes(self) -> dict[str, str]: | ||
return self.get("promptSessionAttributes", {}) or {} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
{ | ||
"messageVersion": "1.0", | ||
"agent": { | ||
"alias": "PROD", | ||
"name": "hr-assistant-function-def", | ||
"version": "1", | ||
"id": "1234abcd" | ||
}, | ||
"sessionId": "123456789123458", | ||
"sessionAttributes": { | ||
"employeeId": "EMP123" | ||
}, | ||
"promptSessionAttributes": { | ||
"lastInteraction": "2024-02-01T15:30:00Z", | ||
"requestType": "vacation" | ||
}, | ||
"inputText": "I want to request vacation from March 15 to March 20", | ||
"actionGroup": "VacationsActionGroup", | ||
"function": "submitVacationRequest", | ||
"parameters": [ | ||
{ | ||
"name": "startDate", | ||
"type": "string", | ||
"value": "2024-03-15" | ||
}, | ||
{ | ||
"name": "endDate", | ||
"type": "string", | ||
"value": "2024-03-20" | ||
} | ||
] | ||
} |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.