diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index c9751b0ca12..c3183e0df84 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -7,6 +7,7 @@ import json import logging from abc import ABC, abstractmethod +from types import MappingProxyType from typing import Any, Dict import jmespath @@ -21,7 +22,7 @@ logger = logging.getLogger(__name__) -STATUS_CONSTANTS = {"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"} +STATUS_CONSTANTS = MappingProxyType({"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"}) class DataRecord: @@ -81,8 +82,7 @@ def status(self) -> str: """ if self.is_expired: return STATUS_CONSTANTS["EXPIRED"] - - if self._status in STATUS_CONSTANTS.values(): + elif self._status in STATUS_CONSTANTS.values(): return self._status else: raise IdempotencyInvalidStatusError(self._status) @@ -214,14 +214,14 @@ def _validate_payload(self, lambda_event: Dict[str, Any], data_record: DataRecor DataRecord instance Raises - ______ + ---------- IdempotencyValidationError Event payload doesn't match the stored record for the given idempotency key """ if self.payload_validation_enabled: lambda_payload_hash = self._get_hashed_payload(lambda_event) - if not data_record.payload_hash == lambda_payload_hash: + if data_record.payload_hash != lambda_payload_hash: raise IdempotencyValidationError("Payload does not match stored record for this event key") def _get_expiry_timestamp(self) -> int: @@ -238,9 +238,30 @@ def _get_expiry_timestamp(self) -> int: return int((now + period).timestamp()) def _save_to_cache(self, data_record: DataRecord): + """ + Save data_record to local cache except when status is "INPROGRESS" + + NOTE: We can't cache "INPROGRESS" records as we have no way to reflect updates that can happen outside of the + execution environment + + Parameters + ---------- + data_record: DataRecord + DataRecord instance + + Returns + ------- + + """ + if not self.use_local_cache: + return + if data_record.status == STATUS_CONSTANTS["INPROGRESS"]: + return self._cache[data_record.idempotency_key] = data_record def _retrieve_from_cache(self, idempotency_key: str): + if not self.use_local_cache: + return cached_record = self._cache.get(idempotency_key) if cached_record: if not cached_record.is_expired: @@ -249,11 +270,13 @@ def _retrieve_from_cache(self, idempotency_key: str): self._delete_from_cache(idempotency_key) def _delete_from_cache(self, idempotency_key: str): + if not self.use_local_cache: + return del self._cache[idempotency_key] def save_success(self, event: Dict[str, Any], result: dict) -> None: """ - Save record of function's execution completing succesfully + Save record of function's execution completing successfully Parameters ---------- @@ -277,8 +300,7 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None: ) self._update_record(data_record=data_record) - if self.use_local_cache: - self._save_to_cache(data_record) + self._save_to_cache(data_record) def save_inprogress(self, event: Dict[str, Any]) -> None: """ @@ -298,18 +320,11 @@ def save_inprogress(self, event: Dict[str, Any]) -> None: logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}") - if self.use_local_cache: - cached_record = self._retrieve_from_cache(idempotency_key=data_record.idempotency_key) - if cached_record: - raise IdempotencyItemAlreadyExistsError + if self._retrieve_from_cache(idempotency_key=data_record.idempotency_key): + raise IdempotencyItemAlreadyExistsError self._put_record(data_record) - # This has to come after _put_record. If _put_record call raises ItemAlreadyExists we shouldn't populate the - # cache with an "INPROGRESS" record as we don't know the status in the data store at this point. - if self.use_local_cache: - self._save_to_cache(data_record) - def delete_record(self, event: Dict[str, Any], exception: Exception): """ Delete record from the persistence store @@ -329,8 +344,7 @@ def delete_record(self, event: Dict[str, Any], exception: Exception): ) self._delete_record(data_record) - if self.use_local_cache: - self._delete_from_cache(data_record.idempotency_key) + self._delete_from_cache(data_record.idempotency_key) def get_record(self, event: Dict[str, Any]) -> DataRecord: """ @@ -356,17 +370,15 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord: idempotency_key = self._get_hashed_idempotency_key(event) - if self.use_local_cache: - cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key) - if cached_record: - logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}") - self._validate_payload(event, cached_record) - return cached_record + cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key) + if cached_record: + logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}") + self._validate_payload(event, cached_record) + return cached_record record = self._get_record(idempotency_key) - if self.use_local_cache: - self._save_to_cache(data_record=record) + self._save_to_cache(data_record=record) self._validate_payload(event, record) return record diff --git a/docs/utilities/idempotency.md b/docs/utilities/idempotency.md index 1c6555088d9..6bc7457d603 100644 --- a/docs/utilities/idempotency.md +++ b/docs/utilities/idempotency.md @@ -3,8 +3,11 @@ title: Idempotency description: Utility --- -This utility provides a simple solution to convert your Lambda functions into idempotent operations which are safe to -retry. +!!! attention + **This utility is currently in beta**. Please open an [issue in GitHub](https://github.com/awslabs/aws-lambda-powertools-python/issues/new/choose) for any bugs or feature requests. + +The idempotency utility provides a simple solution to convert your Lambda functions into idempotent operations which +are safe to retry. ## Terminology @@ -31,31 +34,31 @@ storage layer, so you'll need to create a table first. > Example using AWS Serverless Application Model (SAM) === "template.yml" -```yaml -Resources: - HelloWorldFunction: - Type: AWS::Serverless::Function - Properties: - Runtime: python3.8 - ... - Policies: - - DynamoDBCrudPolicy: - TableName: !Ref IdempotencyTable - IdempotencyTable: - Type: AWS::DynamoDB::Table - Properties: - AttributeDefinitions: - - AttributeName: id - AttributeType: S - BillingMode: PAY_PER_REQUEST - KeySchema: - - AttributeName: id - KeyType: HASH - TableName: "IdempotencyTable" - TimeToLiveSpecification: - AttributeName: expiration - Enabled: true -``` + ```yaml + Resources: + HelloWorldFunction: + Type: AWS::Serverless::Function + Properties: + Runtime: python3.8 + ... + Policies: + - DynamoDBCrudPolicy: + TableName: !Ref IdempotencyTable + IdempotencyTable: + Type: AWS::DynamoDB::Table + Properties: + AttributeDefinitions: + - AttributeName: id + AttributeType: S + BillingMode: PAY_PER_REQUEST + KeySchema: + - AttributeName: id + KeyType: HASH + TableName: "IdempotencyTable" + TimeToLiveSpecification: + AttributeName: expiration + Enabled: true + ``` !!! warning When using this utility with DynamoDB, your lambda responses must always be smaller than 400kb. Larger items cannot diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index e6e64e3b38b..269ab6f9b33 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -133,6 +133,12 @@ def test_idempotent_lambda_in_progress_with_cache( stubber.add_client_error("put_item", "ConditionalCheckFailedException") stubber.add_response("get_item", ddb_response, expected_params) + + stubber.add_client_error("put_item", "ConditionalCheckFailedException") + stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params)) + + stubber.add_client_error("put_item", "ConditionalCheckFailedException") + stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params)) stubber.activate() @idempotent(persistence_store=persistence_store) @@ -151,11 +157,8 @@ def lambda_handler(event, context): assert retrieve_from_cache_spy.call_count == 2 * loops retrieve_from_cache_spy.assert_called_with(idempotency_key=hashed_idempotency_key) - assert save_to_cache_spy.call_count == 1 - first_call_args_data_record = save_to_cache_spy.call_args_list[0].kwargs["data_record"] - assert first_call_args_data_record.idempotency_key == hashed_idempotency_key - assert first_call_args_data_record.status == "INPROGRESS" - assert persistence_store._cache.get(hashed_idempotency_key) + save_to_cache_spy.assert_called() + assert persistence_store._cache.get(hashed_idempotency_key) is None stubber.assert_no_pending_responses() stubber.deactivate() @@ -223,12 +226,10 @@ def lambda_handler(event, context): lambda_handler(lambda_apigw_event, {}) - assert retrieve_from_cache_spy.call_count == 1 - assert save_to_cache_spy.call_count == 2 - first_call_args, second_call_args = save_to_cache_spy.call_args_list - assert first_call_args.args[0].status == "INPROGRESS" - assert second_call_args.args[0].status == "COMPLETED" - assert persistence_store._cache.get(hashed_idempotency_key) + retrieve_from_cache_spy.assert_called_once() + save_to_cache_spy.assert_called_once() + assert save_to_cache_spy.call_args[0][0].status == "COMPLETED" + assert persistence_store._cache.get(hashed_idempotency_key).status == "COMPLETED" # This lambda call should not call AWS API lambda_handler(lambda_apigw_event, {}) @@ -594,3 +595,35 @@ def test_data_record_invalid_status_value(): _ = data_record.status assert e.value.args[0] == "UNSUPPORTED_STATUS" + + +@pytest.mark.parametrize("persistence_store", [{"use_local_cache": True}], indirect=True) +def test_in_progress_never_saved_to_cache(persistence_store): + # GIVEN a data record with status "INPROGRESS" + # and persistence_store has use_local_cache = True + data_record = DataRecord("key", status="INPROGRESS") + + # WHEN saving to local cache + persistence_store._save_to_cache(data_record) + + # THEN don't save to local cache + assert persistence_store._cache.get("key") is None + + +@pytest.mark.parametrize("persistence_store", [{"use_local_cache": False}], indirect=True) +def test_user_local_disabled(persistence_store): + # GIVEN a persistence_store with use_local_cache = False + + # WHEN calling any local cache options + data_record = DataRecord("key", status="COMPLETED") + try: + persistence_store._save_to_cache(data_record) + cache_value = persistence_store._retrieve_from_cache("key") + assert cache_value is None + persistence_store._delete_from_cache("key") + except AttributeError as e: + pytest.fail(f"AttributeError should not be raised: {e}") + + # THEN raise AttributeError + # AND don't have a _cache attribute + assert not hasattr("persistence_store", "_cache") diff --git a/tests/unit/test_json_encoder.py b/tests/unit/test_json_encoder.py index 8d6a9f3944c..af8de4257a8 100644 --- a/tests/unit/test_json_encoder.py +++ b/tests/unit/test_json_encoder.py @@ -1,6 +1,8 @@ import decimal import json +import pytest + from aws_lambda_powertools.shared.json_encoder import Encoder @@ -12,3 +14,11 @@ def test_jsonencode_decimal(): def test_jsonencode_decimal_nan(): result = json.dumps({"val": decimal.Decimal("NaN")}, cls=Encoder) assert result == '{"val": NaN}' + + +def test_jsonencode_calls_default(): + class CustomClass: + pass + + with pytest.raises(TypeError): + json.dumps({"val": CustomClass()}, cls=Encoder)