diff --git a/src/aws_encryption_sdk/key_providers/kms.py b/src/aws_encryption_sdk/key_providers/kms.py index 599b84f40..7249fe2df 100644 --- a/src/aws_encryption_sdk/key_providers/kms.py +++ b/src/aws_encryption_sdk/key_providers/kms.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Master Key Providers for use with AWS KMS""" +import functools import logging import attr @@ -128,15 +129,44 @@ def _process_config(self): if self.config.key_ids: self.add_master_keys_from_list(self.config.key_ids) + def _wrap_client(self, region_name, method, *args, **kwargs): + """Proxies all calls to a kms clients methods and removes misbehaving clients + + :param str region_name: AWS Region ID (ex: us-east-1) + :param callable method: a method on the KMS client to proxy + :param tuple args: list of arguments to pass to the provided ``method`` + :param dict kwargs: dictonary of keyword arguments to pass to the provided ``method`` + """ + try: + return method(*args, **kwargs) + except botocore.exceptions.BotoCoreError: + self._regional_clients.pop(region_name) + _LOGGER.error( + 'Removing regional client "%s" from cache due to BotoCoreError on %s call', region_name, method.__name__ + ) + raise + + def _register_client(self, client, region_name): + """Uses functools.partial to wrap all methods on a client with the self._wrap_client method + + :param botocore.client.BaseClient client: the client to proxy + :param str region_name: AWS Region ID (ex: us-east-1) + """ + for item in client.meta.method_to_api_mapping: + method = getattr(client, item) + wrapped_method = functools.partial(self._wrap_client, region_name, method) + setattr(client, item, wrapped_method) + def add_regional_client(self, region_name): """Adds a regional client for the specified region if it does not already exist. :param str region_name: AWS Region ID (ex: us-east-1) """ if region_name not in self._regional_clients: - self._regional_clients[region_name] = boto3.session.Session( - region_name=region_name, botocore_session=self.config.botocore_session - ).client("kms", config=self._user_agent_adding_config) + session = boto3.session.Session(region_name=region_name, botocore_session=self.config.botocore_session) + client = session.client("kms", config=self._user_agent_adding_config) + self._register_client(client, region_name) + self._regional_clients[region_name] = client def add_regional_clients_from_list(self, region_names): """Adds multiple regional clients for the specified regions if they do not already exist. diff --git a/test/integration/test_i_aws_encrytion_sdk_client.py b/test/integration/test_i_aws_encrytion_sdk_client.py index e2115ccfe..639e21e95 100644 --- a/test/integration/test_i_aws_encrytion_sdk_client.py +++ b/test/integration/test_i_aws_encrytion_sdk_client.py @@ -16,10 +16,11 @@ import unittest import pytest +from botocore.exceptions import BotoCoreError import aws_encryption_sdk from aws_encryption_sdk.identifiers import USER_AGENT_SUFFIX, Algorithm -from aws_encryption_sdk.key_providers.kms import KMSMasterKey +from aws_encryption_sdk.key_providers.kms import KMSMasterKey, KMSMasterKeyProvider from .integration_test_utils import get_cmk_arn, setup_kms_master_key_provider @@ -58,6 +59,16 @@ def test_encrypt_verify_user_agent_kms_master_key(caplog): assert USER_AGENT_SUFFIX in caplog.text +def test_remove_bad_client(): + test = KMSMasterKeyProvider() + test.add_regional_client("us-fakey-12") + + with pytest.raises(BotoCoreError): + test._regional_clients["us-fakey-12"].list_keys() + + assert not test._regional_clients + + class TestKMSThickClientIntegration(unittest.TestCase): def setUp(self): self.kms_master_key_provider = setup_kms_master_key_provider()