diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 8ac911d4cac..35054f86137 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -387,16 +387,26 @@ def structure_logs(self, append: bool = False, **keys): formatter = self.logger_formatter or LambdaPowertoolsFormatter(**log_keys) # type: ignore self.registered_handler.setFormatter(formatter) - def set_correlation_id(self, value: str): + def set_correlation_id(self, value: Optional[str]): """Sets the correlation_id in the logging json Parameters ---------- - value : str - Value for the correlation id + value : str, optional + Value for the correlation id. None will remove the correlation_id """ self.append_keys(correlation_id=value) + def get_correlation_id(self) -> Optional[str]: + """Gets the correlation_id in the logging json + + Returns + ------- + str, optional + Value for the correlation id + """ + return self.registered_formatter.log_format.get("correlation_id") + @staticmethod def _get_log_level(level: Union[str, int, None]) -> Union[str, int]: """Returns preferred log level set by the customer in upper case""" diff --git a/tests/functional/test_logger.py b/tests/functional/test_logger.py index 44249af6250..a8d92c05257 100644 --- a/tests/functional/test_logger.py +++ b/tests/functional/test_logger.py @@ -460,6 +460,18 @@ def handler(event, _): assert request_id == log["correlation_id"] +def test_logger_get_correlation_id(lambda_context, stdout, service_name): + # GIVEN a logger with a correlation_id set + logger = Logger(service=service_name, stream=stdout) + logger.set_correlation_id("foo") + + # WHEN calling get_correlation_id + correlation_id = logger.get_correlation_id() + + # THEN it should return the correlation_id + assert "foo" == correlation_id + + def test_logger_set_correlation_id_path(lambda_context, stdout, service_name): # GIVEN logger = Logger(service=service_name, stream=stdout)