diff --git a/aws_lambda_powertools/logging/correlation_paths.py b/aws_lambda_powertools/logging/correlation_paths.py new file mode 100644 index 00000000000..73227754363 --- /dev/null +++ b/aws_lambda_powertools/logging/correlation_paths.py @@ -0,0 +1,6 @@ +"""Built-in correlation paths""" + +API_GATEWAY_REST = "requestContext.requestId" +API_GATEWAY_HTTP = API_GATEWAY_REST +APPLICATION_LOAD_BALANCER = "headers.x-amzn-trace-id" +EVENT_BRIDGE = "id" diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 06285e73b82..98ecfc4c449 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -6,6 +6,8 @@ import sys from typing import Any, Callable, Dict, Union +import jmespath + from ..shared import constants from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice from .exceptions import InvalidLoggerSamplingRateError @@ -204,7 +206,9 @@ def _configure_sampling(self): f"Please review POWERTOOLS_LOGGER_SAMPLE_RATE environment variable." ) - def inject_lambda_context(self, lambda_handler: Callable[[Dict, Any], Any] = None, log_event: bool = None): + def inject_lambda_context( + self, lambda_handler: Callable[[Dict, Any], Any] = None, log_event: bool = None, correlation_id_path: str = None + ): """Decorator to capture Lambda contextual info and inject into logger Parameters @@ -213,6 +217,8 @@ def inject_lambda_context(self, lambda_handler: Callable[[Dict, Any], Any] = Non Method to inject the lambda context log_event : bool, optional Instructs logger to log Lambda Event, by default False + correlation_id_path: str, optional + Optional JMESPath for the correlation_id Environment variables --------------------- @@ -251,7 +257,9 @@ def handler(event, context): # Return a partial function with args filled if lambda_handler is None: logger.debug("Decorator called with parameters") - return functools.partial(self.inject_lambda_context, log_event=log_event) + return functools.partial( + self.inject_lambda_context, log_event=log_event, correlation_id_path=correlation_id_path + ) log_event = resolve_truthy_env_var_choice( choice=log_event, env=os.getenv(constants.LOGGER_LOG_EVENT_ENV, "false") @@ -263,6 +271,9 @@ def decorate(event, context): cold_start = _is_cold_start() self.structure_logs(append=True, cold_start=cold_start, **lambda_context.__dict__) + if correlation_id_path: + self.set_correlation_id(jmespath.search(correlation_id_path, event)) + if log_event: logger.debug("Event received") self.info(event) @@ -296,6 +307,16 @@ def structure_logs(self, append: bool = False, **kwargs): # Set a new formatter for a logger handler handler.setFormatter(JsonFormatter(**self._default_log_keys, **kwargs)) + def set_correlation_id(self, value: str): + """Sets the correlation_id in the logging json + + Parameters + ---------- + value : str + Value for the correlation id + """ + self.structure_logs(append=True, correlation_id=value) + @staticmethod def _get_log_level(level: Union[str, int]) -> Union[str, int]: """ Returns preferred log level set by the customer in upper case """ diff --git a/docs/core/logger.md b/docs/core/logger.md index 2c5c347eadf..27cbd725f80 100644 --- a/docs/core/logger.md +++ b/docs/core/logger.md @@ -188,6 +188,47 @@ You can append your own keys to your existing Logger via `structure_logs(append= This example will add `order_id` if its value is not empty, and in subsequent invocations where `order_id` might not be present it'll remove it from the logger. +#### Setting correlation ID + +You can set a correlation_id to your existing Logger via `set_correlation_id(value)` method. + +=== "collect.py" + + ```python hl_lines="8" + from aws_lambda_powertools import Logger + from aws_lambda_powertools.utilities.data_classes import APIGatewayProxyEvent + + logger = Logger() + + def handler(event, context): + event = APIGatewayProxyEvent(event) + logger.set_correlation_id(event.request_context.request_id) + logger.info("Collecting payment") + ... + ``` +=== "Example Event" + + ```json hl_lines="3" + { + "requestContext": { + "requestId": "correlation_id_value" + } + } + ``` +=== "Example CloudWatch Logs excerpt" + + ```json hl_lines="7" + { + "timestamp": "2020-05-24 18:17:33,774", + "level": "INFO", + "location": "collect.handler:1", + "service": "payment", + "sampling_rate": 0.0, + "correlation_id": "correlation_id_value", + "message": "Collecting payment" + } + ``` + #### extra parameter Extra parameter is available for all log levels' methods, as implemented in the standard logging library - e.g. `logger.info, logger.warning`. diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index 2b4c7cf187e..ddf5ee226f5 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -9,6 +9,7 @@ import pytest from aws_lambda_powertools import Logger, Tracer +from aws_lambda_powertools.logging import correlation_paths from aws_lambda_powertools.logging.exceptions import InvalidLoggerSamplingRateError from aws_lambda_powertools.logging.logger import set_package_logger from aws_lambda_powertools.shared import constants @@ -437,3 +438,39 @@ def test_logger_exception_extract_exception_name(stdout, service_name): # THEN we expect a "exception_name" to be "ValueError" log = capture_logging_output(stdout) assert "ValueError" == log["exception_name"] + + +def test_logger_set_correlation_id(lambda_context, stdout, service_name): + # GIVEN + logger = Logger(service=service_name, stream=stdout) + request_id = "xxx-111-222" + mock_event = {"requestContext": {"requestId": request_id}} + + def handler(event, _): + logger.set_correlation_id(event["requestContext"]["requestId"]) + logger.info("Foo") + + # WHEN + handler(mock_event, lambda_context) + + # THEN + log = capture_logging_output(stdout) + assert request_id == log["correlation_id"] + + +def test_logger_set_correlation_id_path(lambda_context, stdout, service_name): + # GIVEN + logger = Logger(service=service_name, stream=stdout) + request_id = "xxx-111-222" + mock_event = {"requestContext": {"requestId": request_id}} + + @logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_REST) + def handler(event, context): + logger.info("Foo") + + # WHEN + handler(mock_event, lambda_context) + + # THEN + log = capture_logging_output(stdout) + assert request_id == log["correlation_id"]