Skip to content

feat(bedrock): add openapi_extensions in BedrockAgentResolver #6510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ def get( # type: ignore[override]
tags: list[str] | None = None,
operation_id: str | None = None,
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
openapi_extensions = None
security = None

return super().get(
Expand Down Expand Up @@ -151,11 +151,11 @@ def post( # type: ignore[override]
tags: list[str] | None = None,
operation_id: str | None = None,
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
security = None

return super().post(
Expand Down Expand Up @@ -192,11 +192,11 @@ def put( # type: ignore[override]
tags: list[str] | None = None,
operation_id: str | None = None,
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
security = None

return super().put(
Expand Down Expand Up @@ -233,11 +233,11 @@ def patch( # type: ignore[override]
tags: list[str] | None = None,
operation_id: str | None = None,
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable] | None = None,
):
openapi_extensions = None
security = None

return super().patch(
Expand Down Expand Up @@ -274,11 +274,11 @@ def delete( # type: ignore[override]
tags: list[str] | None = None,
operation_id: str | None = None,
include_in_schema: bool = True,
openapi_extensions: dict[str, Any] | None = None,
deprecated: bool = False,
custom_response_validation_http_code: int | HTTPStatus | None = None,
middlewares: list[Callable[..., Any]] | None = None,
):
openapi_extensions = None
security = None

return super().delete(
Expand Down Expand Up @@ -325,6 +325,7 @@ def get_openapi_json_schema( # type: ignore[override]
license_info: License | None = None,
security_schemes: dict[str, SecurityScheme] | None = None,
security: list[dict[str, list[str]]] | None = None,
openapi_extensions: dict[str, Any] | None = None,
) -> str:
"""
Returns the OpenAPI schema as a JSON serializable dict.
Expand Down Expand Up @@ -365,8 +366,6 @@ def get_openapi_json_schema( # type: ignore[override]
"""
from aws_lambda_powertools.event_handler.openapi.compat import model_json

openapi_extensions = None

schema = super().get_openapi_schema(
title=title,
version=version,
Expand Down
10 changes: 10 additions & 0 deletions docs/core/event_handler/bedrock_agents.md
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,16 @@ To implement these customizations, include extra parameters when defining your r
--8<-- "examples/event_handler_bedrock_agents/src/customizing_bedrock_api_operations.py"
```

#### Enabling user confirmation

You can enable user confirmation with Bedrock Agents to have your application ask for explicit user approval before invoking an action.

```python hl_lines="14" title="enabling_user_confirmation.py" title="Enabling user confirmation"
--8<-- "examples/event_handler_bedrock_agents/src/enabling_user_confirmation.py"
```

1. Add an openapi extension

## Testing your code

Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from time import time

from aws_lambda_powertools import Logger
from aws_lambda_powertools.event_handler import BedrockAgentResolver
from aws_lambda_powertools.utilities.typing import LambdaContext

logger = Logger()
app = BedrockAgentResolver()


@app.get(
"/current_time",
description="Gets the current time in seconds",
openapi_extensions={"x-requireConfirmation": "ENABLED"}, # (1)!
)
def current_time() -> int:
return int(time())


@logger.inject_lambda_context
def lambda_handler(event: dict, context: LambdaContext):
return app.resolve(event, context)


if __name__ == "__main__":
print(app.get_openapi_json_schema())
16 changes: 16 additions & 0 deletions tests/functional/event_handler/_pydantic/test_bedrock_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,3 +200,19 @@ def handler() -> Optional[Dict]:
# THEN the schema must be a valid 3.0.3 version
assert openapi30_schema(schema)
assert schema.get("openapi") == "3.0.3"


def test_bedrock_resolver_with_openapi_extensions():
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
app = BedrockAgentResolver(enable_validation=True)

# WHEN we have a simple handler with openapi extension
@app.get("/", description="Testing", openapi_extensions={"x-requireConfirmation": "ENABLED"})
def handler() -> Optional[Dict]:
pass

# WHEN we get the schema
schema = json.loads(app.get_openapi_json_schema())

# THEN the OpenAPI schema must contain the "x-requireConfirmation" extension at the operation level
assert schema["paths"]["/"]["get"]["x-requireConfirmation"] == "ENABLED"
Loading