Skip to content

Commit e9cb5e5

Browse files
feat(bedrock): add openapi_extensions in BedrockAgentResolver (#6510)
* Adding OpenAPI extensions to Bedrock * Adding OpenAPI extensions to Bedrock * Adding OpenAPI extensions to Bedrock * Adding OpenAPI extensions to Bedrock
1 parent 2c2d8e8 commit e9cb5e5

File tree

4 files changed

+58
-7
lines changed

4 files changed

+58
-7
lines changed

aws_lambda_powertools/event_handler/bedrock_agent.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def get( # type: ignore[override]
110110
tags: list[str] | None = None,
111111
operation_id: str | None = None,
112112
include_in_schema: bool = True,
113+
openapi_extensions: dict[str, Any] | None = None,
113114
deprecated: bool = False,
114115
custom_response_validation_http_code: int | HTTPStatus | None = None,
115116
middlewares: list[Callable[..., Any]] | None = None,
116117
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
117-
openapi_extensions = None
118118
security = None
119119

120120
return super().get(
@@ -151,11 +151,11 @@ def post( # type: ignore[override]
151151
tags: list[str] | None = None,
152152
operation_id: str | None = None,
153153
include_in_schema: bool = True,
154+
openapi_extensions: dict[str, Any] | None = None,
154155
deprecated: bool = False,
155156
custom_response_validation_http_code: int | HTTPStatus | None = None,
156157
middlewares: list[Callable[..., Any]] | None = None,
157158
):
158-
openapi_extensions = None
159159
security = None
160160

161161
return super().post(
@@ -192,11 +192,11 @@ def put( # type: ignore[override]
192192
tags: list[str] | None = None,
193193
operation_id: str | None = None,
194194
include_in_schema: bool = True,
195+
openapi_extensions: dict[str, Any] | None = None,
195196
deprecated: bool = False,
196197
custom_response_validation_http_code: int | HTTPStatus | None = None,
197198
middlewares: list[Callable[..., Any]] | None = None,
198199
):
199-
openapi_extensions = None
200200
security = None
201201

202202
return super().put(
@@ -233,11 +233,11 @@ def patch( # type: ignore[override]
233233
tags: list[str] | None = None,
234234
operation_id: str | None = None,
235235
include_in_schema: bool = True,
236+
openapi_extensions: dict[str, Any] | None = None,
236237
deprecated: bool = False,
237238
custom_response_validation_http_code: int | HTTPStatus | None = None,
238239
middlewares: list[Callable] | None = None,
239240
):
240-
openapi_extensions = None
241241
security = None
242242

243243
return super().patch(
@@ -274,11 +274,11 @@ def delete( # type: ignore[override]
274274
tags: list[str] | None = None,
275275
operation_id: str | None = None,
276276
include_in_schema: bool = True,
277+
openapi_extensions: dict[str, Any] | None = None,
277278
deprecated: bool = False,
278279
custom_response_validation_http_code: int | HTTPStatus | None = None,
279280
middlewares: list[Callable[..., Any]] | None = None,
280281
):
281-
openapi_extensions = None
282282
security = None
283283

284284
return super().delete(
@@ -325,6 +325,7 @@ def get_openapi_json_schema( # type: ignore[override]
325325
license_info: License | None = None,
326326
security_schemes: dict[str, SecurityScheme] | None = None,
327327
security: list[dict[str, list[str]]] | None = None,
328+
openapi_extensions: dict[str, Any] | None = None,
328329
) -> str:
329330
"""
330331
Returns the OpenAPI schema as a JSON serializable dict.
@@ -365,8 +366,6 @@ def get_openapi_json_schema( # type: ignore[override]
365366
"""
366367
from aws_lambda_powertools.event_handler.openapi.compat import model_json
367368

368-
openapi_extensions = None
369-
370369
schema = super().get_openapi_schema(
371370
title=title,
372371
version=version,

docs/core/event_handler/bedrock_agents.md

+10
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ To implement these customizations, include extra parameters when defining your r
313313
--8<-- "examples/event_handler_bedrock_agents/src/customizing_bedrock_api_operations.py"
314314
```
315315

316+
#### Enabling user confirmation
317+
318+
You can enable user confirmation with Bedrock Agents to have your application ask for explicit user approval before invoking an action.
319+
320+
```python hl_lines="14" title="enabling_user_confirmation.py" title="Enabling user confirmation"
321+
--8<-- "examples/event_handler_bedrock_agents/src/enabling_user_confirmation.py"
322+
```
323+
324+
1. Add an openapi extension
325+
316326
## Testing your code
317327

318328
Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from time import time
2+
3+
from aws_lambda_powertools import Logger
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
5+
from aws_lambda_powertools.utilities.typing import LambdaContext
6+
7+
logger = Logger()
8+
app = BedrockAgentResolver()
9+
10+
11+
@app.get(
12+
"/current_time",
13+
description="Gets the current time in seconds",
14+
openapi_extensions={"x-requireConfirmation": "ENABLED"}, # (1)!
15+
)
16+
def current_time() -> int:
17+
return int(time())
18+
19+
20+
@logger.inject_lambda_context
21+
def lambda_handler(event: dict, context: LambdaContext):
22+
return app.resolve(event, context)
23+
24+
25+
if __name__ == "__main__":
26+
print(app.get_openapi_json_schema())

tests/functional/event_handler/_pydantic/test_bedrock_agent.py

+16
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,19 @@ def handler() -> Optional[Dict]:
200200
# THEN the schema must be a valid 3.0.3 version
201201
assert openapi30_schema(schema)
202202
assert schema.get("openapi") == "3.0.3"
203+
204+
205+
def test_bedrock_resolver_with_openapi_extensions():
206+
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
207+
app = BedrockAgentResolver(enable_validation=True)
208+
209+
# WHEN we have a simple handler with openapi extension
210+
@app.get("/", description="Testing", openapi_extensions={"x-requireConfirmation": "ENABLED"})
211+
def handler() -> Optional[Dict]:
212+
pass
213+
214+
# WHEN we get the schema
215+
schema = json.loads(app.get_openapi_json_schema())
216+
217+
# THEN the OpenAPI schema must contain the "x-requireConfirmation" extension at the operation level
218+
assert schema["paths"]["/"]["get"]["x-requireConfirmation"] == "ENABLED"

0 commit comments

Comments
 (0)