Skip to content

fix(idempotency): make idempotent_function decorator thread safe #1899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,7 +836,6 @@ def route(
# Override _compile_regex to exclude trailing slashes for route resolution
@staticmethod
def _compile_regex(rule: str, base_regex: str = _ROUTE_REGEX):

return super(APIGatewayRestResolver, APIGatewayRestResolver)._compile_regex(rule, "^{}/*$")


Expand Down
1 change: 0 additions & 1 deletion aws_lambda_powertools/logging/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def copy_config_to_registered_loggers(
exclude: Optional[Set[str]] = None,
include: Optional[Set[str]] = None,
) -> None:

"""Copies source Logger level and handler to all registered loggers for consistent formatting.

Parameters
Expand Down
1 change: 0 additions & 1 deletion aws_lambda_powertools/utilities/feature_flags/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,6 @@ def __init__(self, rule: Dict[str, Any], rule_name: str, logger: Optional[Union[
self.logger = logger or logging.getLogger(__name__)

def validate(self):

if not self.conditions or not isinstance(self.conditions, list):
self.logger.debug(f"Condition is empty or invalid for rule={self.rule_name}")
raise SchemaValidationError(f"Invalid condition, rule={self.rule_name}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class DataRecord:

def __init__(
self,
idempotency_key,
idempotency_key: str,
status: str = "",
expiry_timestamp: Optional[int] = None,
in_progress_expiry_timestamp: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Any, Dict, Optional

import boto3
from boto3.dynamodb.types import TypeDeserializer
from botocore.config import Config
from botocore.exceptions import ClientError

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.utilities.idempotency import BasePersistenceLayer
Expand Down Expand Up @@ -79,13 +81,14 @@ def __init__(

self._boto_config = boto_config or Config()
self._boto3_session = boto3_session or boto3.session.Session()
self._client = self._boto3_session.client("dynamodb", config=self._boto_config)

if sort_key_attr == key_attr:
raise ValueError(f"key_attr [{key_attr}] and sort_key_attr [{sort_key_attr}] cannot be the same!")

if static_pk_value is None:
static_pk_value = f"idempotency#{os.getenv(constants.LAMBDA_FUNCTION_NAME_ENV, '')}"

self._table = None
self.table_name = table_name
self.key_attr = key_attr
self.static_pk_value = static_pk_value
Expand All @@ -95,31 +98,15 @@ def __init__(
self.status_attr = status_attr
self.data_attr = data_attr
self.validation_key_attr = validation_key_attr
super(DynamoDBPersistenceLayer, self).__init__()

@property
def table(self):
"""
Caching property to store boto3 dynamodb Table resource
self._deserializer = TypeDeserializer()

"""
if self._table:
return self._table
ddb_resource = self._boto3_session.resource("dynamodb", config=self._boto_config)
self._table = ddb_resource.Table(self.table_name)
return self._table

@table.setter
def table(self, table):
"""
Allow table instance variable to be set directly, primarily for use in tests
"""
self._table = table
super(DynamoDBPersistenceLayer, self).__init__()

def _get_key(self, idempotency_key: str) -> dict:
if self.sort_key_attr:
return {self.key_attr: self.static_pk_value, self.sort_key_attr: idempotency_key}
return {self.key_attr: idempotency_key}
return {self.key_attr: {"S": self.static_pk_value}, self.sort_key_attr: {"S": idempotency_key}}
return {self.key_attr: {"S": idempotency_key}}

def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
"""
Expand All @@ -136,36 +123,39 @@ def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
representation of item

