Skip to content

Commit 5e668a5

Browse files
Enable Ruff format
2 parents d31ee43 + e9cb5e5 commit 5e668a5

File tree

6 files changed

+112
-10
lines changed

6 files changed

+112
-10
lines changed

aws_lambda_powertools/event_handler/bedrock_agent.py

+6-7
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,11 @@ def get( # type: ignore[override]
110110
tags: list[str] | None = None,
111111
operation_id: str | None = None,
112112
include_in_schema: bool = True,
113+
openapi_extensions: dict[str, Any] | None = None,
113114
deprecated: bool = False,
114115
custom_response_validation_http_code: int | HTTPStatus | None = None,
115116
middlewares: list[Callable[..., Any]] | None = None,
116117
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
117-
openapi_extensions = None
118118
security = None
119119

120120
return super().get(
@@ -151,11 +151,11 @@ def post( # type: ignore[override]
151151
tags: list[str] | None = None,
152152
operation_id: str | None = None,
153153
include_in_schema: bool = True,
154+
openapi_extensions: dict[str, Any] | None = None,
154155
deprecated: bool = False,
155156
custom_response_validation_http_code: int | HTTPStatus | None = None,
156157
middlewares: list[Callable[..., Any]] | None = None,
157158
):
158-
openapi_extensions = None
159159
security = None
160160

161161
return super().post(
@@ -192,11 +192,11 @@ def put( # type: ignore[override]
192192
tags: list[str] | None = None,
193193
operation_id: str | None = None,
194194
include_in_schema: bool = True,
195+
openapi_extensions: dict[str, Any] | None = None,
195196
deprecated: bool = False,
196197
custom_response_validation_http_code: int | HTTPStatus | None = None,
197198
middlewares: list[Callable[..., Any]] | None = None,
198199
):
199-
openapi_extensions = None
200200
security = None
201201

202202
return super().put(
@@ -233,11 +233,11 @@ def patch( # type: ignore[override]
233233
tags: list[str] | None = None,
234234
operation_id: str | None = None,
235235
include_in_schema: bool = True,
236+
openapi_extensions: dict[str, Any] | None = None,
236237
deprecated: bool = False,
237238
custom_response_validation_http_code: int | HTTPStatus | None = None,
238239
middlewares: list[Callable] | None = None,
239240
):
240-
openapi_extensions = None
241241
security = None
242242

243243
return super().patch(
@@ -274,11 +274,11 @@ def delete( # type: ignore[override]
274274
tags: list[str] | None = None,
275275
operation_id: str | None = None,
276276
include_in_schema: bool = True,
277+
openapi_extensions: dict[str, Any] | None = None,
277278
deprecated: bool = False,
278279
custom_response_validation_http_code: int | HTTPStatus | None = None,
279280
middlewares: list[Callable[..., Any]] | None = None,
280281
):
281-
openapi_extensions = None
282282
security = None
283283

284284
return super().delete(
@@ -325,6 +325,7 @@ def get_openapi_json_schema( # type: ignore[override]
325325
license_info: License | None = None,
326326
security_schemes: dict[str, SecurityScheme] | None = None,
327327
security: list[dict[str, list[str]]] | None = None,
328+
openapi_extensions: dict[str, Any] | None = None,
328329
) -> str:
329330
"""
330331
Returns the OpenAPI schema as a JSON serializable dict.
@@ -365,8 +366,6 @@ def get_openapi_json_schema( # type: ignore[override]
365366
"""
366367
from aws_lambda_powertools.event_handler.openapi.compat import model_json
367368

368-
openapi_extensions = None
369-
370369
schema = super().get_openapi_schema(
371370
title=title,
372371
version=version,

aws_lambda_powertools/logging/logger.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -242,9 +242,6 @@ def __init__(
242242
buffer_config: LoggerBufferConfig | None = None,
243243
**kwargs,
244244
) -> None:
245-
# Used in case of sampling
246-
self.initial_log_level = self._determine_log_level(level)
247-
248245
self.service = resolve_env_var_choice(
249246
choice=service,
250247
env=os.getenv(constants.SERVICE_NAME_ENV, "service_undefined"),
@@ -284,6 +281,9 @@ def __init__(
284281
if self._buffer_config:
285282
self._buffer_cache = LoggerBufferCache(max_size_bytes=self._buffer_config.max_bytes)
286283

284+
# Used in case of sampling
285+
self.initial_log_level = self._determine_log_level(level)
286+
287287
self._init_logger(
288288
formatter_options=formatter_options,
289289
log_level=level,
@@ -1046,6 +1046,20 @@ def _determine_log_level(self, level: str | int | None) -> str | int:
10461046
stacklevel=2,
10471047
)
10481048

1049+
# Check if buffer level is less verbose than ALC
1050+
if (
1051+
hasattr(self, "_buffer_config")
1052+
and self._buffer_config
1053+
and logging.getLevelName(lambda_log_level)
1054+
> logging.getLevelName(self._buffer_config.buffer_at_verbosity)
1055+
):
1056+
warnings.warn(
1057+
"Advanced Logging Controls (ALC) Log Level is less verbose than Log Buffering Log Level. "
1058+
"Buffered logs will be filtered by ALC",
1059+
PowertoolsUserWarning,
1060+
stacklevel=2,
1061+
)
1062+
10491063
# AWS Lambda Advanced Logging Controls takes precedence over Powertools log level and we use this
10501064
if lambda_log_level:
10511065
return lambda_log_level
@@ -1132,6 +1146,7 @@ def _add_log_record_to_buffer(
11321146
Handles special first invocation buffering and migration of log records
11331147
between different tracer contexts.
11341148
"""
1149+
11351150
# Determine tracer ID, defaulting to first invoke marker
11361151
tracer_id = get_tracer_id()
11371152

@@ -1179,6 +1194,7 @@ def flush_buffer(self) -> None:
11791194
Any exceptions from underlying logging or buffer mechanisms
11801195
will be propagated to caller
11811196
"""
1197+
11821198
tracer_id = get_tracer_id()
11831199

11841200
# Flushing log without a tracer id? Return
@@ -1190,6 +1206,21 @@ def flush_buffer(self) -> None:
11901206
if not buffer:
11911207
return
11921208

1209+
if not self._buffer_config:
1210+
return
1211+
1212+
# Check ALC level against buffer level
1213+
lambda_log_level = self._get_aws_lambda_log_level()
1214+
if lambda_log_level:
1215+
# Check if buffer level is less verbose than ALC
1216+
if logging.getLevelName(lambda_log_level) > logging.getLevelName(self._buffer_config.buffer_at_verbosity):
1217+
warnings.warn(
1218+
"Advanced Logging Controls (ALC) Log Level is less verbose than Log Buffering Log Level. "
1219+
"Some logs might be missing",
1220+
PowertoolsUserWarning,
1221+
stacklevel=2,
1222+
)
1223+
11931224
# Process log records
11941225
for log_line in buffer:
11951226
self._create_and_flush_log_record(log_line)

docs/core/event_handler/bedrock_agents.md

+10
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,16 @@ To implement these customizations, include extra parameters when defining your r
313313
--8<-- "examples/event_handler_bedrock_agents/src/customizing_bedrock_api_operations.py"
314314
```
315315

316+
#### Enabling user confirmation
317+
318+
You can enable user confirmation with Bedrock Agents to have your application ask for explicit user approval before invoking an action.
319+
320+
```python hl_lines="14" title="enabling_user_confirmation.py" title="Enabling user confirmation"
321+
--8<-- "examples/event_handler_bedrock_agents/src/enabling_user_confirmation.py"
322+
```
323+
324+
1. Add an openapi extension
325+
316326
## Testing your code
317327

318328
Test your routes by passing an [Agent for Amazon Bedrock proxy event](https://docs.aws.amazon.com/bedrock/latest/userguide/agents-lambda.html#agents-lambda-input) request:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from time import time
2+
3+
from aws_lambda_powertools import Logger
4+
from aws_lambda_powertools.event_handler import BedrockAgentResolver
5+
from aws_lambda_powertools.utilities.typing import LambdaContext
6+
7+
logger = Logger()
8+
app = BedrockAgentResolver()
9+
10+
11+
@app.get(
12+
"/current_time",
13+
description="Gets the current time in seconds",
14+
openapi_extensions={"x-requireConfirmation": "ENABLED"}, # (1)!
15+
)
16+
def current_time() -> int:
17+
return int(time())
18+
19+
20+
@logger.inject_lambda_context
21+
def lambda_handler(event: dict, context: LambdaContext):
22+
return app.resolve(event, context)
23+
24+
25+
if __name__ == "__main__":
26+
print(app.get_openapi_json_schema())

tests/functional/event_handler/_pydantic/test_bedrock_agent.py

+16
Original file line numberDiff line numberDiff line change
@@ -200,3 +200,19 @@ def handler() -> Optional[Dict]:
200200
# THEN the schema must be a valid 3.0.3 version
201201
assert openapi30_schema(schema)
202202
assert schema.get("openapi") == "3.0.3"
203+
204+
205+
def test_bedrock_resolver_with_openapi_extensions():
206+
# GIVEN BedrockAgentResolver is initialized with enable_validation=True
207+
app = BedrockAgentResolver(enable_validation=True)
208+
209+
# WHEN we have a simple handler with openapi extension
210+
@app.get("/", description="Testing", openapi_extensions={"x-requireConfirmation": "ENABLED"})
211+
def handler() -> Optional[Dict]:
212+
pass
213+
214+
# WHEN we get the schema
215+
schema = json.loads(app.get_openapi_json_schema())
216+
217+
# THEN the OpenAPI schema must contain the "x-requireConfirmation" extension at the operation level
218+
assert schema["paths"]["/"]["get"]["x-requireConfirmation"] == "ENABLED"

tests/functional/logger/required_dependencies/test_powertools_logger_buffer.py

+20
Original file line numberDiff line numberDiff line change
@@ -524,3 +524,23 @@ def handler(event, context):
524524

525525
# THEN Verify buffer for the original trace ID is cleared
526526
assert not logger._buffer_cache.get("1-67c39786-5908a82a246fb67f3089263f")
527+
528+
529+
def test_warning_when_alc_less_verbose_than_buffer(stdout, monkeypatch):
530+
# GIVEN Lambda ALC set to INFO
531+
monkeypatch.setenv("AWS_LAMBDA_LOG_LEVEL", "INFO")
532+
# Set initial trace ID for first Lambda invocation
533+
monkeypatch.setenv(constants.XRAY_TRACE_ID_ENV, "1-67c39786-5908a82a246fb67f3089263f")
534+
535+
# WHEN creating a logger with DEBUG buffer level
536+
# THEN a warning should be emitted
537+
with pytest.warns(PowertoolsUserWarning, match="Advanced Logging Controls*"):
538+
logger = Logger(service="test", level="DEBUG", buffer_config=LoggerBufferConfig(buffer_at_verbosity="DEBUG"))
539+
540+
# AND logging a debug message
541+
logger.debug("This is a debug")
542+
543+
# AND flushing buffer
544+
# THEN another warning should be emitted about ALC and buffer level mismatch
545+
with pytest.warns(PowertoolsUserWarning, match="Advanced Logging Controls*"):
546+
logger.flush_buffer()

0 commit comments

Comments
 (0)