Skip to content

Commit d3698d2

Browse files
SimonBFranksthulbleandrodamascena
authored
feat(logger): add thread safe logging keys (#5141)
* Getting baseline thread safe keys working * Functional tests for thread safe keys * Updating documentation * Cleanup * Small fixes for linting * Clearing thread local keys with clear_state=True * Cleaning up PR * Small fixes to docs * Fixing type annotations for THREAD_LOCAL_KEYS * Replacing '|' with {**dict1, **dict2} due to support of Python < 3.9 * fix types from v2 to v3 * Changing documentation and method names --------- Co-authored-by: Simon Thulbourn <[email protected]> Co-authored-by: Leandro Damascena <[email protected]>
1 parent fd609bc commit d3698d2

13 files changed

+459
-11
lines changed

Diff for: aws_lambda_powertools/logging/formatter.py

+81-2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import time
88
import traceback
99
from abc import ABCMeta, abstractmethod
10+
from contextvars import ContextVar
1011
from datetime import datetime, timezone
1112
from functools import partial
1213
from typing import TYPE_CHECKING, Any, Callable, Iterable
@@ -61,6 +62,21 @@ def clear_state(self) -> None:
6162
"""Removes any previously added logging keys"""
6263
raise NotImplementedError()
6364

65+
# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
66+
# They prevent race conditions and ensure data consistency across multiple threads.
67+
def thread_safe_append_keys(self, **additional_keys) -> None:
68+
raise NotImplementedError()
69+
70+
def thread_safe_get_current_keys(self) -> dict[str, Any]:
71+
return {}
72+
73+
def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
74+
raise NotImplementedError()
75+
76+
def thread_safe_clear_keys(self) -> None:
77+
"""Removes any previously added logging keys in a specific thread"""
78+
raise NotImplementedError()
79+
6480

6581
class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
6682
"""Powertools for AWS Lambda (Python) Logging formatter.
@@ -247,6 +263,24 @@ def clear_state(self) -> None:
247263
self.log_format = dict.fromkeys(self.log_record_order)
248264
self.log_format.update(**self.keys_combined)
249265

266+
# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
267+
# They prevent race conditions and ensure data consistency across multiple threads.
268+
def thread_safe_append_keys(self, **additional_keys) -> None:
269+
# Append additional key-value pairs to the context safely in a thread-safe manner.
270+
set_context_keys(**additional_keys)
271+
272+
def thread_safe_get_current_keys(self) -> dict[str, Any]:
273+
# Retrieve the current context keys safely in a thread-safe manner.
274+
return _get_context().get()
275+
276+
def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
277+
# Remove specified keys from the context safely in a thread-safe manner.
278+
remove_context_keys(keys)
279+
280+
def thread_safe_clear_keys(self) -> None:
281+
# Clear all keys from the context safely in a thread-safe manner.
282+
clear_context_keys()
283+
250284
@staticmethod
251285
def _build_default_keys() -> dict[str, str]:
252286
return {
@@ -345,14 +379,33 @@ def _extract_log_keys(self, log_record: logging.LogRecord) -> dict[str, Any]:
345379
record_dict["asctime"] = self.formatTime(record=log_record)
346380
extras = {k: v for k, v in record_dict.items() if k not in RESERVED_LOG_ATTRS}
347381

348-
formatted_log = {}
382+
formatted_log: dict[str, Any] = {}
349383

350384
# Iterate over a default or existing log structure
351385
# then replace any std log attribute e.g. '%(level)s' to 'INFO', '%(process)d to '4773'
386+
# check if the value is a str if the key is a reserved attribute, the modulo operator only supports string
352387
# lastly add or replace incoming keys (those added within the constructor or .structure_logs method)
353388
for key, value in self.log_format.items():
354389
if value and key in RESERVED_LOG_ATTRS:
355-
formatted_log[key] = value % record_dict
390+
if isinstance(value, str):
391+
formatted_log[key] = value % record_dict
392+
else:
393+
raise ValueError(
394+
"Logging keys that override reserved log attributes need to be type 'str', "
395+
f"instead got '{type(value).__name__}'",
396+
)
397+
else:
398+
formatted_log[key] = value
399+
400+
for key, value in _get_context().get().items():
401+
if value and key in RESERVED_LOG_ATTRS:
402+
if isinstance(value, str):
403+
formatted_log[key] = value % record_dict
404+
else:
405+
raise ValueError(
406+
"Logging keys that override reserved log attributes need to be type 'str', "
407+
f"instead got '{type(value).__name__}'",
408+
)
356409
else:
357410
formatted_log[key] = value
358411

@@ -370,3 +423,29 @@ def _strip_none_records(records: dict[str, Any]) -> dict[str, Any]:
370423

371424
# Fetch current and future parameters from PowertoolsFormatter that should be reserved
372425
RESERVED_FORMATTER_CUSTOM_KEYS: list[str] = inspect.getfullargspec(LambdaPowertoolsFormatter).args[1:]
426+
427+
# ContextVar for thread local keys
428+
THREAD_LOCAL_KEYS: ContextVar[dict[str, Any]] = ContextVar("THREAD_LOCAL_KEYS", default={})
429+
430+
431+
def _get_context() -> ContextVar[dict[str, Any]]:
432+
return THREAD_LOCAL_KEYS
433+
434+
435+
def clear_context_keys() -> None:
436+
_get_context().set({})
437+
438+
439+
def set_context_keys(**kwargs: dict[str, Any]) -> None:
440+
context = _get_context()
441+
context.set({**context.get(), **kwargs})
442+
443+
444+
def remove_context_keys(keys: Iterable[str]) -> None:
445+
context = _get_context()
446+
context_values = context.get()
447+
448+
for k in keys:
449+
context_values.pop(k, None)
450+
451+
context.set(context_values)

Diff for: aws_lambda_powertools/logging/logger.py

+19
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,24 @@ def get_current_keys(self) -> dict[str, Any]:
589589
def remove_keys(self, keys: Iterable[str]) -> None:
590590
self.registered_formatter.remove_keys(keys)
591591

592+
# These specific thread-safe methods are necessary to manage shared context in concurrent environments.
593+
# They prevent race conditions and ensure data consistency across multiple threads.
594+
def thread_safe_append_keys(self, **additional_keys: object) -> None:
595+
# Append additional key-value pairs to the context safely in a thread-safe manner.
596+
self.registered_formatter.thread_safe_append_keys(**additional_keys)
597+
598+
def thread_safe_get_current_keys(self) -> dict[str, Any]:
599+
# Retrieve the current context keys safely in a thread-safe manner.
600+
return self.registered_formatter.thread_safe_get_current_keys()
601+
602+
def thread_safe_remove_keys(self, keys: Iterable[str]) -> None:
603+
# Remove specified keys from the context safely in a thread-safe manner.
604+
self.registered_formatter.thread_safe_remove_keys(keys)
605+
606+
def thread_safe_clear_keys(self) -> None:
607+
# Clear all keys from the context safely in a thread-safe manner.
608+
self.registered_formatter.thread_safe_clear_keys()
609+
592610
def structure_logs(self, append: bool = False, formatter_options: dict | None = None, **keys) -> None:
593611
"""Sets logging formatting to JSON.
594612
@@ -633,6 +651,7 @@ def structure_logs(self, append: bool = False, formatter_options: dict | None =
633651

634652
# Mode 3
635653
self.registered_formatter.clear_state()
654+
self.registered_formatter.thread_safe_clear_keys()
636655
self.registered_formatter.append_keys(**log_keys)
637656

638657
def set_correlation_id(self, value: str | None) -> None:

Diff for: docs/core/event_handler/api_gateway.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,12 @@ Here's an example on how we can handle the `/todos` path.
128128

129129
When using Amazon API Gateway HTTP API to front your Lambda functions, you can use `APIGatewayHttpResolver`.
130130

131+
<!-- markdownlint-disable MD013 -->
131132
???+ note
132133
Using HTTP API v1 payload? Use `APIGatewayRestResolver` instead. `APIGatewayHttpResolver` defaults to v2 payload.
133134

134-
<!-- markdownlint-disable-next-line MD013 -->
135135
If you're using Terraform to deploy a HTTP API, note that it defaults the [payload_format_version](https://registry.terraform.io/providers/hashicorp/aws/latest/docs/resources/apigatewayv2_integration#payload_format_version){target="_blank" rel="nofollow"} value to 1.0 if not specified.
136+
<!-- markdownlint-enable MD013 -->
136137

137138
```python hl_lines="5 11" title="Using HTTP API resolver"
138139
--8<-- "examples/event_handler_rest/src/getting_started_http_api_resolver.py"

Diff for: docs/core/logger.md

+78-2
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,14 @@ To ease routine tasks like extracting correlation ID from popular event sources,
159159

160160
You can append additional keys using either mechanism:
161161

162-
* Persist new keys across all future log messages via `append_keys` method
162+
* New keys persist across all future log messages via `append_keys` method
163163
* Add additional keys on a per log message basis as a keyword=value, or via `extra` parameter
164+
* New keys persist across all future logs in a specific thread via `thread_safe_append_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section.
164165

165166
#### append_keys method
166167

167168
???+ warning
168-
`append_keys` is not thread-safe, please see [RFC](https://github.com/aws-powertools/powertools-lambda-python/issues/991){target="_blank"}.
169+
`append_keys` is not thread-safe, use [thread_safe_append_keys](#appending-thread-safe-additional-keys) instead
169170

170171
You can append your own keys to your existing Logger via `append_keys(**additional_key_values)` method.
171172

@@ -228,6 +229,16 @@ It accepts any dictionary, and all keyword arguments will be added as part of th
228229

229230
### Removing additional keys
230231

232+
You can remove additional keys using either mechanism:
233+
234+
* Remove new keys across all future log messages via `remove_keys` method
235+
* Remove keys persist across all future logs in a specific thread via `thread_safe_remove_keys` method. Check [Working with thread-safe keys](#working-with-thread-safe-keys) section.
236+
237+
???+ danger
238+
Keys added by `append_keys` can only be removed by `remove_keys` and thread-local keys added by `thread_safe_append_keys` can only be removed by `thread_safe_remove_keys` or `thread_safe_clear_keys`. Thread-local and normal logger keys are distinct values and can't be manipulated interchangeably.
239+
240+
#### remove_keys method
241+
231242
You can remove any additional key from Logger state using `remove_keys`.
232243

233244
=== "remove_keys.py"
@@ -284,6 +295,9 @@ You can view all currently configured keys from the Logger state using the `get_
284295
--8<-- "examples/logger/src/get_current_keys.py"
285296
```
286297

298+
???+ info
299+
For thread-local additional logging keys, use `get_current_thread_keys` instead
300+
287301
### Log levels
288302

289303
The default log level is `INFO`. It can be set using the `level` constructor option, `setLevel()` method or by using the `POWERTOOLS_LOG_LEVEL` environment variable.
@@ -473,6 +487,68 @@ You can use any of the following built-in JMESPath expressions as part of [injec
473487
| **APPLICATION_LOAD_BALANCER** | `'headers."x-amzn-trace-id"'` | ALB X-Ray Trace ID |
474488
| **EVENT_BRIDGE** | `"id"` | EventBridge Event ID |
475489

490+
### Working with thread-safe keys
491+
492+
#### Appending thread-safe additional keys
493+
494+
You can append your own thread-local keys in your existing Logger via the `thread_safe_append_keys` method
495+
496+
=== "thread_safe_append_keys.py"
497+
498+
```python hl_lines="11"
499+
--8<-- "examples/logger/src/thread_safe_append_keys.py"
500+
```
501+
502+
=== "thread_safe_append_keys_output.json"
503+
504+
```json hl_lines="8 9 17 18"
505+
--8<-- "examples/logger/src/thread_safe_append_keys_output.json"
506+
```
507+
508+
#### Removing thread-safe additional keys
509+
510+
You can remove any additional thread-local keys from Logger using either `thread_safe_remove_keys` or `thread_safe_clear_keys`.
511+
512+
Use the `thread_safe_remove_keys` method to remove a list of thread-local keys that were previously added using the `thread_safe_append_keys` method.
513+
514+
=== "thread_safe_remove_keys.py"
515+
516+
```python hl_lines="13"
517+
--8<-- "examples/logger/src/thread_safe_remove_keys.py"
518+
```
519+
520+
=== "thread_safe_remove_keys_output.json"
521+
522+
```json hl_lines="8 9 17 18 26 34"
523+
--8<-- "examples/logger/src/thread_safe_remove_keys_output.json"
524+
```
525+
526+
#### Clearing thread-safe additional keys
527+
528+
Use the `thread_safe_clear_keys` method to remove all thread-local keys that were previously added using the `thread_safe_append_keys` method.
529+
530+
=== "thread_safe_clear_keys.py"
531+
532+
```python hl_lines="13"
533+
--8<-- "examples/logger/src/thread_safe_clear_keys.py"
534+
```
535+
536+
=== "thread_safe_clear_keys_output.json"
537+
538+
```json hl_lines="8 9 17 18"
539+
--8<-- "examples/logger/src/thread_safe_clear_keys_output.json"
540+
```
541+
542+
#### Accessing thread-safe currently keys
543+
544+
You can view all currently thread-local keys from the Logger state using the `thread_safe_get_current_keys()` method. This method is useful when you need to avoid overwriting keys that are already configured.
545+
546+
=== "thread_safe_get_current_keys.py"
547+
548+
```python hl_lines="13"
549+
--8<-- "examples/logger/src/thread_safe_get_current_keys.py"
550+
```
551+
476552
### Reusing Logger across your code
477553

478554
Similar to [Tracer](./tracer.md#reusing-tracer-across-your-code){target="_blank"}, a new instance that uses the same `service` name will reuse a previous Logger instance.

Diff for: examples/logger/src/thread_safe_append_keys.py

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import threading
2+
from typing import List
3+
4+
from aws_lambda_powertools import Logger
5+
from aws_lambda_powertools.utilities.typing import LambdaContext
6+
7+
logger = Logger()
8+
9+
10+
def threaded_func(order_id: str):
11+
logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident())
12+
logger.info("Collecting payment")
13+
14+
15+
def lambda_handler(event: dict, context: LambdaContext) -> str:
16+
order_ids: List[str] = event["order_ids"]
17+
18+
threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
19+
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()
20+
21+
return "hello world"
+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
[
2+
{
3+
"level": "INFO",
4+
"location": "threaded_func:11",
5+
"message": "Collecting payment",
6+
"timestamp": "2024-09-08 03:04:11,316-0400",
7+
"service": "payment",
8+
"order_id": "order_id_value_1",
9+
"thread_id": "3507187776085958"
10+
},
11+
{
12+
"level": "INFO",
13+
"location": "threaded_func:11",
14+
"message": "Collecting payment",
15+
"timestamp": "2024-09-08 03:04:11,316-0400",
16+
"service": "payment",
17+
"order_id": "order_id_value_2",
18+
"thread_id": "140718447808512"
19+
}
20+
]

Diff for: examples/logger/src/thread_safe_clear_keys.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import threading
2+
from typing import List
3+
4+
from aws_lambda_powertools import Logger
5+
from aws_lambda_powertools.utilities.typing import LambdaContext
6+
7+
logger = Logger()
8+
9+
10+
def threaded_func(order_id: str):
11+
logger.thread_safe_append_keys(order_id=order_id, thread_id=threading.get_ident())
12+
logger.info("Collecting payment")
13+
logger.thread_safe_clear_keys()
14+
logger.info("Exiting thread")
15+
16+
17+
def lambda_handler(event: dict, context: LambdaContext) -> str:
18+
order_ids: List[str] = event["order_ids"]
19+
20+
threading.Thread(target=threaded_func, args=(order_ids[0],)).start()
21+
threading.Thread(target=threaded_func, args=(order_ids[1],)).start()
22+
23+
return "hello world"
+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
[
2+
{
3+
"level": "INFO",
4+
"location": "threaded_func:11",
5+
"message": "Collecting payment",
6+
"timestamp": "2024-09-08 12:26:10,648-0400",
7+
"service": "payment",
8+
"order_id": "order_id_value_1",
9+
"thread_id": 140077070292544
10+
},
11+
{
12+
"level": "INFO",
13+
"location": "threaded_func:11",
14+
"message": "Collecting payment",
15+
"timestamp": "2024-09-08 12:26:10,649-0400",
16+
"service": "payment",
17+
"order_id": "order_id_value_2",
18+
"thread_id": 140077061899840
19+
},
20+
{
21+
"level": "INFO",
22+
"location": "threaded_func:13",
23+
"message": "Exiting thread",
24+
"timestamp": "2024-09-08 12:26:10,649-0400",
25+
"service": "payment"
26+
},
27+
{
28+
"level": "INFO",
29+
"location": "threaded_func:13",
30+
"message": "Exiting thread",
31+
"timestamp": "2024-09-08 12:26:10,649-0400",
32+
"service": "payment"
33+
}
34+
]

Diff for: examples/logger/src/thread_safe_get_current_keys.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from aws_lambda_powertools import Logger
2+
from aws_lambda_powertools.utilities.typing import LambdaContext
3+
4+
logger = Logger()
5+
6+
7+
@logger.inject_lambda_context
8+
def lambda_handler(event: dict, context: LambdaContext) -> str:
9+
logger.info("Collecting payment")
10+
11+
if "order" not in logger.thread_safe_get_current_keys():
12+
logger.thread_safe_append_keys(order=event.get("order"))
13+
14+
return "hello world"

0 commit comments

Comments
 (0)