Skip to content

Commit 3126885

Browse files
feat(redis) - refactoring connection and fixing mypy errors
1 parent fce91e5 commit 3126885

File tree

7 files changed

+225
-71
lines changed

7 files changed

+225
-71
lines changed

aws_lambda_powertools/utilities/connections/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@
33
RedisStandalone,
44
)
55

6-
__all__ = (RedisStandalone, RedisCluster)
6+
__all__ = ["RedisStandalone", "RedisCluster"]
Lines changed: 50 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import logging
2-
from typing import Optional
2+
from typing import Optional, Type, Union
33

44
import redis
55

@@ -9,44 +9,29 @@
99
logger = logging.getLogger(__name__)
1010

1111

12-
class RedisStandalone(BaseConnectionSync):
12+
class RedisConnection(BaseConnectionSync):
1313
def __init__(
1414
self,
15+
client: Type[Union[redis.Redis, redis.RedisCluster]],
1516
host: Optional[str] = None,
1617
port: Optional[int] = None,
1718
username: Optional[str] = None,
1819
password: Optional[str] = None,
1920
db_index: Optional[int] = None,
2021
url: Optional[str] = None,
22+
**extra_options,
2123
) -> None:
22-
"""
23-
Initialize the Redis standalone client
24-
Parameters
25-
----------
26-
host: str
27-
Name of the host to connect to Redis instance/cluster
28-
port: int
29-
Number of the port to connect to Redis instance/cluster
30-
username: str
31-
Name of the username to connect to Redis instance/cluster in case of using ACL
32-
See: https://redis.io/docs/management/security/acl/
33-
password: str
34-
Passwod to connect to Redis instance/cluster
35-
db_index: int
36-
Index of Redis database
37-
See: https://redis.io/commands/select/
38-
url: str
39-
Redis client object configured from the given URL
40-
See: https://redis.readthedocs.io/en/latest/connections.html#redis.Redis.from_url
41-
"""
24+
self.extra_options: dict = {}
4225

4326
self.url = url
4427
self.host = host
4528
self.port = port
4629
self.username = username
4730
self.password = password
4831
self.db_index = db_index
32+
self.extra_options.update(**extra_options)
4933
self._connection = None
34+
self._client = client
5035

