Skip to content

Commit 3bf45df

Browse files
Refactoring secrets
1 parent c016d24 commit 3bf45df

File tree

5 files changed

+54
-23
lines changed

5 files changed

+54
-23
lines changed

aws_lambda_powertools/utilities/parameters/secrets.py

+16-11
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,23 @@ def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]:
120120
"""
121121
raise NotImplementedError()
122122

123-
def _set(
123+
def _create_secret(self, name: str, **sdk_options):
124+
try:
125+
sdk_options["Name"] = name
126+
return self.client.create_secret(**sdk_options)
127+
except Exception as exc:
128+
raise SetSecretError(f"Error setting secret - {str(exc)}") from exc
129+
130+
def _update_secret(self, name: str, **sdk_options):
131+
sdk_options["SecretId"] = name
132+
return self.client.put_secret_value(**sdk_options)
133+
134+
def set(
124135
self,
125136
name: str,
126137
value: Union[str, dict, bytes],
127138
*, # force keyword arguments
128139
client_request_token: Optional[str] = None,
129-
version_stages: Optional[list[str]] = None,
130140
**sdk_options,
131141
) -> SetSecretResponse:
132142
"""
@@ -143,8 +153,6 @@ def _set(
143153
a UUID-type value to ensure uniqueness within the specified secret.
144154
This value becomes the VersionId of the new version. This field is
145155
autopopulated if not provided.
146-
version_stages: list[str], optional
147-
Specifies a list of staging labels that are attached to this version of the secret.
148156
sdk_options: dict, optional
149157
Dictionary of options that will be passed to the Secrets Manager update_secret API call
150158
@@ -157,8 +165,6 @@ def _set(
157165
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager/client/put_secret_value.html
158166
"""
159167

160-
sdk_options["SecretId"] = name
161-
162168
if isinstance(value, dict):
163169
value = json.dumps(value)
164170

@@ -167,13 +173,13 @@ def _set(
167173
else:
168174
sdk_options["SecretString"] = value
169175

170-
if version_stages:
171-
sdk_options["VersionStages"] = version_stages
172176
if client_request_token:
173177
sdk_options["ClientRequestToken"] = client_request_token
174178

175179
try:
176-
return self.client.put_secret_value(**sdk_options)
180+
return self._update_secret(name=name, **sdk_options)
181+
except self.client.exceptions.ResourceNotFoundException:
182+
return self._create_secret(name=name, **sdk_options)
177183
except Exception as exc:
178184
raise SetSecretError(f"Error setting secret - {str(exc)}") from exc
179185

@@ -350,10 +356,9 @@ def set_secret(
350356
if "secrets" not in DEFAULT_PROVIDERS:
351357
DEFAULT_PROVIDERS["secrets"] = SecretsProvider()
352358

353-
return DEFAULT_PROVIDERS["secrets"]._set(
359+
return DEFAULT_PROVIDERS["secrets"].set(
354360
name=name,
355361
value=value,
356362
client_request_token=client_request_token,
357-
version_stages=version_stages,
358363
**sdk_options,
359364
)

aws_lambda_powertools/utilities/parameters/types.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from aws_lambda_powertools.shared.types import List, Literal, TypedDict
1+
from typing import Any, Optional
2+
3+
from aws_lambda_powertools.shared.types import Dict, List, Literal, TypedDict
24

35
TransformOptions = Literal["json", "binary", "auto", None]
46

@@ -13,5 +15,6 @@ class SetSecretResponse(TypedDict):
1315
ARN: str
1416
Name: str
1517
VersionId: str
16-
VersionStages: List[str]
18+
VersionStages: Optional[List[str]]
19+
ReplicationStatus: Optional[List[Dict[str, Any]]]
1720
ResponseMetadata: dict

docs/utilities/parameters.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ This utility requires additional permissions to work as expected.
3232
| SSM | **`get_parameter`**, **`SSMProvider.get`** | **`ssm:GetParameter`** |
3333
| SSM | **`get_parameters`**, **`SSMProvider.get_multiple`** | **`ssm:GetParametersByPath`** |
3434
| SSM | **`get_parameters_by_name`**, **`SSMProvider.get_parameters_by_name`** | **`ssm:GetParameter`** and **`ssm:GetParameters`** |
35-
| SSM | **`set_parameter`** | **`ssm:PutParameter`** |
35+
| SSM | **`set_parameter`**, **`SSMProvider.set_parameter`** | **`ssm:PutParameter`** |
3636
| SSM | If using **`decrypt=True`** | You must add an additional permission **`kms:Decrypt`** |
3737
| Secrets | **`get_secret`**, **`SecretsProvider.get`** | **`secretsmanager:GetSecretValue`** |
38-
| Secrets | **`set_secret`**, **`SecretsProvider.get`** | **`secretsmanager:PutSecretValue`** and or **`secretsmanager:CreateSecret`** |
38+
| Secrets | **`set_secret`**, **`SecretsProvider.set`** | **`secretsmanager:PutSecretValue`** and or **`secretsmanager:CreateSecret`** |
3939
| DynamoDB | **`DynamoDBProvider.get`** | **`dynamodb:GetItem`** |
4040
| DynamoDB | **`DynamoDBProvider.get_multiple`** | **`dynamodb:Query`** |
4141
| AppConfig | **`get_app_config`**, **`AppConfigProvider.get_app_config`** | **`appconfig:GetLatestConfiguration`** and **`appconfig:StartConfigurationSession`** |

examples/parameters/src/getting_started_setting_secret.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
def access_token(client_id: str, client_secret: str, audience: str) -> str:
1111
# example function that returns a JWT Access Token
12-
...
12+
# add your own logic here
1313
return f"{client_id}.{client_secret}.{audience}"
1414

1515

@@ -25,6 +25,6 @@ def lambda_handler(event: dict, context: LambdaContext):
2525
update_secret_version_id = parameters.set_secret(name="/aws-powertools/jwt_token", value=jwt_token)
2626

2727
return {"access_token": "updated", "statusCode": 200, "update_secret_version_id": update_secret_version_id}
28-
except parameters.exceptions.SetParameterError as error:
28+
except parameters.exceptions.SetSecretError as error:
2929
logger.exception(error)
3030
return {"access_token": "updated", "statusCode": 400}

tests/functional/test_utilities_parameters.py

+29-6
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ def test_ssm_provider_get(mock_name, mock_value, mock_version, config):
513513

514514
def test_set_parameter(monkeypatch, mock_name, mock_value):
515515
"""
516-
Test get_parameter()
516+
Test set_parameter()
517517
"""
518518

519519
class TestProvider(BaseProvider):
@@ -534,7 +534,7 @@ def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
534534
assert value == mock_value
535535

536536

537-
def test_ssm_provider_set(mock_name, mock_value, mock_version, config):
537+
def test_ssm_provider_set_parameter(mock_name, mock_value, mock_version, config):
538538
"""
539539
Test SSMProvider.set_parameter() with a non-cached value
540540
"""
@@ -564,7 +564,7 @@ def test_ssm_provider_set(mock_name, mock_value, mock_version, config):
564564
stubber.deactivate()
565565

566566

567-
def test_ssm_provider_set_default_config(monkeypatch, mock_name, mock_value, mock_version):
567+
def test_ssm_provider_set_parameter_default_config(monkeypatch, mock_name, mock_value, mock_version):
568568
"""
569569
Test SSMProvider._set() without specifying the config
570570
"""
@@ -596,9 +596,9 @@ def test_ssm_provider_set_default_config(monkeypatch, mock_name, mock_value, moc
596596
stubber.deactivate()
597597

598598

599-
def test_ssm_provider_set_with_custom_options(monkeypatch, mock_name, mock_value, mock_version):
599+
def test_ssm_provider_set_parameter_with_custom_options(monkeypatch, mock_name, mock_value, mock_version):
600600
"""
601-
Test SSMProvider._set() without specifying the config
601+
Test SSMProvider._set() with custom options
602602
"""
603603

604604
monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-2")
@@ -638,7 +638,7 @@ def test_ssm_provider_set_with_custom_options(monkeypatch, mock_name, mock_value
638638
stubber.deactivate()
639639

640640

641-
def test_ssm_provider_set_raise_on_failure(mock_name, mock_value, mock_version, config):
641+
def test_ssm_provider_set_parameter_raise_on_failure(mock_name, mock_value, mock_version, config):
642642
"""
643643
Test SSMProvider.set_parameter() with failure
644644
"""
@@ -669,6 +669,29 @@ def test_ssm_provider_set_raise_on_failure(mock_name, mock_value, mock_version,
669669
stubber.deactivate()
670670

671671

672+
def test_set_secret(monkeypatch, mock_name, mock_value):
673+
"""
674+
Test set_secret()
675+
"""
676+
677+
class TestProvider(BaseProvider):
678+
def set(self, name: str, value: Any, *, overwrite: bool = False, **kwargs) -> str:
679+
assert name == mock_name
680+
return mock_value
681+
682+
def _get(self, name: str, **kwargs) -> str:
683+
raise NotImplementedError()
684+
685+
def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]:
686+
raise NotImplementedError()
687+
688+
monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "secrets", TestProvider())
689+
690+
value = parameters.set_secret(name=mock_name, value=mock_value)
691+
692+
assert value == mock_value
693+
694+
672695
def test_ssm_provider_get_with_custom_client(mock_name, mock_value, mock_version, config):
673696
"""
674697
Test SSMProvider.get() with a non-cached value

0 commit comments

Comments
 (0)