"""
data = self._deserializer.deserialize({"M": item})
return DataRecord(
idempotency_key=item[self.key_attr],
status=item[self.status_attr],
expiry_timestamp=item[self.expiry_attr],
in_progress_expiry_timestamp=item.get(self.in_progress_expiry_attr),
response_data=item.get(self.data_attr),
payload_hash=item.get(self.validation_key_attr),
idempotency_key=data[self.key_attr],
status=data[self.status_attr],
expiry_timestamp=data[self.expiry_attr],
in_progress_expiry_timestamp=data.get(self.in_progress_expiry_attr),
response_data=data.get(self.data_attr),
payload_hash=data.get(self.validation_key_attr),
)

def _get_record(self, idempotency_key) -> DataRecord:
response = self.table.get_item(Key=self._get_key(idempotency_key), ConsistentRead=True)

response = self._client.get_item(
TableName=self.table_name, Key=self._get_key(idempotency_key), ConsistentRead=True
)
try:
item = response["Item"]
except KeyError:
raise IdempotencyItemNotFoundError
except KeyError as exc:
raise IdempotencyItemNotFoundError from exc
return self._item_to_data_record(item)

def _put_record(self, data_record: DataRecord) -> None:
item = {
**self._get_key(data_record.idempotency_key),
self.expiry_attr: data_record.expiry_timestamp,
self.status_attr: data_record.status,
self.key_attr: {"S": data_record.idempotency_key},
self.expiry_attr: {"N": str(data_record.expiry_timestamp)},
self.status_attr: {"S": data_record.status},
}

if data_record.in_progress_expiry_timestamp is not None:
item[self.in_progress_expiry_attr] = data_record.in_progress_expiry_timestamp
item[self.in_progress_expiry_attr] = {"N": str(data_record.in_progress_expiry_timestamp)}

if self.payload_validation_enabled:
item[self.validation_key_attr] = data_record.payload_hash
if self.payload_validation_enabled and data_record.payload_hash:
item[self.validation_key_attr] = {"S": data_record.payload_hash}

now = datetime.datetime.now()
try:
Expand Down Expand Up @@ -199,8 +189,8 @@ def _put_record(self, data_record: DataRecord) -> None:
condition_expression = (
f"{idempotency_key_not_exist} OR {idempotency_expiry_expired} OR ({inprogress_expiry_expired})"
)

self.table.put_item(
self._client.put_item(
TableName=self.table_name,
Item=item,
ConditionExpression=condition_expression,
ExpressionAttributeNames={
Expand All @@ -210,22 +200,28 @@ def _put_record(self, data_record: DataRecord) -> None:
"#status": self.status_attr,
},
ExpressionAttributeValues={
":now": int(now.timestamp()),
":now_in_millis": int(now.timestamp() * 1000),
":inprogress": STATUS_CONSTANTS["INPROGRESS"],
":now": {"N": str(int(now.timestamp()))},
":now_in_millis": {"N": str(int(now.timestamp() * 1000))},
":inprogress": {"S": STATUS_CONSTANTS["INPROGRESS"]},
},
)
except self.table.meta.client.exceptions.ConditionalCheckFailedException:
logger.debug(f"Failed to put record for already existing idempotency key: {data_record.idempotency_key}")
raise IdempotencyItemAlreadyExistsError
except ClientError as exc:
error_code = exc.response.get("Error", {}).get("Code")
if error_code == "ConditionalCheckFailedException":
logger.debug(
f"Failed to put record for already existing idempotency key: {data_record.idempotency_key}"
)
raise IdempotencyItemAlreadyExistsError from exc
else:
raise

def _update_record(self, data_record: DataRecord):
logger.debug(f"Updating record for idempotency key: {data_record.idempotency_key}")
update_expression = "SET #response_data = :response_data, #expiry = :expiry, " "#status = :status"
expression_attr_values = {
":expiry": data_record.expiry_timestamp,
":response_data": data_record.response_data,
":status": data_record.status,
":expiry": {"N": str(data_record.expiry_timestamp)},
":response_data": {"S": data_record.response_data},
":status": {"S": data_record.status},
}
expression_attr_names = {
"#expiry": self.expiry_attr,
Expand All @@ -235,7 +231,7 @@ def _update_record(self, data_record: DataRecord):

if self.payload_validation_enabled:
update_expression += ", #validation_key = :validation_key"
expression_attr_values[":validation_key"] = data_record.payload_hash
expression_attr_values[":validation_key"] = {"S": data_record.payload_hash}
expression_attr_names["#validation_key"] = self.validation_key_attr

kwargs = {
Expand All @@ -245,8 +241,8 @@ def _update_record(self, data_record: DataRecord):
"ExpressionAttributeNames": expression_attr_names,
}

self.table.update_item(**kwargs)
self._client.update_item(TableName=self.table_name, **kwargs)

def _delete_record(self, data_record: DataRecord) -> None:
logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}")
self.table.delete_item(Key=self._get_key(data_record.idempotency_key))
self._client.delete_item(TableName=self.table_name, Key={**self._get_key(data_record.idempotency_key)})
5 changes: 5 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ ignore_missing_imports = True
[mypy-boto3.dynamodb.conditions]
ignore_missing_imports = True

[mypy-boto3.dynamodb.types]
ignore_missing_imports = True

[mypy-botocore.config]
ignore_missing_imports = True

Expand Down Expand Up @@ -58,3 +61,5 @@ ignore_missing_imports = True

[mypy-ijson]
ignore_missing_imports = True


Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

@idempotent(persistence_store=persistence_layer)
def lambda_handler(event, context):

time.sleep(5)

return event
29 changes: 29 additions & 0 deletions tests/e2e/idempotency/handlers/parallel_functions_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import current_thread

from aws_lambda_powertools.utilities.idempotency import (
DynamoDBPersistenceLayer,
idempotent_function,
)

TABLE_NAME = os.getenv("IdempotencyTable", "")
persistence_layer = DynamoDBPersistenceLayer(table_name=TABLE_NAME)
threads_count = 2


@idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record")
def record_handler(record):
time_now = time.time()
return {"thread_name": current_thread().name, "time": str(time_now)}


def lambda_handler(event, context):
with ThreadPoolExecutor(max_workers=threads_count) as executor:
futures = [executor.submit(record_handler, **{"record": event}) for _ in range(threads_count)]

return [
{"state": future._state, "exception": future.exception(), "output": future.result()}
for future in as_completed(futures)
]
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

@idempotent(config=config, persistence_store=persistence_layer)
def lambda_handler(event, context):

time_now = time.time()

return {"time": str(time_now)}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

@idempotent(config=config, persistence_store=persistence_layer)
def lambda_handler(event, context):

sleep_time: int = event.get("sleep") or 0
time.sleep(sleep_time)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/idempotency/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def create_resources(self):
table.grant_read_write_data(functions["TtlCacheExpirationHandler"])
table.grant_read_write_data(functions["TtlCacheTimeoutHandler"])
table.grant_read_write_data(functions["ParallelExecutionHandler"])
table.grant_read_write_data(functions["ParallelFunctionsHandler"])

def _create_dynamodb_table(self) -> Table:
table = dynamodb.Table(
Expand Down
32 changes: 32 additions & 0 deletions tests/e2e/idempotency/test_idempotency_dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ def parallel_execution_handler_fn_arn(infrastructure: dict) -> str:
return infrastructure.get("ParallelExecutionHandlerArn", "")


@pytest.fixture
def parallel_functions_handler_fn_arn(infrastructure: dict) -> str:
return infrastructure.get("ParallelFunctionsHandlerArn", "")


@pytest.fixture
def idempotency_table_name(infrastructure: dict) -> str:
return infrastructure.get("DynamoDBTable", "")
Expand Down Expand Up @@ -97,3 +102,30 @@ def test_parallel_execution_idempotency(parallel_execution_handler_fn_arn: str):
# THEN
assert "Execution already in progress with idempotency key" in error_idempotency_execution_response
assert "Task timed out after" in timeout_execution_response


@pytest.mark.xdist_group(name="idempotency")
def test_parallel_functions_execution_idempotency(parallel_functions_handler_fn_arn: str):
# GIVEN
payload = json.dumps({"message": "Lambda Powertools - Parallel functions execution"})

# WHEN
# first execution
first_execution, _ = data_fetcher.get_lambda_response(lambda_arn=parallel_functions_handler_fn_arn, payload=payload)
first_execution_response = first_execution["Payload"].read().decode("utf-8")

# the second execution should return the same response as the first execution
second_execution, _ = data_fetcher.get_lambda_response(
lambda_arn=parallel_functions_handler_fn_arn, payload=payload
)
second_execution_response = second_execution["Payload"].read().decode("utf-8")

# THEN
# Function threads finished without exception AND
# first and second execution is the same
for function_thread in json.loads(first_execution_response):
assert function_thread["state"] == "FINISHED"
assert function_thread["exception"] is None
assert function_thread["output"] is not None

assert first_execution_response == second_execution_response
2 changes: 0 additions & 2 deletions tests/e2e/parameters/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ def create_resources(self):
)

def _create_app_config(self, function: Function):

service_name = build_service_name()

cfn_application = appconfig.CfnApplication(
Expand Down Expand Up @@ -82,7 +81,6 @@ def _create_app_config_freeform(
function: Function,
service_name: str,
):

cfn_configuration_profile = appconfig.CfnConfigurationProfile(
self.stack,
"appconfig-profile",
Expand Down
2 changes: 2 additions & 0 deletions tests/functional/feature_flags/test_feature_flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,7 @@ def test_flags_conditions_rule_match_multiple_actions_multiple_rules_multiple_co

# check a case where the feature exists but the rule doesn't match so we revert to the default value of the feature


# Check IN/NOT_IN/KEY_IN_VALUE/KEY_NOT_IN_VALUE/VALUE_IN_KEY/VALUE_NOT_IN_KEY conditions
def test_flags_match_rule_with_in_action(mocker, config):
expected_value = True
Expand Down Expand Up @@ -775,6 +776,7 @@ def test_get_configuration_with_envelope_and_raw(mocker, config):
## Inequality test cases
##


# Test not equals
def test_flags_not_equal_no_match(mocker, config):
expected_value = False
Expand Down
Loading