Skip to content

fix: correct behaviour to avoid caching "INPROGRESS" records #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 20, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 39 additions & 27 deletions aws_lambda_powertools/utilities/idempotency/persistence/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import json
import logging
from abc import ABC, abstractmethod
from types import MappingProxyType
from typing import Any, Dict

import jmespath
Expand All @@ -21,7 +22,7 @@

logger = logging.getLogger(__name__)

STATUS_CONSTANTS = {"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"}
STATUS_CONSTANTS = MappingProxyType({"INPROGRESS": "INPROGRESS", "COMPLETED": "COMPLETED", "EXPIRED": "EXPIRED"})


class DataRecord:
Expand Down Expand Up @@ -81,8 +82,7 @@ def status(self) -> str:
"""
if self.is_expired:
return STATUS_CONSTANTS["EXPIRED"]

if self._status in STATUS_CONSTANTS.values():
elif self._status in STATUS_CONSTANTS.values():
return self._status
else:
raise IdempotencyInvalidStatusError(self._status)
Expand Down Expand Up @@ -214,14 +214,14 @@ def _validate_payload(self, lambda_event: Dict[str, Any], data_record: DataRecor
DataRecord instance

Raises
______
----------
IdempotencyValidationError
Event payload doesn't match the stored record for the given idempotency key

"""
if self.payload_validation_enabled:
lambda_payload_hash = self._get_hashed_payload(lambda_event)
if not data_record.payload_hash == lambda_payload_hash:
if data_record.payload_hash != lambda_payload_hash:
raise IdempotencyValidationError("Payload does not match stored record for this event key")

def _get_expiry_timestamp(self) -> int:
Expand All @@ -238,9 +238,30 @@ def _get_expiry_timestamp(self) -> int:
return int((now + period).timestamp())

def _save_to_cache(self, data_record: DataRecord):
"""
Save data_record to local cache except when status is "INPROGRESS"

NOTE: We can't cache "INPROGRESS" records as we have no way to reflect updates that can happen outside of the
execution environment

Parameters
----------
data_record: DataRecord
DataRecord instance

Returns
-------

"""
if not self.use_local_cache:
return
if data_record.status == STATUS_CONSTANTS["INPROGRESS"]:
return
self._cache[data_record.idempotency_key] = data_record

def _retrieve_from_cache(self, idempotency_key: str):
if not self.use_local_cache:
return
cached_record = self._cache.get(idempotency_key)
if cached_record:
if not cached_record.is_expired:
Expand All @@ -249,11 +270,13 @@ def _retrieve_from_cache(self, idempotency_key: str):
self._delete_from_cache(idempotency_key)

def _delete_from_cache(self, idempotency_key: str):
if not self.use_local_cache:
return
del self._cache[idempotency_key]

def save_success(self, event: Dict[str, Any], result: dict) -> None:
"""
Save record of function's execution completing succesfully
Save record of function's execution completing successfully

Parameters
----------
Expand All @@ -277,8 +300,7 @@ def save_success(self, event: Dict[str, Any], result: dict) -> None:
)
self._update_record(data_record=data_record)

if self.use_local_cache:
self._save_to_cache(data_record)
self._save_to_cache(data_record)

def save_inprogress(self, event: Dict[str, Any]) -> None:
"""
Expand All @@ -298,18 +320,11 @@ def save_inprogress(self, event: Dict[str, Any]) -> None:

logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}")

if self.use_local_cache:
cached_record = self._retrieve_from_cache(idempotency_key=data_record.idempotency_key)
if cached_record:
raise IdempotencyItemAlreadyExistsError
if self._retrieve_from_cache(idempotency_key=data_record.idempotency_key):
raise IdempotencyItemAlreadyExistsError

self._put_record(data_record)

# This has to come after _put_record. If _put_record call raises ItemAlreadyExists we shouldn't populate the
# cache with an "INPROGRESS" record as we don't know the status in the data store at this point.
if self.use_local_cache:
self._save_to_cache(data_record)

def delete_record(self, event: Dict[str, Any], exception: Exception):
"""
Delete record from the persistence store
Expand All @@ -329,8 +344,7 @@ def delete_record(self, event: Dict[str, Any], exception: Exception):
)
self._delete_record(data_record)

if self.use_local_cache:
self._delete_from_cache(data_record.idempotency_key)
self._delete_from_cache(data_record.idempotency_key)

def get_record(self, event: Dict[str, Any]) -> DataRecord:
"""
Expand All @@ -356,17 +370,15 @@ def get_record(self, event: Dict[str, Any]) -> DataRecord:

idempotency_key = self._get_hashed_idempotency_key(event)

if self.use_local_cache:
cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}")
self._validate_payload(event, cached_record)
return cached_record
cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key)
if cached_record:
logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}")
self._validate_payload(event, cached_record)
return cached_record

record = self._get_record(idempotency_key)

if self.use_local_cache:
self._save_to_cache(data_record=record)
self._save_to_cache(data_record=record)

self._validate_payload(event, record)
return record
Expand Down
57 changes: 30 additions & 27 deletions docs/utilities/idempotency.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@ title: Idempotency
description: Utility
---

