Skip to content

Commit 50ac176

Browse files
authored
feat(idempotency): allow custom sdk clients in DynamoDBPersistenceLayer (#2087)
1 parent 99bcf80 commit 50ac176

File tree

5 files changed

+77
-46
lines changed

5 files changed

+77
-46
lines changed

aws_lambda_powertools/utilities/idempotency/persistence/base.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def __init__(
4141
status: str = "",
4242
expiry_timestamp: Optional[int] = None,
4343
in_progress_expiry_timestamp: Optional[int] = None,
44-
response_data: Optional[str] = "",
45-
payload_hash: Optional[str] = None,
44+
response_data: str = "",
45+
payload_hash: str = "",
4646
) -> None:
4747
"""
4848
@@ -117,15 +117,15 @@ def __init__(self):
117117
"""Initialize the defaults"""
118118
self.function_name = ""
119119
self.configured = False
120-
self.event_key_jmespath: Optional[str] = None
120+
self.event_key_jmespath: str = ""
121121
self.event_key_compiled_jmespath = None
122122
self.jmespath_options: Optional[dict] = None
123123
self.payload_validation_enabled = False
124124
self.validation_key_jmespath = None
125125
self.raise_on_no_idempotency_key = False
126126
self.expires_after_seconds: int = 60 * 60 # 1 hour default
127127
self.use_local_cache = False
128-
self.hash_function = None
128+
self.hash_function = hashlib.md5
129129

130130
def configure(self, config: IdempotencyConfig, function_name: Optional[str] = None) -> None:
131131
"""

aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py

+28-18
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
from __future__ import annotations
2+
13
import datetime
24
import logging
35
import os
4-
from typing import Any, Dict, Optional
6+
from typing import TYPE_CHECKING, Any, Dict, Optional
57

68
import boto3
79
from boto3.dynamodb.types import TypeDeserializer
@@ -19,6 +21,10 @@
1921
DataRecord,
2022
)
2123

24+
if TYPE_CHECKING:
25+
from mypy_boto3_dynamodb import DynamoDBClient
26+
from mypy_boto3_dynamodb.type_defs import AttributeValueTypeDef
27+
2228
logger = logging.getLogger(__name__)
2329

2430

@@ -36,6 +42,7 @@ def __init__(
3642
validation_key_attr: str = "validation",
3743
boto_config: Optional[Config] = None,
3844
boto3_session: Optional[boto3.session.Session] = None,
45+
boto3_client: "DynamoDBClient" | None = None,
3946
):
4047
"""
4148
Initialize the DynamoDB client
@@ -61,8 +68,10 @@ def __init__(
6168
DynamoDB attribute name for response data, by default "data"
6269
boto_config: botocore.config.Config, optional
6370
Botocore configuration to pass during client initialization
64-
boto3_session : boto3.session.Session, optional
71+
boto3_session : boto3.Session, optional
6572
Boto3 session to use for AWS API communication
73+
boto3_client : DynamoDBClient, optional
74+
Boto3 DynamoDB Client to use, boto3_session and boto_config will be ignored if both are provided
6675
6776
Examples
6877
--------
@@ -78,10 +87,12 @@ def __init__(
7887
>>> def handler(event, context):
7988
>>> return {"StatusCode": 200}
8089
"""
81-
82-
self._boto_config = boto_config or Config()
83-
self._boto3_session = boto3_session or boto3.session.Session()
84-
self._client = self._boto3_session.client("dynamodb", config=self._boto_config)
90+
if boto3_client is None:
91+
self._boto_config = boto_config or Config()
92+
self._boto3_session: boto3.Session = boto3_session or boto3.session.Session()
93+
self.client: "DynamoDBClient" = self._boto3_session.client("dynamodb", config=self._boto_config)
94+
else:
95+
self.client = boto3_client
8596

8697
if sort_key_attr == key_attr:
8798
raise ValueError(f"key_attr [{key_attr}] and sort_key_attr [{sort_key_attr}] cannot be the same!")
@@ -149,7 +160,7 @@ def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
149160
)
150161

