Skip to content

Commit 903653f

Browse files
committed
feat: bring your own formatter
1 parent 2f828cd commit 903653f

File tree

3 files changed

+74
-10
lines changed

3 files changed

+74
-10
lines changed

aws_lambda_powertools/logging/formatter.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import os
44
import time
5+
from abc import ABCMeta, abstractmethod
56
from functools import partial
67
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
78

@@ -35,7 +36,17 @@
3536
)
3637

3738

38-
class LambdaPowertoolsFormatter(logging.Formatter):
39+
class BasePowertoolsFormatter(logging.Formatter, metaclass=ABCMeta):
40+
@abstractmethod
41+
def append_keys(self, **additional_keys):
42+
raise NotImplementedError()
43+
44+
@abstractmethod
45+
def remove_keys(self, keys: Iterable[str]):
46+
raise NotImplementedError()
47+
48+
49+
class LambdaPowertoolsFormatter(BasePowertoolsFormatter):
3950
"""AWS Lambda Powertools Logging formatter.
4051
4152
Formats the log message as a JSON encoded string. If the message is a

aws_lambda_powertools/logging/logger.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice
1313
from .exceptions import InvalidLoggerSamplingRateError
1414
from .filters import SuppressFilter
15-
from .formatter import LambdaPowertoolsFormatter
15+
from .formatter import BasePowertoolsFormatter, LambdaPowertoolsFormatter
1616
from .lambda_context import build_lambda_context_model
1717

1818
logger = logging.getLogger(__name__)
@@ -124,6 +124,7 @@ def __init__(
124124
child: bool = False,
125125
sampling_rate: float = None,
126126
stream: sys.stdout = None,
127+
logger_formatter: Optional[BasePowertoolsFormatter] = None,
127128
**kwargs,
128129
):
129130
self.service = resolve_env_var_choice(
@@ -132,11 +133,12 @@ def __init__(
132133
self.sampling_rate = resolve_env_var_choice(
133134
choice=sampling_rate, env=os.getenv(constants.LOGGER_LOG_SAMPLING_RATE)
134135
)
136+
self.child = child
137+
self.logger_formatter = logger_formatter
138+
self.log_level = self._get_log_level(level)
135139
self._is_deduplication_disabled = resolve_truthy_env_var_choice(
136140
env=os.getenv(constants.LOGGER_LOG_DEDUPLICATION_ENV, "false")
137141
)
138-
self.log_level = self._get_log_level(level)
139-
self.child = child
140142
self._handler = logging.StreamHandler(stream) or logging.StreamHandler(sys.stdout)
141143
self._default_log_keys = {"service": self.service, "sampling_rate": self.sampling_rate}
142144
self._logger = self._get_logger()
@@ -292,13 +294,13 @@ def remove_keys(self, keys: Iterable[str]):
292294

293295
@property
294296
def registered_handler(self) -> logging.Handler:
295-
"""Registered Logger handler"""
297+
"""Convenience property to access logger handler"""
296298
handlers = self._logger.parent.handlers if self.child else self._logger.handlers
297299
return handlers[0]
298300

299301
@property
300-
def registered_formatter(self) -> Optional[LambdaPowertoolsFormatter]:
301-
"""Registered Logger formatter"""
302+
def registered_formatter(self) -> Optional[BasePowertoolsFormatter]:
303+
"""Convenience property to access logger formatter"""
302304
return self.registered_handler.formatter
303305

304306
def structure_logs(self, append: bool = False, **keys):
@@ -312,15 +314,16 @@ def structure_logs(self, append: bool = False, **keys):
312314
Parameters
313315
----------
314316
append : bool, optional
315-
[description], by default False
317+
append keys provided to logger formatter, by default False
316318
"""
317319

318320
if append:
319321
# Maintenance: Add deprecation warning for major version. Refer to append_keys() when docs are updated
320322
self.append_keys(**keys)
321323
else:
322-
# Set a new formatter for a logger handler
323-
self.registered_handler.setFormatter(LambdaPowertoolsFormatter(**self._default_log_keys, **keys))
324+
log_keys = {**self._default_log_keys, **keys}
325+
formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys)
326+
self.registered_handler.setFormatter(formatter)
324327

325328
def set_correlation_id(self, value: str):
326329
"""Sets the correlation_id in the logging json

tests/functional/test_logger.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
import random
66
import string
77
from collections import namedtuple
8+
from typing import Iterable
89

910
import pytest
1011

1112
from aws_lambda_powertools import Logger, Tracer
1213
from aws_lambda_powertools.logging import correlation_paths
1314
from aws_lambda_powertools.logging.exceptions import InvalidLoggerSamplingRateError
15+
from aws_lambda_powertools.logging.formatter import BasePowertoolsFormatter
1416
from aws_lambda_powertools.logging.logger import set_package_logger
1517
from aws_lambda_powertools.shared import constants
1618

@@ -494,3 +496,51 @@ def test_logger_append_remove_keys(stdout, service_name):
494496

495497
assert extra_keys.items() <= extra_keys_log.items()
496498
assert (extra_keys.items() <= keys_removed_log.items()) is False
499+
500+
501+
def test_logger_custom_formatter(stdout, service_name, lambda_context):
502+
class CustomFormatter(BasePowertoolsFormatter):
503+
custom_format = {}
504+
505+
def append_keys(self, **additional_keys):
506+
self.custom_format.update(additional_keys)
507+
508+
def remove_keys(self, keys: Iterable[str]):
509+
for key in keys:
510+
self.custom_format.pop(key, None)
511+
512+
def format(self, record: logging.LogRecord) -> str: # noqa: A003
513+
return json.dumps(
514+
{
515+
"message": super().format(record),
516+
"timestamp": self.formatTime(record),
517+
"my_default_key": "test",
518+
**self.custom_format,
519+
}
520+
)
521+
522+
custom_formatter = CustomFormatter()
523+
524+
# GIVEN a Logger is initialized with a custom formatter
525+
logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter)
526+
527+
# WHEN a lambda function is decorated with logger
528+
@logger.inject_lambda_context
529+
def handler(event, context):
530+
logger.info("Hello")
531+
532+
handler({}, lambda_context)
533+
534+
lambda_context_keys = (
535+
"function_name",
536+
"function_memory_size",
537+
"function_arn",
538+
"function_request_id",
539+
)
540+
541+
log = capture_logging_output(stdout)
542+
543+
# THEN custom key should always be present
544+
# and lambda contextual info should also be in the logs
545+
assert "my_default_key" in log
546+
assert all(k in log for k in lambda_context_keys)

0 commit comments

Comments
 (0)