This utility provides a simple solution to convert your Lambda functions into idempotent operations which are safe to
retry.
!!! attention
**This utility is currently in beta**. Please open an [issue in GitHub](https://github.com/awslabs/aws-lambda-powertools-python/issues/new/choose) for any bugs or feature requests.

The idempotency utility provides a simple solution to convert your Lambda functions into idempotent operations which
are safe to retry.

## Terminology

Expand All @@ -31,31 +34,31 @@ storage layer, so you'll need to create a table first.
> Example using AWS Serverless Application Model (SAM)

=== "template.yml"
```yaml
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Properties:
Runtime: python3.8
...
Policies:
- DynamoDBCrudPolicy:
TableName: !Ref IdempotencyTable
IdempotencyTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: id
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: id
KeyType: HASH
TableName: "IdempotencyTable"
TimeToLiveSpecification:
AttributeName: expiration
Enabled: true
```
```yaml
Resources:
HelloWorldFunction:
Type: AWS::Serverless::Function
Properties:
Runtime: python3.8
...
Policies:
- DynamoDBCrudPolicy:
TableName: !Ref IdempotencyTable
IdempotencyTable:
Type: AWS::DynamoDB::Table
Properties:
AttributeDefinitions:
- AttributeName: id
AttributeType: S
BillingMode: PAY_PER_REQUEST
KeySchema:
- AttributeName: id
KeyType: HASH
TableName: "IdempotencyTable"
TimeToLiveSpecification:
AttributeName: expiration
Enabled: true
```

!!! warning
When using this utility with DynamoDB, your lambda responses must always be smaller than 400kb. Larger items cannot
Expand Down
55 changes: 44 additions & 11 deletions tests/functional/idempotency/test_idempotency.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ def test_idempotent_lambda_in_progress_with_cache(

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", ddb_response, expected_params)

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params))

stubber.add_client_error("put_item", "ConditionalCheckFailedException")
stubber.add_response("get_item", copy.deepcopy(ddb_response), copy.deepcopy(expected_params))
stubber.activate()

@idempotent(persistence_store=persistence_store)
Expand All @@ -151,11 +157,8 @@ def lambda_handler(event, context):
assert retrieve_from_cache_spy.call_count == 2 * loops
retrieve_from_cache_spy.assert_called_with(idempotency_key=hashed_idempotency_key)

assert save_to_cache_spy.call_count == 1
first_call_args_data_record = save_to_cache_spy.call_args_list[0].kwargs["data_record"]
assert first_call_args_data_record.idempotency_key == hashed_idempotency_key
assert first_call_args_data_record.status == "INPROGRESS"
assert persistence_store._cache.get(hashed_idempotency_key)
save_to_cache_spy.assert_called()
assert persistence_store._cache.get(hashed_idempotency_key) is None

stubber.assert_no_pending_responses()
stubber.deactivate()
Expand Down Expand Up @@ -223,12 +226,10 @@ def lambda_handler(event, context):

lambda_handler(lambda_apigw_event, {})

assert retrieve_from_cache_spy.call_count == 1
assert save_to_cache_spy.call_count == 2
first_call_args, second_call_args = save_to_cache_spy.call_args_list
assert first_call_args.args[0].status == "INPROGRESS"
assert second_call_args.args[0].status == "COMPLETED"
assert persistence_store._cache.get(hashed_idempotency_key)
retrieve_from_cache_spy.assert_called_once()
save_to_cache_spy.assert_called_once()
assert save_to_cache_spy.call_args[0][0].status == "COMPLETED"
assert persistence_store._cache.get(hashed_idempotency_key).status == "COMPLETED"

# This lambda call should not call AWS API
lambda_handler(lambda_apigw_event, {})
Expand Down Expand Up @@ -594,3 +595,35 @@ def test_data_record_invalid_status_value():
_ = data_record.status

assert e.value.args[0] == "UNSUPPORTED_STATUS"


@pytest.mark.parametrize("persistence_store", [{"use_local_cache": True}], indirect=True)
def test_in_progress_never_saved_to_cache(persistence_store):
# GIVEN a data record with status "INPROGRESS"
# and persistence_store has use_local_cache = True
data_record = DataRecord("key", status="INPROGRESS")

# WHEN saving to local cache
persistence_store._save_to_cache(data_record)

# THEN don't save to local cache
assert persistence_store._cache.get("key") is None


@pytest.mark.parametrize("persistence_store", [{"use_local_cache": False}], indirect=True)
def test_user_local_disabled(persistence_store):
# GIVEN a persistence_store with use_local_cache = False

# WHEN calling any local cache options
data_record = DataRecord("key", status="COMPLETED")
try:
persistence_store._save_to_cache(data_record)
cache_value = persistence_store._retrieve_from_cache("key")
assert cache_value is None
persistence_store._delete_from_cache("key")
except AttributeError as e:
pytest.fail(f"AttributeError should not be raised: {e}")

# THEN raise AttributeError
# AND don't have a _cache attribute
assert not hasattr("persistence_store", "_cache")
10 changes: 10 additions & 0 deletions tests/unit/test_json_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import decimal
import json

import pytest

from aws_lambda_powertools.shared.json_encoder import Encoder


Expand All @@ -12,3 +14,11 @@ def test_jsonencode_decimal():
def test_jsonencode_decimal_nan():
result = json.dumps({"val": decimal.Decimal("NaN")}, cls=Encoder)
assert result == '{"val": NaN}'


def test_jsonencode_calls_default():
class CustomClass:
pass

with pytest.raises(TypeError):
json.dumps({"val": CustomClass()}, cls=Encoder)