diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py b/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py index 0ce307ab503..8a470c0f910 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py @@ -1,10 +1,12 @@ import datetime import logging +import os from typing import Any, Dict, Optional import boto3 from botocore.config import Config +from aws_lambda_powertools.shared import constants from aws_lambda_powertools.utilities.idempotency import BasePersistenceLayer from aws_lambda_powertools.utilities.idempotency.exceptions import ( IdempotencyItemAlreadyExistsError, @@ -20,6 +22,8 @@ def __init__( self, table_name: str, key_attr: str = "id", + static_pk_value: str = f"idempotency#{os.getenv(constants.LAMBDA_FUNCTION_NAME_ENV, '')}", + sort_key_attr: Optional[str] = None, expiry_attr: str = "expiration", status_attr: str = "status", data_attr: str = "data", @@ -35,7 +39,12 @@ def __init__( table_name: str Name of the table to use for storing execution records key_attr: str, optional - DynamoDB attribute name for key, by default "id" + DynamoDB attribute name for partition key, by default "id" + static_pk_value: str, optional + DynamoDB attribute value for partition key, by default "idempotency#". + This will be used if the sort_key_attr is set. + sort_key_attr: str, optional + DynamoDB attribute name for the sort key expiry_attr: str, optional DynamoDB attribute name for expiry timestamp, by default "expiration" status_attr: str, optional @@ -64,10 +73,14 @@ def __init__( self._boto_config = boto_config or Config() self._boto3_session = boto3_session or boto3.session.Session() + if sort_key_attr == key_attr: + raise ValueError(f"key_attr [{key_attr}] and sort_key_attr [{sort_key_attr}] cannot be the same!") self._table = None self.table_name = table_name self.key_attr = key_attr + self.static_pk_value = static_pk_value + self.sort_key_attr = sort_key_attr self.expiry_attr = expiry_attr self.status_attr = status_attr self.data_attr = data_attr @@ -93,6 +106,11 @@ def table(self, table): """ self._table = table + def _get_key(self, idempotency_key: str) -> dict: + if self.sort_key_attr: + return {self.key_attr: self.static_pk_value, self.sort_key_attr: idempotency_key} + return {self.key_attr: idempotency_key} + def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord: """ Translate raw item records from DynamoDB to DataRecord @@ -117,7 +135,7 @@ def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord: ) def _get_record(self, idempotency_key) -> DataRecord: - response = self.table.get_item(Key={self.key_attr: idempotency_key}, ConsistentRead=True) + response = self.table.get_item(Key=self._get_key(idempotency_key), ConsistentRead=True) try: item = response["Item"] @@ -127,7 +145,7 @@ def _get_record(self, idempotency_key) -> DataRecord: def _put_record(self, data_record: DataRecord) -> None: item = { - self.key_attr: data_record.idempotency_key, + **self._get_key(data_record.idempotency_key), self.expiry_attr: data_record.expiry_timestamp, self.status_attr: data_record.status, } @@ -168,7 +186,7 @@ def _update_record(self, data_record: DataRecord): expression_attr_names["#validation_key"] = self.validation_key_attr kwargs = { - "Key": {self.key_attr: data_record.idempotency_key}, + "Key": self._get_key(data_record.idempotency_key), "UpdateExpression": update_expression, "ExpressionAttributeValues": expression_attr_values, "ExpressionAttributeNames": expression_attr_names, @@ -178,4 +196,4 @@ def _update_record(self, data_record: DataRecord): def _delete_record(self, data_record: DataRecord) -> None: logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}") - self.table.delete_item(Key={self.key_attr: data_record.idempotency_key}) + self.table.delete_item(Key=self._get_key(data_record.idempotency_key)) diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index b1d0914d181..043fb06a04a 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -783,11 +783,11 @@ def test_jmespath_with_powertools_json( # GIVEN an event_key_jmespath with powertools_json custom function persistence_store.configure(idempotency_config) sub_attr_value = "cognito_user" - key_attr_value = "some_key" - expected_value = [sub_attr_value, key_attr_value] + static_pk_value = "some_key" + expected_value = [sub_attr_value, static_pk_value] api_gateway_proxy_event = { "requestContext": {"authorizer": {"claims": {"sub": sub_attr_value}}}, - "body": serialize({"id": key_attr_value}), + "body": serialize({"id": static_pk_value}), } # WHEN calling _get_hashed_idempotency_key