151162
def _get_record(self, idempotency_key) -> DataRecord:
152-
response = self._client.get_item(
163+
response = self.client.get_item(
153164
TableName=self.table_name, Key=self._get_key(idempotency_key), ConsistentRead=True
154165
)
155166
try:
@@ -204,7 +215,7 @@ def _put_record(self, data_record: DataRecord) -> None:
204215
condition_expression = (
205216
f"{idempotency_key_not_exist} OR {idempotency_expiry_expired} OR ({inprogress_expiry_expired})"
206217
)
207-
self._client.put_item(
218+
self.client.put_item(
208219
TableName=self.table_name,
209220
Item=item,
210221
ConditionExpression=condition_expression,
@@ -233,7 +244,7 @@ def _put_record(self, data_record: DataRecord) -> None:
233244
def _update_record(self, data_record: DataRecord):
234245
logger.debug(f"Updating record for idempotency key: {data_record.idempotency_key}")
235246
update_expression = "SET #response_data = :response_data, #expiry = :expiry, " "#status = :status"
236-
expression_attr_values = {
247+
expression_attr_values: Dict[str, "AttributeValueTypeDef"] = {
237248
":expiry": {"N": str(data_record.expiry_timestamp)},
238249
":response_data": {"S": data_record.response_data},
239250
":status": {"S": data_record.status},
@@ -249,15 +260,14 @@ def _update_record(self, data_record: DataRecord):
249260
expression_attr_values[":validation_key"] = {"S": data_record.payload_hash}
250261
expression_attr_names["#validation_key"] = self.validation_key_attr
251262

252-
kwargs = {
253-
"Key": self._get_key(data_record.idempotency_key),
254-
"UpdateExpression": update_expression,
255-
"ExpressionAttributeValues": expression_attr_values,
256-
"ExpressionAttributeNames": expression_attr_names,
257-
}
258-
259-
self._client.update_item(TableName=self.table_name, **kwargs)
263+
self.client.update_item(
264+
TableName=self.table_name,
265+
Key=self._get_key(data_record.idempotency_key),
266+
UpdateExpression=update_expression,
267+
ExpressionAttributeNames=expression_attr_names,
268+
ExpressionAttributeValues=expression_attr_values,
269+
)
260270

261271
def _delete_record(self, data_record: DataRecord) -> None:
262272
logger.debug(f"Deleting record for idempotency key: {data_record.idempotency_key}")
263-
self._client.delete_item(TableName=self.table_name, Key={**self._get_key(data_record.idempotency_key)})
273+
self.client.delete_item(TableName=self.table_name, Key={**self._get_key(data_record.idempotency_key)})

tests/functional/idempotency/test_idempotency.py

+24-24
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def test_idempotent_lambda_already_completed(
7676
Test idempotent decorator where event with matching event key has already been successfully processed
7777
"""
7878

79-
stubber = stub.Stubber(persistence_store._client)
79+
stubber = stub.Stubber(persistence_store.client)
8080
ddb_response = {
8181
"Item": {
8282
"id": {"S": hashed_idempotency_key},
@@ -120,7 +120,7 @@ def test_idempotent_lambda_in_progress(
120120
Test idempotent decorator where lambda_handler is already processing an event with matching event key
121121
"""
122122

123-
stubber = stub.Stubber(persistence_store._client)
123+
stubber = stub.Stubber(persistence_store.client)
124124

125125
expected_params = {
126126
"TableName": TABLE_NAME,
@@ -172,7 +172,7 @@ def test_idempotent_lambda_in_progress_with_cache(
172172
"""
173173
save_to_cache_spy = mocker.spy(persistence_store, "_save_to_cache")
174174
retrieve_from_cache_spy = mocker.spy(persistence_store, "_retrieve_from_cache")
175-
stubber = stub.Stubber(persistence_store._client)
175+
stubber = stub.Stubber(persistence_store.client)
176176

177177
expected_params = {
178178
"TableName": TABLE_NAME,
@@ -234,7 +234,7 @@ def test_idempotent_lambda_first_execution(
234234
Test idempotent decorator when lambda is executed with an event with a previously unknown event key
235235
"""
236236

237-
stubber = stub.Stubber(persistence_store._client)
237+
stubber = stub.Stubber(persistence_store.client)
238238
ddb_response = {}
239239

240240
stubber.add_response("put_item", ddb_response, expected_params_put_item)
@@ -269,7 +269,7 @@ def test_idempotent_lambda_first_execution_cached(
269269
"""
270270
save_to_cache_spy = mocker.spy(persistence_store, "_save_to_cache")
271271
retrieve_from_cache_spy = mocker.spy(persistence_store, "_retrieve_from_cache")
272-
stubber = stub.Stubber(persistence_store._client)
272+
stubber = stub.Stubber(persistence_store.client)
273273
ddb_response = {}
274274

275275
stubber.add_response("put_item", ddb_response, expected_params_put_item)
@@ -310,7 +310,7 @@ def test_idempotent_lambda_first_execution_event_mutation(
310310
Ensures we're passing data by value, not reference.
311311
"""
312312
event = copy.deepcopy(lambda_apigw_event)
313-
stubber = stub.Stubber(persistence_store._client)
313+
stubber = stub.Stubber(persistence_store.client)
314314
ddb_response = {}
315315
stubber.add_response(
316316
"put_item",
@@ -350,7 +350,7 @@ def test_idempotent_lambda_expired(
350350
expiry window
351351
"""
352352

353-
stubber = stub.Stubber(persistence_store._client)
353+
stubber = stub.Stubber(persistence_store.client)
354354

355355
ddb_response = {}
356356

@@ -385,7 +385,7 @@ def test_idempotent_lambda_exception(
385385
# Create a new provider
386386

387387
# Stub the boto3 client
388-
stubber = stub.Stubber(persistence_store._client)
388+
stubber = stub.Stubber(persistence_store.client)
389389

390390
ddb_response = {}
391391
expected_params_delete_item = {"TableName": TABLE_NAME, "Key": {"id": {"S": hashed_idempotency_key}}}
@@ -427,7 +427,7 @@ def test_idempotent_lambda_already_completed_with_validation_bad_payload(
427427
Test idempotent decorator where event with matching event key has already been successfully processed
428428
"""
429429

430-
stubber = stub.Stubber(persistence_store._client)
430+
stubber = stub.Stubber(persistence_store.client)
431431
ddb_response = {
432432
"Item": {
433433
"id": {"S": hashed_idempotency_key},
@@ -471,7 +471,7 @@ def test_idempotent_lambda_expired_during_request(
471471
returns inconsistent/rapidly changing result between put_item and get_item calls.
472472
"""
473473

474-
stubber = stub.Stubber(persistence_store._client)
474+
stubber = stub.Stubber(persistence_store.client)
475475

476476
ddb_response_get_item = {
477477
"Item": {
@@ -524,7 +524,7 @@ def test_idempotent_persistence_exception_deleting(
524524
Test idempotent decorator when lambda is executed with an event with a previously unknown event key, but
525525
lambda_handler raises an exception which is retryable.
526526
"""
527-
stubber = stub.Stubber(persistence_store._client)
527+
stubber = stub.Stubber(persistence_store.client)
528528

529529
ddb_response = {}
530530

@@ -556,7 +556,7 @@ def test_idempotent_persistence_exception_updating(
556556
Test idempotent decorator when lambda is executed with an event with a previously unknown event key, but
557557
lambda_handler raises an exception which is retryable.
558558
"""
559-
stubber = stub.Stubber(persistence_store._client)
559+
stubber = stub.Stubber(persistence_store.client)
560560

561561
ddb_response = {}
562562

@@ -587,7 +587,7 @@ def test_idempotent_persistence_exception_getting(
587587
Test idempotent decorator when lambda is executed with an event with a previously unknown event key, but
588588
lambda_handler raises an exception which is retryable.
589589
"""
590-
stubber = stub.Stubber(persistence_store._client)
590+
stubber = stub.Stubber(persistence_store.client)
591591

592592
stubber.add_client_error("put_item", "ConditionalCheckFailedException")
593593
stubber.add_client_error("get_item", "UnexpectedException")
@@ -625,7 +625,7 @@ def test_idempotent_lambda_first_execution_with_validation(
625625
"""
626626
Test idempotent decorator when lambda is executed with an event with a previously unknown event key
627627
"""
628-
stubber = stub.Stubber(persistence_store._client)
628+
stubber = stub.Stubber(persistence_store.client)
629629
ddb_response = {}
630630

631631
stubber.add_response("put_item", ddb_response, expected_params_put_item_with_validation)
@@ -661,7 +661,7 @@ def test_idempotent_lambda_with_validator_util(
661661
validator utility to unwrap the event
662662
"""
663663

664-
stubber = stub.Stubber(persistence_store._client)
664+
stubber = stub.Stubber(persistence_store.client)
665665
ddb_response = {
666666
"Item": {
667667
"id": {"S": hashed_idempotency_key_with_envelope},
@@ -704,7 +704,7 @@ def test_idempotent_lambda_expires_in_progress_before_expire(
704704
hashed_idempotency_key,
705705
lambda_context,
706706
):
707-
stubber = stub.Stubber(persistence_store._client)
707+
stubber = stub.Stubber(persistence_store.client)
708708

709709
stubber.add_client_error("put_item", "ConditionalCheckFailedException")
710710

@@ -751,7 +751,7 @@ def test_idempotent_lambda_expires_in_progress_after_expire(
751751
hashed_idempotency_key,
752752
lambda_context,
753753
):
754-
stubber = stub.Stubber(persistence_store._client)
754+
stubber = stub.Stubber(persistence_store.client)
755755

756756
for _ in range(MAX_RETRIES + 1):
757757
stubber.add_client_error("put_item", "ConditionalCheckFailedException")
@@ -1070,7 +1070,7 @@ def test_custom_jmespath_function_overrides_builtin_functions(
10701070
def test_idempotent_lambda_save_inprogress_error(persistence_store: DynamoDBPersistenceLayer, lambda_context):
10711071
# GIVEN a miss configured persistence layer
10721072
# like no table was created for the idempotency persistence layer
1073-
stubber = stub.Stubber(persistence_store._client)
1073+
stubber = stub.Stubber(persistence_store.client)
10741074
service_error_code = "ResourceNotFoundException"
10751075
service_message = "Custom message"
10761076

@@ -1327,7 +1327,7 @@ def test_idempotency_disabled_envvar(monkeypatch, lambda_context, persistence_st
13271327
# Scenario to validate no requests sent to dynamodb table when 'POWERTOOLS_IDEMPOTENCY_DISABLED' is set
13281328
mock_event = {"data": "value"}
13291329

1330-
persistence_store._client = MagicMock()
1330+
persistence_store.client = MagicMock()
13311331

13321332
monkeypatch.setenv("POWERTOOLS_IDEMPOTENCY_DISABLED", "1")
13331333

@@ -1342,7 +1342,7 @@ def dummy_handler(event, context):
13421342
dummy(data=mock_event)
13431343
dummy_handler(mock_event, lambda_context)
13441344

1345-
assert len(persistence_store._client.method_calls) == 0
1345+
assert len(persistence_store.client.method_calls) == 0
13461346

13471347

13481348
@pytest.mark.parametrize("idempotency_config", [{"use_local_cache": True}], indirect=True)
@@ -1351,7 +1351,7 @@ def test_idempotent_function_duplicates(
13511351
):
13521352
# Scenario to validate the both methods are called
13531353
mock_event = {"data": "value"}
1354-
persistence_store._client = MagicMock()
1354+
persistence_store.client = MagicMock()
13551355

13561356
@idempotent_function(data_keyword_argument="data", persistence_store=persistence_store, config=idempotency_config)
13571357
def one(data):
@@ -1363,7 +1363,7 @@ def two(data):
13631363

13641364
assert one(data=mock_event) == "one"
13651365
assert two(data=mock_event) == "two"
1366-
assert len(persistence_store._client.method_calls) == 4
1366+
assert len(persistence_store.client.method_calls) == 4
13671367

13681368

13691369
def test_invalid_dynamodb_persistence_layer():
@@ -1475,7 +1475,7 @@ def test_idempotent_lambda_compound_already_completed(
14751475
Test idempotent decorator having a DynamoDBPersistenceLayer with a compound key
14761476
"""
14771477

1478-
stubber = stub.Stubber(persistence_store_compound._client)
1478+
stubber = stub.Stubber(persistence_store_compound.client)
14791479
stubber.add_client_error("put_item", "ConditionalCheckFailedException")
14801480
ddb_response = {
14811481
"Item": {
@@ -1520,7 +1520,7 @@ def test_idempotent_lambda_compound_static_pk_value_has_correct_pk(
15201520
Test idempotent decorator having a DynamoDBPersistenceLayer with a compound key and a static PK value
15211521
"""
15221522

1523-
stubber = stub.Stubber(persistence_store_compound_static_pk_value._client)
1523+
stubber = stub.Stubber(persistence_store_compound_static_pk_value.client)
15241524
ddb_response = {}
15251525

15261526
stubber.add_response("put_item", ddb_response, expected_params_put_item_compound_key_static_pk_value)

tests/unit/idempotency/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from dataclasses import dataclass
2+
3+
from aws_lambda_powertools.utilities.idempotency import DynamoDBPersistenceLayer
4+
from tests.e2e.utils.data_builder.common import build_random_value
5+
6+
7+
def test_custom_sdk_client_injection():
8+
# GIVEN
9+
@dataclass
10+
class DummyClient:
11+
table_name: str
12+
13+
table_name = build_random_value()
14+
fake_client = DummyClient(table_name)
15+
16+
# WHEN
17+
persistence_layer = DynamoDBPersistenceLayer(table_name, boto3_client=fake_client)
18+
19+
# THEN
20+
assert persistence_layer.table_name == table_name
21+
assert persistence_layer.client == fake_client

0 commit comments

Comments
 (0)