From 96cbdc1910aa4bdcdb47efef587c04a17f75436e Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Sun, 22 Aug 2021 08:28:22 +0200 Subject: [PATCH 1/2] fix(idempotency): sorting keys before hashing --- .../utilities/idempotency/persistence/base.py | 4 +-- tests/functional/idempotency/conftest.py | 14 +++++--- .../idempotency/test_idempotency.py | 34 ++++++++++++++----- 3 files changed, 37 insertions(+), 15 deletions(-) diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index 2f5dd512ac6..4901e9f9f75 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -223,7 +223,7 @@ def _generate_hash(self, data: Any) -> str: """ data = getattr(data, "raw_event", data) # could be a data class depending on decorator order - hashed_data = self.hash_function(json.dumps(data, cls=Encoder).encode()) + hashed_data = self.hash_function(json.dumps(data, cls=Encoder, sort_keys=True).encode()) return hashed_data.hexdigest() def _validate_payload(self, data: Dict[str, Any], data_record: DataRecord) -> None: @@ -310,7 +310,7 @@ def save_success(self, data: Dict[str, Any], result: dict) -> None: result: dict The response from function """ - response_data = json.dumps(result, cls=Encoder) + response_data = json.dumps(result, cls=Encoder, sort_keys=True) data_record = DataRecord( idempotency_key=self._get_hashed_idempotency_key(data=data), diff --git a/tests/functional/idempotency/conftest.py b/tests/functional/idempotency/conftest.py index e613bb85e60..2c528cafc50 100644 --- a/tests/functional/idempotency/conftest.py +++ b/tests/functional/idempotency/conftest.py @@ -21,6 +21,10 @@ TABLE_NAME = "TEST_TABLE" +def serialize(data): + return json.dumps(data, sort_keys=True, cls=Encoder) + + @pytest.fixture(scope="module") def config() -> Config: return Config(region_name="us-east-1") @@ -62,12 +66,12 @@ def lambda_response(): @pytest.fixture(scope="module") def serialized_lambda_response(lambda_response): - return json.dumps(lambda_response, cls=Encoder) + return serialize(lambda_response) @pytest.fixture(scope="module") def deserialized_lambda_response(lambda_response): - return json.loads(json.dumps(lambda_response, cls=Encoder)) + return json.loads(serialize(lambda_response)) @pytest.fixture @@ -144,7 +148,7 @@ def expected_params_put_item_with_validation(hashed_idempotency_key, hashed_vali def hashed_idempotency_key(lambda_apigw_event, default_jmespath, lambda_context): compiled_jmespath = jmespath.compile(default_jmespath) data = compiled_jmespath.search(lambda_apigw_event) - return "test-func#" + hashlib.md5(json.dumps(data).encode()).hexdigest() + return "test-func#" + hashlib.md5(serialize(data).encode()).hexdigest() @pytest.fixture @@ -152,12 +156,12 @@ def hashed_idempotency_key_with_envelope(lambda_apigw_event): event = extract_data_from_envelope( data=lambda_apigw_event, envelope=envelopes.API_GATEWAY_HTTP, jmespath_options={} ) - return "test-func#" + hashlib.md5(json.dumps(event).encode()).hexdigest() + return "test-func#" + hashlib.md5(serialize(event).encode()).hexdigest() @pytest.fixture def hashed_validation_key(lambda_apigw_event): - return hashlib.md5(json.dumps(lambda_apigw_event["requestContext"]).encode()).hexdigest() + return hashlib.md5(serialize(lambda_apigw_event["requestContext"]).encode()).hexdigest() @pytest.fixture diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 5505a7dc5c9..cb0d43ae6fa 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -21,6 +21,7 @@ from aws_lambda_powertools.utilities.idempotency.idempotency import idempotent, idempotent_function from aws_lambda_powertools.utilities.idempotency.persistence.base import BasePersistenceLayer, DataRecord from aws_lambda_powertools.utilities.validation import envelopes, validator +from tests.functional.idempotency.conftest import serialize from tests.functional.utils import load_event TABLE_NAME = "TEST_TABLE" @@ -741,7 +742,7 @@ def test_default_no_raise_on_missing_idempotency_key( hashed_key = persistence_store._get_hashed_idempotency_key({}) # THEN return the hash of None - expected_value = "test-func#" + md5(json.dumps(None).encode()).hexdigest() + expected_value = "test-func#" + md5(serialize(None).encode()).hexdigest() assert expected_value == hashed_key @@ -785,7 +786,7 @@ def test_jmespath_with_powertools_json( expected_value = [sub_attr_value, key_attr_value] api_gateway_proxy_event = { "requestContext": {"authorizer": {"claims": {"sub": sub_attr_value}}}, - "body": json.dumps({"id": key_attr_value}), + "body": serialize({"id": key_attr_value}), } # WHEN calling _get_hashed_idempotency_key @@ -869,7 +870,7 @@ def _delete_record(self, data_record: DataRecord) -> None: def test_idempotent_lambda_event_source(lambda_context): # Scenario to validate that we can use the event_source decorator before or after the idempotent decorator mock_event = load_event("apiGatewayProxyV2Event.json") - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} # GIVEN an event_source decorator @@ -889,7 +890,7 @@ def lambda_handler(event, _): def test_idempotent_function(): # Scenario to validate we can use idempotent_function with any function mock_event = {"data": "value"} - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") @@ -906,7 +907,7 @@ def test_idempotent_function_arbitrary_args_kwargs(): # Scenario to validate we can use idempotent_function with a function # with an arbitrary number of args and kwargs mock_event = {"data": "value"} - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") @@ -921,7 +922,7 @@ def record_handler(arg_one, arg_two, record, is_record): def test_idempotent_function_invalid_data_kwarg(): mock_event = {"data": "value"} - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} keyword_argument = "payload" @@ -938,7 +939,7 @@ def record_handler(record): def test_idempotent_function_arg_instead_of_kwarg(): mock_event = {"data": "value"} - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} keyword_argument = "record" @@ -956,7 +957,7 @@ def record_handler(record): def test_idempotent_function_and_lambda_handler(lambda_context): # Scenario to validate we can use both idempotent_function and idempotent decorators mock_event = {"data": "value"} - persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(serialize(mock_event).encode()).hexdigest()) expected_result = {"message": "Foo"} @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") @@ -976,3 +977,20 @@ def lambda_handler(event, _): # THEN we expect the function and lambda handler to execute successfully assert fn_result == expected_result assert handler_result == expected_result + + +def test_idempotent_data_sorting(): + # Scenario to validate same data in different order hashes to the same idempotency key + data_one = {"data": "test message 1", "more_data": "more data 1"} + data_two = {"more_data": "more data 1", "data": "test message 1"} + + # Assertion will happen in MockPersistenceLayer + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(data_one).encode()).hexdigest()) + + # GIVEN + @idempotent_function(data_keyword_argument="payload", persistence_store=persistence_layer) + def dummy(payload): + return {"message": "hello"} + + # WHEN + dummy(payload=data_two) From a7936214afd769b68c66aae95d4afdfb2e786a0b Mon Sep 17 00:00:00 2001 From: heitorlessa Date: Fri, 1 Oct 2021 15:12:22 +0200 Subject: [PATCH 2/2] chore: add return types, ignore some, correct signatures --- aws_lambda_powertools/logging/formatter.py | 8 ++--- aws_lambda_powertools/logging/logger.py | 6 ++-- aws_lambda_powertools/metrics/base.py | 6 ++-- aws_lambda_powertools/metrics/metric.py | 2 +- aws_lambda_powertools/metrics/metrics.py | 30 +++++++++++-------- .../middleware_factory/factory.py | 2 +- .../shared/jmespath_utils.py | 9 +++--- aws_lambda_powertools/tracing/tracer.py | 2 +- .../utilities/data_classes/sqs_event.py | 4 +-- .../idempotency/persistence/dynamodb.py | 2 +- .../utilities/validation/base.py | 4 +-- mypy.ini | 6 ++++ 12 files changed, 47 insertions(+), 34 deletions(-) diff --git a/aws_lambda_powertools/logging/formatter.py b/aws_lambda_powertools/logging/formatter.py index de9254a3371..e35c9a7a327 100644 --- a/aws_lambda_powertools/logging/formatter.py +++ b/aws_lambda_powertools/logging/formatter.py @@ -58,7 +58,7 @@ class LambdaPowertoolsFormatter(BasePowertoolsFormatter): def __init__( self, json_serializer: Optional[Callable[[Dict], str]] = None, - json_deserializer: Optional[Callable[[Dict], str]] = None, + json_deserializer: Optional[Callable[[Union[Dict, str, bool, int, float]], str]] = None, json_default: Optional[Callable[[Any], Any]] = None, datefmt: Optional[str] = None, log_record_order: Optional[List[str]] = None, @@ -106,7 +106,7 @@ def __init__( self.update_formatter = self.append_keys # alias to old method if self.utc: - self.converter = time.gmtime + self.converter = time.gmtime # type: ignore super(LambdaPowertoolsFormatter, self).__init__(datefmt=self.datefmt) @@ -128,7 +128,7 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003 return self.serialize(log=formatted_log) def formatTime(self, record: logging.LogRecord, datefmt: Optional[str] = None) -> str: - record_ts = self.converter(record.created) + record_ts = self.converter(record.created) # type: ignore if datefmt: return time.strftime(datefmt, record_ts) @@ -201,7 +201,7 @@ def _extract_log_exception(self, log_record: logging.LogRecord) -> Union[Tuple[s Log record with constant traceback info and exception name """ if log_record.exc_info: - return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__ + return self.formatException(log_record.exc_info), log_record.exc_info[0].__name__ # type: ignore return None, None diff --git a/aws_lambda_powertools/logging/logger.py b/aws_lambda_powertools/logging/logger.py index 35054f86137..e8b67a2ca7e 100644 --- a/aws_lambda_powertools/logging/logger.py +++ b/aws_lambda_powertools/logging/logger.py @@ -361,7 +361,7 @@ def registered_handler(self) -> logging.Handler: return handlers[0] @property - def registered_formatter(self) -> Optional[PowertoolsFormatter]: + def registered_formatter(self) -> PowertoolsFormatter: """Convenience property to access logger formatter""" return self.registered_handler.formatter # type: ignore @@ -405,7 +405,9 @@ def get_correlation_id(self) -> Optional[str]: str, optional Value for the correlation id """ - return self.registered_formatter.log_format.get("correlation_id") + if isinstance(self.registered_formatter, LambdaPowertoolsFormatter): + return self.registered_formatter.log_format.get("correlation_id") + return None @staticmethod def _get_log_level(level: Union[str, int, None]) -> Union[str, int]: diff --git a/aws_lambda_powertools/metrics/base.py b/aws_lambda_powertools/metrics/base.py index 853f06f210b..25e502d0887 100644 --- a/aws_lambda_powertools/metrics/base.py +++ b/aws_lambda_powertools/metrics/base.py @@ -90,7 +90,7 @@ def __init__( self._metric_unit_options = list(MetricUnit.__members__) self.metadata_set = metadata_set if metadata_set is not None else {} - def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float): + def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None: """Adds given metric Example @@ -215,7 +215,7 @@ def serialize_metric_set( **metric_names_and_values, # "single_metric": 1.0 } - def add_dimension(self, name: str, value: str): + def add_dimension(self, name: str, value: str) -> None: """Adds given dimension to all metrics Example @@ -241,7 +241,7 @@ def add_dimension(self, name: str, value: str): # checking before casting improves performance in most cases self.dimension_set[name] = value if isinstance(value, str) else str(value) - def add_metadata(self, key: str, value: Any): + def add_metadata(self, key: str, value: Any) -> None: """Adds high cardinal metadata for metrics object This will not be available during metrics visualization. diff --git a/aws_lambda_powertools/metrics/metric.py b/aws_lambda_powertools/metrics/metric.py index 1ac2bd9450e..a30f428e38e 100644 --- a/aws_lambda_powertools/metrics/metric.py +++ b/aws_lambda_powertools/metrics/metric.py @@ -42,7 +42,7 @@ class SingleMetric(MetricManager): Inherits from `aws_lambda_powertools.metrics.base.MetricManager` """ - def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float): + def add_metric(self, name: str, unit: Union[MetricUnit, str], value: float) -> None: """Method to prevent more than one metric being created Parameters diff --git a/aws_lambda_powertools/metrics/metrics.py b/aws_lambda_powertools/metrics/metrics.py index fafc604b505..23e9f542eea 100644 --- a/aws_lambda_powertools/metrics/metrics.py +++ b/aws_lambda_powertools/metrics/metrics.py @@ -2,8 +2,9 @@ import json import logging import warnings -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, Union, cast +from ..shared.types import AnyCallableT from .base import MetricManager, MetricUnit from .metric import single_metric @@ -87,7 +88,7 @@ def __init__(self, service: Optional[str] = None, namespace: Optional[str] = Non service=self.service, ) - def set_default_dimensions(self, **dimensions): + def set_default_dimensions(self, **dimensions) -> None: """Persist dimensions across Lambda invocations Parameters @@ -113,10 +114,10 @@ def lambda_handler(): self.default_dimensions.update(**dimensions) - def clear_default_dimensions(self): + def clear_default_dimensions(self) -> None: self.default_dimensions.clear() - def clear_metrics(self): + def clear_metrics(self) -> None: logger.debug("Clearing out existing metric set from memory") self.metric_set.clear() self.dimension_set.clear() @@ -125,11 +126,11 @@ def clear_metrics(self): def log_metrics( self, - lambda_handler: Optional[Callable[[Any, Any], Any]] = None, + lambda_handler: Union[Callable[[Dict, Any], Any], Optional[Callable[[Dict, Any, Optional[Dict]], Any]]] = None, capture_cold_start_metric: bool = False, raise_on_empty_metrics: bool = False, default_dimensions: Optional[Dict[str, str]] = None, - ): + ) -> AnyCallableT: """Decorator to serialize and publish metrics at the end of a function execution. Be aware that the log_metrics **does call* the decorated function (e.g. lambda_handler). @@ -169,11 +170,14 @@ def handler(event, context): # Return a partial function with args filled if lambda_handler is None: logger.debug("Decorator called with parameters") - return functools.partial( - self.log_metrics, - capture_cold_start_metric=capture_cold_start_metric, - raise_on_empty_metrics=raise_on_empty_metrics, - default_dimensions=default_dimensions, + return cast( + AnyCallableT, + functools.partial( + self.log_metrics, + capture_cold_start_metric=capture_cold_start_metric, + raise_on_empty_metrics=raise_on_empty_metrics, + default_dimensions=default_dimensions, + ), ) @functools.wraps(lambda_handler) @@ -194,9 +198,9 @@ def decorate(event, context): return response - return decorate + return cast(AnyCallableT, decorate) - def __add_cold_start_metric(self, context: Any): + def __add_cold_start_metric(self, context: Any) -> None: """Add cold start metric and function_name dimension Parameters diff --git a/aws_lambda_powertools/middleware_factory/factory.py b/aws_lambda_powertools/middleware_factory/factory.py index 74858bf6709..8ab16c5e8b7 100644 --- a/aws_lambda_powertools/middleware_factory/factory.py +++ b/aws_lambda_powertools/middleware_factory/factory.py @@ -118,7 +118,7 @@ def final_decorator(func: Optional[Callable] = None, **kwargs): if not inspect.isfunction(func): # @custom_middleware(True) vs @custom_middleware(log_event=True) raise MiddlewareInvalidArgumentError( - f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" + f"Only keyword arguments is supported for middlewares: {decorator.__qualname__} received {func}" # type: ignore # noqa: E501 ) @functools.wraps(func) diff --git a/aws_lambda_powertools/shared/jmespath_utils.py b/aws_lambda_powertools/shared/jmespath_utils.py index 9cc736aedfb..bbb3b38fe04 100644 --- a/aws_lambda_powertools/shared/jmespath_utils.py +++ b/aws_lambda_powertools/shared/jmespath_utils.py @@ -6,22 +6,23 @@ import jmespath from jmespath.exceptions import LexerError +from jmespath.functions import Functions, signature from aws_lambda_powertools.exceptions import InvalidEnvelopeExpressionError logger = logging.getLogger(__name__) -class PowertoolsFunctions(jmespath.functions.Functions): - @jmespath.functions.signature({"types": ["string"]}) +class PowertoolsFunctions(Functions): + @signature({"types": ["string"]}) def _func_powertools_json(self, value): return json.loads(value) - @jmespath.functions.signature({"types": ["string"]}) + @signature({"types": ["string"]}) def _func_powertools_base64(self, value): return base64.b64decode(value).decode() - @jmespath.functions.signature({"types": ["string"]}) + @signature({"types": ["string"]}) def _func_powertools_base64_gzip(self, value): encoded = base64.b64decode(value) uncompressed = gzip.decompress(encoded) diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index dc010a3712f..2626793304c 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -17,7 +17,7 @@ logger = logging.getLogger(__name__) aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) -aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) +aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) # type: ignore # noqa: E501 class Tracer: diff --git a/aws_lambda_powertools/utilities/data_classes/sqs_event.py b/aws_lambda_powertools/utilities/data_classes/sqs_event.py index 0e70684cc3f..57caeea4cc2 100644 --- a/aws_lambda_powertools/utilities/data_classes/sqs_event.py +++ b/aws_lambda_powertools/utilities/data_classes/sqs_event.py @@ -75,9 +75,9 @@ def data_type(self) -> str: class SQSMessageAttributes(Dict[str, SQSMessageAttribute]): - def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: + def __getitem__(self, key: str) -> Optional[SQSMessageAttribute]: # type: ignore item = super(SQSMessageAttributes, self).get(key) - return None if item is None else SQSMessageAttribute(item) + return None if item is None else SQSMessageAttribute(item) # type: ignore class SQSRecord(DictWrapper): diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py b/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py index 73f241bd613..c1ed29c6fd3 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py @@ -155,7 +155,7 @@ def _update_record(self, data_record: DataRecord): "ExpressionAttributeNames": expression_attr_names, } - self.table.update_item(**kwargs) # type: ignore + self.table.update_item(**kwargs) def _delete_record(self, data_record: DataRecord) -> None: logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}") diff --git a/aws_lambda_powertools/utilities/validation/base.py b/aws_lambda_powertools/utilities/validation/base.py index 2a337b85971..61d692d7f28 100644 --- a/aws_lambda_powertools/utilities/validation/base.py +++ b/aws_lambda_powertools/utilities/validation/base.py @@ -33,10 +33,10 @@ def validate_data_against_schema(data: Union[Dict, str], schema: Dict, formats: except (TypeError, AttributeError, fastjsonschema.JsonSchemaDefinitionException) as e: raise InvalidSchemaFormatError(f"Schema received: {schema}, Formats: {formats}. Error: {e}") except fastjsonschema.JsonSchemaValueException as e: - message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}" + message = f"Failed schema validation. Error: {e.message}, Path: {e.path}, Data: {e.value}" # noqa: B306 raise SchemaValidationError( message, - validation_message=e.message, + validation_message=e.message, # noqa: B306 name=e.name, path=e.path, value=e.value, diff --git a/mypy.ini b/mypy.ini index 2436d7074d2..faf6014a54d 100644 --- a/mypy.ini +++ b/mypy.ini @@ -11,6 +11,12 @@ show_error_context = True [mypy-jmespath] ignore_missing_imports=True +[mypy-jmespath.exceptions] +ignore_missing_imports=True + +[mypy-jmespath.functions] +ignore_missing_imports=True + [mypy-boto3] ignore_missing_imports = True