Skip to content

Commit 6936dbf

Browse files
author
Michael Brewer
committed
feat(parameters): Add force_update option
Changes: - Add new force_update to always load even if there is a cached value
1 parent 5daa45a commit 6936dbf

File tree

3 files changed

+62
-4
lines changed

3 files changed

+62
-4
lines changed

aws_lambda_powertools/utilities/parameters/base.py

+13-3
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@ def _has_not_expired(self, key: Tuple[str, Optional[str]]) -> bool:
3838
return key in self.store and self.store[key].ttl >= datetime.now()
3939

4040
def get(
41-
self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options
41+
self,
42+
name: str,
43+
max_age: int = DEFAULT_MAX_AGE_SECS,
44+
transform: Optional[str] = None,
45+
force_update: bool = False,
46+
**sdk_options,
4247
) -> Union[str, list, dict, bytes]:
4348
"""
4449
Retrieve a parameter value or return the cached value
@@ -53,6 +58,8 @@ def get(
5358
Optional transformation of the parameter value. Supported values
5459
are "json" for JSON strings and "binary" for base 64 encoded
5560
values.
61+
force_update: bool, optional
62+
Force update even before a cached item has expired
5663
sdk_options: dict, optional
5764
Arguments that will be passed directly to the underlying API call
5865
@@ -76,7 +83,7 @@ def get(
7683
# an acceptable tradeoff.
7784
key = (name, transform)
7885

79-
if self._has_not_expired(key):
86+
if not force_update and self._has_not_expired(key):
8087
return self.store[key].value
8188

8289
try:
@@ -105,6 +112,7 @@ def get_multiple(
105112
max_age: int = DEFAULT_MAX_AGE_SECS,
106113
transform: Optional[str] = None,
107114
raise_on_transform_error: bool = False,
115+
force_update: bool = False,
108116
**sdk_options,
109117
) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]:
110118
"""
@@ -123,6 +131,8 @@ def get_multiple(
123131
raise_on_transform_error: bool, optional
124132
Raises an exception if any transform fails, otherwise this will
125133
return a None value for each transform that failed
134+
force_update: bool, optional
135+
Force update even before a cached item has expired
126136
sdk_options: dict, optional
127137
Arguments that will be passed directly to the underlying API call
128138
@@ -137,7 +147,7 @@ def get_multiple(
137147

138148
key = (path, transform)
139149

140-
if self._has_not_expired(key):
150+
if not force_update and self._has_not_expired(key):
141151
return self.store[key].value
142152

143153
try:

aws_lambda_powertools/utilities/parameters/ssm.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def get(
9292
max_age: int = DEFAULT_MAX_AGE_SECS,
9393
transform: Optional[str] = None,
9494
decrypt: bool = False,
95+
force_update: bool = False,
9596
**sdk_options
9697
) -> Union[str, list, dict, bytes]:
9798
"""
@@ -109,6 +110,8 @@ def get(
109110
values.
110111
decrypt: bool, optional
111112
If the parameter value should be decrypted
113+
force_update: bool, optional
114+
Force update even before a cached item has expired
112115
sdk_options: dict, optional
113116
Arguments that will be passed directly to the underlying API call
114117
@@ -124,7 +127,7 @@ def get(
124127
# Add to `decrypt` sdk_options to we can have an explicit option for this
125128
sdk_options["decrypt"] = decrypt
126129

127-
return super().get(name, max_age, transform, **sdk_options)
130+
return super().get(name, max_age, transform, force_update, **sdk_options)
128131

129132
def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str:
130133
"""

tests/functional/test_utilities_parameters.py

+45
Original file line numberDiff line numberDiff line change
@@ -1663,3 +1663,48 @@ def test_get_transform_method_preserve_auto_unhandled(key):
16631663
transform = parameters.base.get_transform_method(key, "auto")
16641664

16651665
assert transform is None
1666+
1667+
1668+
def test_base_provider_get_multiple_force_update(mock_name, mock_value):
1669+
"""
1670+
Test BaseProvider.get_multiple() with cached values and force_update is True
1671+
"""
1672+
1673+
class TestProvider(BaseProvider):
1674+
def _get(self, name: str, **kwargs) -> str:
1675+
raise NotImplementedError()
1676+
1677+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1678+
assert path == mock_name
1679+
return {"A": mock_value}
1680+
1681+
provider = TestProvider()
1682+
1683+
provider.store[(mock_name, None)] = ExpirableValue({"B": mock_value}, datetime.now() + timedelta(seconds=60))
1684+
1685+
value = provider.get_multiple(mock_name, force_update=True)
1686+
1687+
assert isinstance(value, dict)
1688+
assert value["A"] == mock_value
1689+
1690+
1691+
def test_base_provider_get_force_update(mock_name, mock_value):
1692+
"""
1693+
Test BaseProvider.get() with cached values and force_update is True
1694+
"""
1695+
1696+
class TestProvider(BaseProvider):
1697+
def _get(self, name: str, **kwargs) -> str:
1698+
return mock_value
1699+
1700+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
1701+
raise NotImplementedError()
1702+
1703+
provider = TestProvider()
1704+
1705+
provider.store[(mock_name, None)] = ExpirableValue("not-value", datetime.now() + timedelta(seconds=60))
1706+
1707+
value = provider.get(mock_name, force_update=True)
1708+
1709+
assert isinstance(value, str)
1710+
assert value == mock_value

0 commit comments

Comments
 (0)