5136
def init_connection(self):
5237
"""
@@ -55,36 +40,40 @@ def init_connection(self):
5540
if self._connection:
5641
return self._connection
5742

58-
logger.info(f"Trying to connect to Redis Host/Cluster: {self.host}")
43+
logger.info(f"Trying to connect to Redis: {self.host}")
5944

6045
try:
6146
if self.url:
6247
logger.debug(f"Using URL format to connect to Redis: {self.host}")
63-
self._connection = redis.Redis.from_url(url=self.url)
48+
self._connection = self._client.from_url(url=self.url)
6449
else:
6550
logger.debug(f"Using other parameters to connect to Redis: {self.host}")
66-
self._connection = redis.Redis(
51+
self._connection = self._client(
6752
host=self.host,
6853
port=self.port,
6954
username=self.username,
7055
password=self.password,
7156
db=self.db_index,
7257
decode_responses=True,
58+
**self.extra_options,
7359
)
7460
except redis.exceptions.ConnectionError as exc:
75-
logger.debug(f"Cannot connect in Redis Host: {self.host}")
76-
raise RedisConnectionError("Could not to connect to Redis Standalone", exc) from exc
61+
logger.debug(f"Cannot connect in Redis: {self.host}")
62+
raise RedisConnectionError("Could not to connect to Redis", exc) from exc
7763

7864
return self._connection
7965

8066

81-
class RedisCluster(BaseConnectionSync):
67+
class RedisStandalone(RedisConnection):
8268
def __init__(
8369
self,
8470
host: Optional[str] = None,
8571
port: Optional[int] = None,
86-
read_from_replicas: Optional[bool] = False,
72+
username: Optional[str] = None,
73+
password: Optional[str] = None,
74+
db_index: Optional[int] = None,
8775
url: Optional[str] = None,
76+
**extra_options,
8877
) -> None:
8978
"""
9079
Initialize the Redis standalone client
@@ -106,36 +95,40 @@ def __init__(
10695
Redis client object configured from the given URL
10796
See: https://redis.readthedocs.io/en/latest/connections.html#redis.Redis.from_url
10897
"""
98+
print(extra_options)
99+
super().__init__(redis.Redis, host, port, username, password, db_index, url, **extra_options)
109100

110-
self.url = url
111-
self.host = host
112-
self.port = port
113-
self.read_from_replicas = read_from_replicas
114-
self._connection = None
115101

116-
def init_connection(self):
102+
class RedisCluster(RedisConnection):
103+
def __init__(
104+
self,
105+
host: Optional[str] = None,
106+
port: Optional[int] = None,
107+
username: Optional[str] = None,
108+
password: Optional[str] = None,
109+
db_index: Optional[int] = None,
110+
url: Optional[str] = None,
111+
**extra_options,
112+
) -> None:
117113
"""
118-
Connection is cached, so returning this
114+
Initialize the Redis standalone client
115+
Parameters
116+
----------
117+
host: str
118+
Name of the host to connect to Redis instance/cluster
119+
port: int
120+
Number of the port to connect to Redis instance/cluster
121+
username: str
122+
Name of the username to connect to Redis instance/cluster in case of using ACL
123+
See: https://redis.io/docs/management/security/acl/
124+
password: str
125+
Passwod to connect to Redis instance/cluster
126+
db_index: int
127+
Index of Redis database
128+
See: https://redis.io/commands/select/
129+
url: str
130+
Redis client object configured from the given URL
131+
See: https://redis.readthedocs.io/en/latest/connections.html#redis.Redis.from_url
119132
"""
120-
if self._connection:
121-
return self._connection
122133

123-
logger.info(f"Trying to connect to Redis Cluster: {self.host}")
124-
125-
try:
126-
if self.url:
127-
logger.debug(f"Using URL format to connect to Redis Cluster: {self.host}")
128-
self._connection = redis.Redis.from_url(url=self.url)
129-
else:
130-
logger.debug(f"Using other parameters to connect to Redis Cluster: {self.host}")
131-
self._connection = redis.cluster.RedisCluster(
132-
host=self.host,
133-
port=self.port,
134-
server_type=None,
135-
read_from_replicas=self.read_from_replicas,
136-
)
137-
except redis.exceptions.ConnectionError as exc:
138-
logger.debug(f"Cannot connect in Redis Cluster: {self.host}")
139-
raise RedisConnectionError("Could not to connect to Redis Cluster", exc) from exc
140-
141-
return self._connection
134+
super().__init__(redis.cluster.RedisCluster, host, port, username, password, db_index, url, **extra_options)

aws_lambda_powertools/utilities/idempotency/persistence/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class DataRecord:
3737

3838
def __init__(
3939
self,
40-
idempotency_key: Optional[str] = "",
40+
idempotency_key: str = "",
4141
status: str = "",
4242
expiry_timestamp: Optional[int] = None,
4343
in_progress_expiry_timestamp: Optional[int] = None,

aws_lambda_powertools/utilities/idempotency/persistence/dynamodb.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class DynamoDBPersistenceLayer(BasePersistenceLayer):
2626
def __init__(
2727
self,
2828
table_name: str,
29-
key_attr: str = "id",
29+
key_attr: Optional[str] = "id",
3030
static_pk_value: Optional[str] = None,
3131
sort_key_attr: Optional[str] = None,
3232
expiry_attr: str = "expiration",

aws_lambda_powertools/utilities/idempotency/persistence/redis.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ def __init__(
1919
connection,
2020
static_pk_value: Optional[str] = None,
2121
expiry_attr: str = "expiration",
22-
in_progress_expiry_attr: str = "in_progress_expiration",
22+
in_progress_expiry_attr="in_progress_expiration",
2323
status_attr: str = "status",
2424
data_attr: str = "data",
25-
validation_key_attr: str = "validation",
25+
validation_key_attr="validation",
2626
):
2727
"""
2828
Initialize the Redis Persistence Layer
@@ -54,12 +54,6 @@ def __init__(
5454
self.validation_key_attr = validation_key_attr
5555
super(RedisCachePersistenceLayer, self).__init__()
5656

57-
def _get_key(self, idempotency_key: str) -> dict:
58-
# Need to review this after adding GETKEY logic
59-
if self.sort_key_attr:
60-
return {self.key_attr: self.static_pk_value, self.sort_key_attr: idempotency_key}
61-
return {self.key_attr: idempotency_key}
62-
6357
def _item_to_data_record(self, item: Dict[str, Any]) -> DataRecord:
6458
# Need to review this after adding GETKEY logic
6559
return DataRecord(
@@ -93,10 +87,10 @@ def _put_record(self, data_record: DataRecord) -> None:
9387
}
9488

9589
if data_record.in_progress_expiry_timestamp is not None:
96-
item["mapping"][self.in_progress_expiry_attr] = data_record.in_progress_expiry_timestamp
90+
item.update({"mapping": {self.in_progress_expiry_attr: data_record.in_progress_expiry_timestamp}})
9791

9892
if self.payload_validation_enabled:
99-
item["mapping"][self.validation_key_attr] = data_record.payload_hash
93+
item.update({"mapping": {self.validation_key_attr: data_record.payload_hash}})
10094

10195
try:
10296
# | LOCKED | RETRY if status = "INPROGRESS" | RETRY

0 commit comments

Comments
 (0)