Skip to content

Commit 64c60a0

Browse files
authored
fix(parameters): distinct cache key for single vs path with same name (#2839)
1 parent 9c5b6b5 commit 64c60a0

File tree

2 files changed

+67
-14
lines changed

2 files changed

+67
-14
lines changed

aws_lambda_powertools/utilities/parameters/base.py

+37-10
Original file line numberDiff line numberDiff line change
@@ -66,16 +66,16 @@ class BaseProvider(ABC):
6666
Abstract Base Class for Parameter providers
6767
"""
6868

69-
store: Dict[Tuple[str, TransformOptions], ExpirableValue]
69+
store: Dict[Tuple, ExpirableValue]
7070

7171
def __init__(self):
7272
"""
7373
Initialize the base provider
7474
"""
7575

76-
self.store: Dict[Tuple[str, TransformOptions], ExpirableValue] = {}
76+
self.store: Dict[Tuple, ExpirableValue] = {}
7777

78-
def has_not_expired_in_cache(self, key: Tuple[str, TransformOptions]) -> bool:
78+
def has_not_expired_in_cache(self, key: Tuple) -> bool:
7979
return key in self.store and self.store[key].ttl >= datetime.now()
8080

8181
def get(
@@ -123,13 +123,13 @@ def get(
123123
# parameter will always be used in a specific transform, this should be
124124
# an acceptable tradeoff.
125125
value: Optional[Union[str, bytes, dict]] = None
126-
key = (name, transform)
126+
key = self._build_cache_key(name=name, transform=transform)
127127

128128
# If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS
129129
max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age)
130130

131131
if not force_fetch and self.has_not_expired_in_cache(key):
132-
return self.store[key].value
132+
return self.fetch_from_cache(key)
133133

134134
try:
135135
value = self._get(name, **sdk_options)
@@ -142,7 +142,7 @@ def get(
142142

143143
# NOTE: don't cache None, as they might've been failed transforms and may be corrected
144144
if value is not None:
145-
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))
145+
self.add_to_cache(key=key, value=value, max_age=max_age)
146146

147147
return value
148148

@@ -191,13 +191,13 @@ def get_multiple(
191191
TransformParameterError
192192
When the parameter provider fails to transform a parameter value.
193193
"""
194-
key = (path, transform)
194+
key = self._build_cache_key(name=path, transform=transform, is_nested=True)
195195

196196
# If max_age is not set, resolve it from the environment variable, defaulting to DEFAULT_MAX_AGE_SECS
197197
max_age = resolve_max_age(env=os.getenv(constants.PARAMETERS_MAX_AGE_ENV, DEFAULT_MAX_AGE_SECS), choice=max_age)
198198

199199
if not force_fetch and self.has_not_expired_in_cache(key):
200-
return self.store[key].value # type: ignore # need to revisit entire typing here
200+
return self.fetch_from_cache(key)
201201

202202
try:
203203
values = self._get_multiple(path, **sdk_options)
@@ -208,7 +208,7 @@ def get_multiple(
208208
if transform:
209209
values.update(transform_value(values, transform, raise_on_transform_error))
210210

211-
self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age))
211+
self.add_to_cache(key=key, value=values, max_age=max_age)
212212

213213
return values
214214

@@ -222,12 +222,39 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
222222
def clear_cache(self):
223223
self.store.clear()
224224

225-
def add_to_cache(self, key: Tuple[str, TransformOptions], value: Any, max_age: int):
225+
def fetch_from_cache(self, key: Tuple):
226+
return self.store[key].value if key in self.store else {}
227+
228+
def add_to_cache(self, key: Tuple, value: Any, max_age: int):
226229
if max_age <= 0:
227230
return
228231

229232
self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age))
230233

234+
def _build_cache_key(
235+
self,
236+
name: str,
237+
transform: TransformOptions = None,
238+
is_nested: bool = False,
239+
):
240+
"""Creates cache key for parameters
241+
242+
Parameters
243+
----------
244+
name : str
245+
Name of parameter, secret or config
246+
transform : TransformOptions, optional
247+
Transform method used, by default None
248+
is_nested : bool, optional
249+
Whether it's a single parameter or multiple nested parameters, by default False
250+
251+
Returns
252+
-------
253+
Tuple[str, TransformOptions, bool]
254+
Cache key
255+
"""
256+
return (name, transform, is_nested)
257+
231258
@staticmethod
232259
def _build_boto3_client(
233260
service_name: str,

tests/functional/test_utilities_parameters.py

+30-4
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,8 @@ def test_dynamodb_provider_get_cached(mock_name, mock_value, config):
139139
provider = parameters.DynamoDBProvider(table_name, config=config)
140140

141141
# Inject value in the internal store
142-
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
142+
cache_key = provider._build_cache_key(name=mock_name)
143+
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)
143144

144145
# Stub the boto3 client
145146
stubber = stub.Stubber(provider.table.meta.client)
@@ -631,7 +632,8 @@ def test_ssm_provider_get_cached(mock_name, mock_value, config):
631632
provider = parameters.SSMProvider(config=config)
632633

633634
# Inject value in the internal store
634-
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
635+
cache_key = provider._build_cache_key(name=mock_name)
636+
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)
635637

636638
# Stub the boto3 client
637639
stubber = stub.Stubber(provider.client)
@@ -1332,7 +1334,8 @@ def test_secrets_provider_get_cached(mock_name, mock_value, config):
13321334
provider = parameters.SecretsProvider(config=config)
13331335

13341336
# Inject value in the internal store
1335-
provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60))
1337+
cache_key = provider._build_cache_key(name=mock_name)
1338+
provider.add_to_cache(key=cache_key, value=mock_value, max_age=60)
13361339

13371340
# Stub the boto3 client
13381341
stubber = stub.Stubber(provider.client)
@@ -1734,7 +1737,8 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
17341737

17351738
provider = TestProvider()
17361739

1737-
provider.store[(mock_name, None)] = ExpirableValue({"A": mock_value}, datetime.now() + timedelta(seconds=60))
1740+
cache_key = provider._build_cache_key(name=mock_name, is_nested=True)
1741+
provider.add_to_cache(key=cache_key, value={"A": mock_value}, max_age=60)
17381742

17391743
value = provider.get_multiple(mock_name)
17401744

@@ -2500,3 +2504,25 @@ def test_cache_ignores_max_age_zero_or_negative(mock_value, config):
25002504
# THEN they should not be added to the cache
25012505
assert len(provider.store) == 0
25022506
assert provider.has_not_expired_in_cache(cache_key) is False
2507+
2508+
2509+
def test_base_provider_single_and_nested_parameters_cached(mock_name, mock_value):
2510+
# GIVEN a custom provider
2511+
class TestProvider(BaseProvider):
2512+
def _get(self, name: str, **kwargs) -> str:
2513+
raise ValueError("This parameter doesn't exist")
2514+
2515+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
2516+
return {"A": mock_value}
2517+
2518+
provider = TestProvider()
2519+
2520+
# WHEN get_multiple is followed by get with the same name
2521+
# (path vs single parameter name)
2522+
provider.get_multiple(mock_name)
2523+
2524+
# THEN get should raise GetParameterError
2525+
# since a path will likely not be a valid parameter
2526+
# see #2438
2527+
with pytest.raises(parameters.exceptions.GetParameterError):
2528+
provider.get(mock_name)

0 commit comments

Comments
 (0)