diff --git a/src/dynamodb_encryption_sdk/internal/utils.py b/src/dynamodb_encryption_sdk/internal/utils.py index bb133766..93bdbca3 100644 --- a/src/dynamodb_encryption_sdk/internal/utils.py +++ b/src/dynamodb_encryption_sdk/internal/utils.py @@ -16,17 +16,20 @@ No guarantee is provided on the modules and APIs within this namespace staying consistent. Directly reference at your own risk. """ +import copy +from functools import partial + import attr import botocore.client from dynamodb_encryption_sdk.encrypted import CryptoConfig from dynamodb_encryption_sdk.encrypted.item import decrypt_python_item, encrypt_python_item from dynamodb_encryption_sdk.exceptions import InvalidArgumentError -from dynamodb_encryption_sdk.structures import EncryptionContext, TableInfo +from dynamodb_encryption_sdk.structures import CryptoAction, EncryptionContext, TableInfo from dynamodb_encryption_sdk.transform import dict_to_ddb try: # Python 3.5.0 and 3.5.1 have incompatible typing modules - from typing import Any, Callable, Dict, Text # noqa pylint: disable=unused-import + from typing import Any, Bool, Callable, Dict, Text # noqa pylint: disable=unused-import except ImportError: # pragma: no cover # We only actually need these imports when running the mypy checks pass @@ -271,19 +274,22 @@ def encrypt_batch_write_item(encrypt_method, crypto_config_method, write_method, """Transparently encrypt multiple items before putting them in a batch request. :param callable encrypt_method: Method to use to encrypt items - :param callable crypto_config_method: Method that accepts ``kwargs`` and provides a :class:`CryptoConfig` + :param callable crypto_config_method: Method that accepts a table name string and provides a :class:`CryptoConfig` :param callable write_method: Method that writes to the table :param **kwargs: Keyword arguments to pass to ``write_method`` :return: DynamoDB response :rtype: dict """ request_crypto_config = kwargs.pop("crypto_config", None) + table_crypto_configs = {} + plaintext_items = copy.deepcopy(kwargs["RequestItems"]) for table_name, items in kwargs["RequestItems"].items(): if request_crypto_config is not None: crypto_config = request_crypto_config else: crypto_config = crypto_config_method(table_name=table_name) + table_crypto_configs[table_name] = crypto_config for pos, value in enumerate(items): for request_type, item in value.items(): @@ -293,4 +299,91 @@ def encrypt_batch_write_item(encrypt_method, crypto_config_method, write_method, item=item["Item"], crypto_config=crypto_config.with_item(_item_transformer(encrypt_method)(item["Item"])), ) - return write_method(**kwargs) + + response = write_method(**kwargs) + return _process_batch_write_response(plaintext_items, response, table_crypto_configs) + + +def _process_batch_write_response(request, response, table_crypto_config): + # type: (Dict, Dict, Dict[Text, CryptoConfig]) -> Dict + """Handle unprocessed items in the response from a transparently encrypted write. + + :param dict request: The DynamoDB plaintext request dictionary + :param dict response: The DynamoDB response from the batch operation + :param Dict[Text, CryptoConfig] table_crypto_config: table level CryptoConfig used in encrypting the request items + :return: DynamoDB response, with any unprocessed items reverted back to the original plaintext values + :rtype: dict + """ + try: + unprocessed_items = response["UnprocessedItems"] + except KeyError: + return response + + # Unprocessed items need to be returned in their original state + for table_name, unprocessed in unprocessed_items.items(): + original_items = request[table_name] + crypto_config = table_crypto_config[table_name] + + if crypto_config.encryption_context.partition_key_name: + items_match = partial(_item_keys_match, crypto_config) + else: + items_match = partial(_item_attributes_match, crypto_config) + + for pos, operation in enumerate(unprocessed): + for request_type, item in operation.items(): + if request_type != "PutRequest": + continue + + for plaintext_item in original_items: + if plaintext_item.get(request_type) and items_match( + plaintext_item[request_type]["Item"], item["Item"] + ): + unprocessed[pos] = plaintext_item.copy() + break + + return response + + +def _item_keys_match(crypto_config, item1, item2): + # type: (CryptoConfig, Dict, Dict) -> Bool + """Determines whether the values in the primary and sort keys (if they exist) are the same + + :param CryptoConfig crypto_config: CryptoConfig used in encrypting the given items + :param dict item1: The first item to compare + :param dict item2: The second item to compare + :return: Bool response, True if the key attributes match + :rtype: bool + """ + partition_key_name = crypto_config.encryption_context.partition_key_name + sort_key_name = crypto_config.encryption_context.sort_key_name + + partition_keys_match = item1[partition_key_name] == item2[partition_key_name] + + if sort_key_name is None: + return partition_keys_match + + return partition_keys_match and item1[sort_key_name] == item2[sort_key_name] + + +def _item_attributes_match(crypto_config, plaintext_item, encrypted_item): + # type: (CryptoConfig, Dict, Dict) -> Bool + """Determines whether the unencrypted values in the plaintext items attributes are the same as those in the + encrypted item. Essentially this uses brute force to cover when we don't know the primary and sort + index attribute names, since they can't be encrypted. + + :param CryptoConfig crypto_config: CryptoConfig used in encrypting the given items + :param dict plaintext_item: The plaintext item + :param dict encrypted_item: The encrypted item + :return: Bool response, True if the unencrypted attributes in the plaintext item match those in + the encrypted item + :rtype: bool + """ + + for name, value in plaintext_item.items(): + if crypto_config.attribute_actions.action(name) == CryptoAction.ENCRYPT_AND_SIGN: + continue + + if encrypted_item.get(name) != value: + return False + + return True diff --git a/test/functional/encrypted/test_client.py b/test/functional/encrypted/test_client.py index c928a55e..4498f954 100644 --- a/test/functional/encrypted/test_client.py +++ b/test/functional/encrypted/test_client.py @@ -17,6 +17,8 @@ from ..functional_test_utils import example_table # noqa pylint: disable=unused-import from ..functional_test_utils import ( TEST_TABLE_NAME, + build_static_jce_cmp, + client_batch_items_unprocessed_check, client_cycle_batch_items_check, client_cycle_batch_items_check_paginators, client_cycle_single_item_check, @@ -53,6 +55,12 @@ def _client_cycle_batch_items_check_paginators(materials_provider, initial_actio ) +def _client_batch_items_unprocessed_check(materials_provider, initial_actions, initial_item): + client_batch_items_unprocessed_check( + materials_provider, initial_actions, initial_item, TEST_TABLE_NAME, "us-west-2" + ) + + def test_ephemeral_item_cycle(example_table, some_cmps, parametrized_actions, parametrized_item): """Test a small number of curated CMPs against a small number of curated items.""" _client_cycle_single_item_check(some_cmps, parametrized_actions, parametrized_item) @@ -68,6 +76,12 @@ def test_ephemeral_batch_item_cycle_paginators(example_table, some_cmps, paramet _client_cycle_batch_items_check_paginators(some_cmps, parametrized_actions, parametrized_item) +def test_batch_item_unprocessed(example_table, parametrized_actions, parametrized_item): + """Test Unprocessed Items handling with a single ephemeral static CMP against a small number of curated items.""" + cmp = build_static_jce_cmp("AES", 256, "HmacSHA256", 256) + _client_batch_items_unprocessed_check(cmp, parametrized_actions, parametrized_item) + + @pytest.mark.slow def test_ephemeral_item_cycle_slow(example_table, all_the_cmps, parametrized_actions, parametrized_item): """Test ALL THE CMPS against a small number of curated items.""" diff --git a/test/functional/encrypted/test_resource.py b/test/functional/encrypted/test_resource.py index 3400906f..3c6a01e2 100644 --- a/test/functional/encrypted/test_resource.py +++ b/test/functional/encrypted/test_resource.py @@ -16,6 +16,8 @@ from ..functional_test_utils import example_table # noqa pylint: disable=unused-import from ..functional_test_utils import ( TEST_TABLE_NAME, + build_static_jce_cmp, + resource_batch_items_unprocessed_check, resource_cycle_batch_items_check, set_parametrized_actions, set_parametrized_cmp, @@ -35,11 +37,24 @@ def _resource_cycle_batch_items_check(materials_provider, initial_actions, initi resource_cycle_batch_items_check(materials_provider, initial_actions, initial_item, TEST_TABLE_NAME, "us-west-2") +def _resource_batch_items_unprocessed_check(materials_provider, initial_actions, initial_item): + resource_batch_items_unprocessed_check( + materials_provider, initial_actions, initial_item, TEST_TABLE_NAME, "us-west-2" + ) + + def test_ephemeral_batch_item_cycle(example_table, some_cmps, parametrized_actions, parametrized_item): """Test a small number of curated CMPs against a small number of curated items.""" _resource_cycle_batch_items_check(some_cmps, parametrized_actions, parametrized_item) +def test_batch_item_unprocessed(example_table, parametrized_actions, parametrized_item): + """Test Unprocessed Items handling with a single ephemeral static CMP against a small number of curated items.""" + _resource_batch_items_unprocessed_check( + build_static_jce_cmp("AES", 256, "HmacSHA256", 256), parametrized_actions, parametrized_item + ) + + @pytest.mark.travis_isolation @pytest.mark.slow def test_ephemeral_batch_item_cycle_slow(example_table, all_the_cmps, parametrized_actions, parametrized_item): diff --git a/test/functional/encrypted/test_table.py b/test/functional/encrypted/test_table.py index 4759edfe..fe9339ad 100644 --- a/test/functional/encrypted/test_table.py +++ b/test/functional/encrypted/test_table.py @@ -17,9 +17,11 @@ from ..functional_test_utils import example_table # noqa pylint: disable=unused-import from ..functional_test_utils import ( TEST_TABLE_NAME, + build_static_jce_cmp, set_parametrized_actions, set_parametrized_cmp, set_parametrized_item, + table_batch_writer_unprocessed_items_check, table_cycle_batch_writer_check, table_cycle_check, ) @@ -48,6 +50,14 @@ def test_ephemeral_item_cycle_batch_writer(example_table, some_cmps, parametrize table_cycle_batch_writer_check(some_cmps, parametrized_actions, parametrized_item, TEST_TABLE_NAME, "us-west-2") +def test_batch_writer_unprocessed(example_table, parametrized_actions, parametrized_item): + """Test Unprocessed Items handling with a single ephemeral static CMP against a small number of curated items.""" + cmp = build_static_jce_cmp("AES", 256, "HmacSHA256", 256) + table_batch_writer_unprocessed_items_check( + cmp, parametrized_actions, parametrized_item, TEST_TABLE_NAME, "us-west-2" + ) + + @pytest.mark.slow def test_ephemeral_item_cycle_slow(example_table, all_the_cmps, parametrized_actions, parametrized_item): """Test ALL THE CMPS against a small number of curated items.""" diff --git a/test/functional/functional_test_utils.py b/test/functional/functional_test_utils.py index 9353408a..86ebf196 100644 --- a/test/functional/functional_test_utils.py +++ b/test/functional/functional_test_utils.py @@ -24,6 +24,7 @@ import pytest from boto3.dynamodb.types import Binary from botocore.exceptions import NoRegionError +from mock import patch from moto import mock_dynamodb2 from dynamodb_encryption_sdk.delegated_keys.jce import JceNameLocalDelegatedKey @@ -336,6 +337,10 @@ def diverse_item(): _reserved_attributes = set([attr.value for attr in ReservedAttributes]) +def return_requestitems_as_unprocessed(*args, **kwargs): + return {"UnprocessedItems": kwargs["RequestItems"]} + + def check_encrypted_item(plaintext_item, ciphertext_item, attribute_actions): # Verify that all expected attributes are present ciphertext_attributes = set(ciphertext_item.keys()) @@ -374,12 +379,20 @@ def _nop_transformer(item): return item +def assert_items_exist_in_list(source, expected, transformer): + for actual_item in source: + expected_item = _matching_key(actual_item, expected) + assert transformer(actual_item) == transformer(expected_item) + + def assert_equal_lists_of_items(actual, expected, transformer=_nop_transformer): assert len(actual) == len(expected) + assert_items_exist_in_list(actual, expected, transformer) - for actual_item in actual: - expected_item = _matching_key(actual_item, expected) - assert transformer(actual_item) == transformer(expected_item) + +def assert_list_of_items_contains(full, subset, transformer=_nop_transformer): + assert len(full) >= len(subset) + assert_items_exist_in_list(subset, full, transformer) def check_many_encrypted_items(actual, expected, attribute_actions, transformer=_nop_transformer): @@ -479,6 +492,25 @@ def cycle_batch_writer_check(raw_table, encrypted_table, initial_actions, initia del items +def batch_write_item_unprocessed_check( + encrypted, initial_item, write_transformer=_nop_transformer, table_name=TEST_TABLE_NAME +): + """Check that unprocessed items in a batch result are unencrypted.""" + items = _generate_items(initial_item, write_transformer) + + request_items = {table_name: [{"PutRequest": {"Item": _item}} for _item in items]} + _put_result = encrypted.batch_write_item(RequestItems=request_items) + + # we expect results to include Unprocessed items, or the test case is invalid! + unprocessed_items = _put_result["UnprocessedItems"] + assert unprocessed_items != {} + + unprocessed = [operation["PutRequest"]["Item"] for operation in unprocessed_items[TEST_TABLE_NAME]] + assert_list_of_items_contains(items, unprocessed, transformer=_nop_transformer) + + del items + + def cycle_item_check(plaintext_item, crypto_config): """Check that cycling (plaintext->encrypted->decrypted) an item has the expected results.""" ciphertext_item = encrypt_python_item(plaintext_item, crypto_config) @@ -527,6 +559,30 @@ def table_cycle_batch_writer_check(materials_provider, initial_actions, initial_ cycle_batch_writer_check(table, e_table, initial_actions, initial_item) +def table_batch_writer_unprocessed_items_check( + materials_provider, initial_actions, initial_item, table_name, region_name=None +): + kwargs = {} + if region_name is not None: + kwargs["region_name"] = region_name + resource = boto3.resource("dynamodb", **kwargs) + table = resource.Table(table_name) + + items = _generate_items(initial_item, _nop_transformer) + request_items = {table_name: [{"PutRequest": {"Item": _item}} for _item in items]} + + with patch.object(table.meta.client, "batch_write_item") as batch_write_mock: + # Check that unprocessed items returned to a BatchWriter are successfully retried + batch_write_mock.side_effect = [{"UnprocessedItems": request_items}, {"UnprocessedItems": {}}] + e_table = EncryptedTable(table=table, materials_provider=materials_provider, attribute_actions=initial_actions) + + with e_table.batch_writer() as writer: + for item in items: + writer.put_item(item) + + del items + + def resource_cycle_batch_items_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): kwargs = {} if region_name is not None: @@ -550,6 +606,24 @@ def resource_cycle_batch_items_check(materials_provider, initial_actions, initia assert not e_scan_result["Items"] +def resource_batch_items_unprocessed_check( + materials_provider, initial_actions, initial_item, table_name, region_name=None +): + kwargs = {} + if region_name is not None: + kwargs["region_name"] = region_name + resource = boto3.resource("dynamodb", **kwargs) + + with patch.object(resource, "batch_write_item", return_requestitems_as_unprocessed): + e_resource = EncryptedResource( + resource=resource, materials_provider=materials_provider, attribute_actions=initial_actions + ) + + batch_write_item_unprocessed_check( + encrypted=e_resource, initial_item=initial_item, write_transformer=dict_to_ddb, table_name=table_name + ) + + def client_cycle_single_item_check(materials_provider, initial_actions, initial_item, table_name, region_name=None): check_attribute_actions = initial_actions.copy() check_attribute_actions.set_index_keys(*list(TEST_KEY.keys())) @@ -600,6 +674,24 @@ def client_cycle_batch_items_check(materials_provider, initial_actions, initial_ assert not e_scan_result["Items"] +def client_batch_items_unprocessed_check( + materials_provider, initial_actions, initial_item, table_name, region_name=None +): + kwargs = {} + if region_name is not None: + kwargs["region_name"] = region_name + client = boto3.client("dynamodb", **kwargs) + + with patch.object(client, "batch_write_item", return_requestitems_as_unprocessed): + e_client = EncryptedClient( + client=client, materials_provider=materials_provider, attribute_actions=initial_actions + ) + + batch_write_item_unprocessed_check( + encrypted=e_client, initial_item=initial_item, write_transformer=dict_to_ddb, table_name=table_name + ) + + def client_cycle_batch_items_check_paginators( materials_provider, initial_actions, initial_item, table_name, region_name=None ): diff --git a/test/functional/internal/test_utils.py b/test/functional/internal/test_utils.py new file mode 100644 index 00000000..1e52c0fd --- /dev/null +++ b/test/functional/internal/test_utils.py @@ -0,0 +1,152 @@ +# -*- coding: utf-8 -*- +# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""Test suite for ``dynamodb_encryption_sdk.internal.utils``.""" +import copy + +import pytest +from mock import Mock + +from dynamodb_encryption_sdk.encrypted import CryptoConfig +from dynamodb_encryption_sdk.identifiers import CryptoAction +from dynamodb_encryption_sdk.internal.utils import encrypt_batch_write_item +from dynamodb_encryption_sdk.material_providers import CryptographicMaterialsProvider +from dynamodb_encryption_sdk.structures import AttributeActions, EncryptionContext +from dynamodb_encryption_sdk.transform import dict_to_ddb + +from ..functional_test_utils import diverse_item + + +def get_test_item(standard_dict_format, partition_key, sort_key=None): + attributes = diverse_item() + + attributes["partition-key"] = partition_key + if sort_key is not None: + attributes["sort-key"] = sort_key + + if not standard_dict_format: + attributes = dict_to_ddb(attributes) + return attributes + + +def get_test_items(standard_dict_format, table_name="table", with_sort_keys=False): + + if with_sort_keys: + items = [ + get_test_item(standard_dict_format, partition_key="key-1", sort_key="sort-1"), + get_test_item(standard_dict_format, partition_key="key-2", sort_key="sort-1"), + get_test_item(standard_dict_format, partition_key="key-2", sort_key="sort-2"), + ] + else: + items = [ + get_test_item(standard_dict_format, partition_key="key-1"), + get_test_item(standard_dict_format, partition_key="key-2"), + ] + + for pos, item in enumerate(items): + item["encrypt-me"] = table_name + str(pos) + + return {table_name: [{"PutRequest": {"Item": item}} for item in items]} + + +def get_dummy_crypto_config(partition_key_name=None, sort_key_name=None, sign_keys=False): + context = EncryptionContext(partition_key_name=partition_key_name, sort_key_name=sort_key_name) + actions = AttributeActions( + default_action=CryptoAction.DO_NOTHING, attribute_actions={"encrypt-me": CryptoAction.ENCRYPT_AND_SIGN} + ) + if sign_keys: + actions.attribute_actions["partition-key"] = CryptoAction.SIGN_ONLY + actions.attribute_actions["sort-key"] = CryptoAction.SIGN_ONLY + + materials = Mock(spec=CryptographicMaterialsProvider) # type: CryptographicMaterialsProvider + return CryptoConfig(materials_provider=materials, encryption_context=context, attribute_actions=actions) + + +def check_encrypt_batch_write_item_call(request_items, crypto_config): + def dummy_encrypt(item, **kwargs): + result = item.copy() + result["encrypt-me"] = "pretend Im encrypted" + return result + + # execute a batch write, but make the write method return ALL the provided items as unprocessed + result = encrypt_batch_write_item( + encrypt_method=dummy_encrypt, + write_method=lambda **kwargs: {"UnprocessedItems": kwargs["RequestItems"]}, + crypto_config_method=lambda **kwargs: crypto_config, + RequestItems=copy.deepcopy(request_items), + ) + + # assert the returned items equal the submitted items + unprocessed = result["UnprocessedItems"] + + assert unprocessed == request_items + + +@pytest.mark.parametrize( + "items", (get_test_items(standard_dict_format=True), get_test_items(standard_dict_format=False)) +) +def test_encrypt_batch_write_returns_plaintext_unprocessed_items_with_known_partition_key(items): + crypto_config = get_dummy_crypto_config("partition-key") + check_encrypt_batch_write_item_call(items, crypto_config) + + +@pytest.mark.parametrize( + "items", + ( + get_test_items(standard_dict_format=True, with_sort_keys=True), + get_test_items(standard_dict_format=False, with_sort_keys=True), + ), +) +def test_encrypt_batch_write_returns_plaintext_unprocessed_items_with_known_partition_and_sort_keys(items): + crypto_config = get_dummy_crypto_config("partition-key", "sort-key") + check_encrypt_batch_write_item_call(items, crypto_config) + + +@pytest.mark.parametrize( + "items", + ( + get_test_items(standard_dict_format=True), + get_test_items(standard_dict_format=False), + get_test_items(standard_dict_format=True, with_sort_keys=True), + get_test_items(standard_dict_format=False, with_sort_keys=True), + ), +) +def test_encrypt_batch_write_returns_plaintext_unprocessed_items_with_unknown_keys(items): + crypto_config = get_dummy_crypto_config(None, None) + + check_encrypt_batch_write_item_call(items, crypto_config) + + +@pytest.mark.parametrize( + "items", + ( + get_test_items(standard_dict_format=True), + get_test_items(standard_dict_format=False), + get_test_items(standard_dict_format=True, with_sort_keys=True), + get_test_items(standard_dict_format=False, with_sort_keys=True), + ), +) +def test_encrypt_batch_write_returns_plaintext_unprocessed_items_with_unknown_signed_keys(items): + crypto_config = get_dummy_crypto_config(None, None, sign_keys=True) + + check_encrypt_batch_write_item_call(items, crypto_config) + + +def test_encrypt_batch_write_returns_plaintext_unprocessed_items_over_multiple_tables(): + crypto_config = get_dummy_crypto_config("partition-key", "sort-key") + + items = get_test_items(standard_dict_format=True, table_name="table-one", with_sort_keys=True) + more_items = get_test_items(standard_dict_format=False, table_name="table-two", with_sort_keys=True) + items.update(more_items) + + check_encrypt_batch_write_item_call(items, crypto_config)