Skip to content

Commit b820b8f

Browse files
committed
add default client generation for KMSMasterKey is none is provided: this both makes KMSMasterKey simpler to use directly and makes it more likely that the client will use our user agent
1 parent eb1b641 commit b820b8f

File tree

2 files changed

+49
-12
lines changed

2 files changed

+49
-12
lines changed

src/aws_encryption_sdk/key_providers/kms.py

+38-11
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,28 @@
3333
_PROVIDER_ID = 'aws-kms'
3434

3535

36+
def _region_from_key_id(key_id, default_region=None):
37+
"""Determine the target region from a key ID, falling back to a default region if provided.
38+
39+
:param str key_id: AWS KMS key ID
40+
:param str default_region: Region to use if no region found in key_id
41+
:returns: region name
42+
:rtype: str
43+
:raises UnknownRegionError: if no region found in key_id and no default_region provided
44+
"""
45+
try:
46+
region_name = key_id.split(':', 4)[3]
47+
if default_region is None:
48+
default_region = region_name
49+
except IndexError:
50+
if default_region is None:
51+
raise UnknownRegionError(
52+
'No default region found and no region determinable from key id: {}'.format(key_id)
53+
)
54+
region_name = default_region
55+
return region_name
56+
57+
3658
@attr.s(hash=True)
3759
class KMSMasterKeyProviderConfig(MasterKeyProviderConfig):
3860
"""Configuration object for KMSMasterKeyProvider objects.
@@ -136,16 +158,7 @@ def _client(self, key_id):
136158
137159
:param str key_id: KMS CMK ID
138160
"""
139-
try:
140-
region_name = key_id.split(':', 4)[3]
141-
if self.default_region is None:
142-
self.default_region = region_name
143-
except IndexError:
144-
if self.default_region is None:
145-
raise UnknownRegionError(
146-
'No default region found and no region determinable from key id: {}'.format(key_id)
147-
)
148-
region_name = self.default_region
161+
region_name = _region_from_key_id(key_id, self.default_region)
149162
self.add_regional_client(region_name)
150163
return self._regional_clients[region_name]
151164

@@ -175,14 +188,28 @@ class KMSMasterKeyConfig(MasterKeyConfig):
175188
"""
176189

177190
provider_id = _PROVIDER_ID
178-
client = attr.ib(hash=True, validator=attr.validators.instance_of(botocore.client.BaseClient))
191+
client = attr.ib(
192+
hash=True,
193+
validator=attr.validators.instance_of(botocore.client.BaseClient)
194+
)
179195
grant_tokens = attr.ib(
180196
hash=True,
181197
default=attr.Factory(tuple),
182198
validator=attr.validators.instance_of(tuple),
183199
converter=tuple
184200
)
185201

202+
@client.default
203+
def client_default(self):
204+
"""Create a client if one was not provided."""
205+
try:
206+
region_name = _region_from_key_id(to_str(self.key_id))
207+
kwargs = dict(region_name=region_name)
208+
except UnknownRegionError:
209+
kwargs = {}
210+
botocore_config = botocore.config.Config(user_agent_extra=USER_AGENT_SUFFIX)
211+
return boto3.session.Session(**kwargs).client('kms', config=botocore_config)
212+
186213

187214
class KMSMasterKey(MasterKey):
188215
"""Master Key class for KMS CMKs.

test/integration/test_i_aws_encrytion_sdk_client.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import aws_encryption_sdk
2121
from aws_encryption_sdk.identifiers import Algorithm, USER_AGENT_SUFFIX
22+
from aws_encryption_sdk.key_providers.kms import KMSMasterKey
2223
from .integration_test_utils import setup_kms_master_key_provider, get_cmk_arn
2324

2425
pytestmark = [pytest.mark.integ]
@@ -41,7 +42,7 @@
4142
}
4243

4344

44-
def test_encrypt_verify_user_agent(caplog):
45+
def test_encrypt_verify_user_agent_kms_master_key_provider(caplog):
4546
caplog.set_level(level=logging.DEBUG)
4647
mkp = setup_kms_master_key_provider()
4748
mk = mkp.master_key(get_cmk_arn())
@@ -51,6 +52,15 @@ def test_encrypt_verify_user_agent(caplog):
5152
assert USER_AGENT_SUFFIX in caplog.text
5253

5354

55+
def test_encrypt_verify_user_agent_kms_master_key(caplog):
56+
caplog.set_level(level=logging.DEBUG)
57+
mk = KMSMasterKey(key_id=get_cmk_arn())
58+
59+
mk.generate_data_key(algorithm=Algorithm.AES_256_GCM_IV12_TAG16_HKDF_SHA384_ECDSA_P384, encryption_context={})
60+
61+
assert USER_AGENT_SUFFIX in caplog.text
62+
63+
5464
class TestKMSThickClientIntegration(unittest.TestCase):
5565

5666
def setUp(self):

0 commit comments

Comments
 (0)