|
| 1 | +# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"). You |
| 4 | +# may not use this file except in compliance with the License. A copy of |
| 5 | +# the License is located at |
| 6 | +# |
| 7 | +# http://aws.amazon.com/apache2.0/ |
| 8 | +# |
| 9 | +# or in the "license" file accompanying this file. This file is |
| 10 | +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
| 11 | +# ANY KIND, either express or implied. See the License for the specific |
| 12 | +# language governing permissions and limitations under the License. |
| 13 | +"""Basic sanity check for ``aws_encryption_sdk`` client behavior when threading.""" |
| 14 | +from __future__ import division |
| 15 | + |
| 16 | +import copy |
| 17 | +from random import SystemRandom |
| 18 | +import threading |
| 19 | +import time |
| 20 | + |
| 21 | +import pytest |
| 22 | +from six.moves import queue # six.moves confuses pylint: disable=import-error |
| 23 | + |
| 24 | +import aws_encryption_sdk |
| 25 | +from .integration_test_utils import setup_kms_master_key_provider, SKIP_MESSAGE, skip_tests |
| 26 | + |
| 27 | + |
| 28 | +PLAINTEXT = ( |
| 29 | + b'\xa3\xf6\xbc\x89\x95\x15(\xc8}\\\x8d=zu^{JA\xc1\xe9\xf0&m\xe6TD\x03' |
| 30 | + b'\x165F\x85\xae\x96\xd9~ \xa6\x13\x88\xf8\xdb\xc9\x0c\xd8\xd8\xd4\xe0' |
| 31 | + b'\x02\xe9\xdb+\xd4l\xeaq\xf6\xba.cg\xda\xe4V\xd9\x9a\x96\xe8\xf4:\xf5' |
| 32 | + b'\xfd\xd7\xa6\xfa\xd1\x85\xa7o\xf5\x94\xbcE\x14L\xa1\x87\xd9T\xa6\x95' |
| 33 | + b'eZVv\xfe[\xeeJ$a<9\x1f\x97\xe1\xd6\x9dQc\x8b7n\x0f\x1e\xbd\xf5\xba' |
| 34 | + b'\x0e\xae|%\xd8L]\xa2\xa2\x08\x1f' |
| 35 | +) |
| 36 | + |
| 37 | + |
| 38 | +def crypto_thread_worker(crypto_function, start_pause, input_value, output_queue, cache=None): |
| 39 | + """Pauses for ``start_pause`` seconds, then calls ``crypto_function`` with ``input_value`` as source, |
| 40 | + sending output to ``output_queue``. |
| 41 | +
|
| 42 | + :param callable crypto_function: AWS Encryption SDK crypto function to call in each thread |
| 43 | + :param float start_pause: Seconds to pause before running thread (introduces some variability |
| 44 | + to ensure multiple threads run simultaneously) |
| 45 | + :param input_value: Value to pass to ``crypto_function`` as source |
| 46 | + :param output_queue: Queue into which to put output of ``crypto_function`` (ciphertext or decrypted plaintext) |
| 47 | + :param cache: Cache to use with master key provider (optional) |
| 48 | + """ |
| 49 | + time.sleep(start_pause) |
| 50 | + kms_master_key_provider = setup_kms_master_key_provider() |
| 51 | + if cache is None: |
| 52 | + # For simplicity, always use a caching CMM; just use a null cache if no cache is specified. |
| 53 | + cache = aws_encryption_sdk.NullCryptoMaterialsCache() |
| 54 | + materials_manager = aws_encryption_sdk.CachingCryptoMaterialsManager( |
| 55 | + master_key_provider=kms_master_key_provider, |
| 56 | + cache=cache, |
| 57 | + max_age=60.0 |
| 58 | + ) |
| 59 | + output_value, _header = crypto_function( |
| 60 | + source=input_value, |
| 61 | + materials_manager=materials_manager |
| 62 | + ) |
| 63 | + output_queue.put(output_value) |
| 64 | + |
| 65 | + |
| 66 | +def get_all_thread_outputs(crypto_function, thread_inputs): |
| 67 | + """Spawn a thread with ``crypto_function`` for each of ``thread_inputs``, |
| 68 | + collecting and returning all outputs. |
| 69 | +
|
| 70 | + :param callable crypto_function: AWS Encryption SDK crypto function to call in each thread |
| 71 | + :param list thread_inputs: List of inputs and pause times to feed to ``crypto_function`` as source. |
| 72 | + :retuns: Outputs (ciphertext or decrypted plaintext) from all threads in no particular order |
| 73 | + :rtype: list |
| 74 | + """ |
| 75 | + active_threads = [] |
| 76 | + output_queue = queue.Queue() |
| 77 | + for values in thread_inputs: |
| 78 | + _thread = threading.Thread( |
| 79 | + target=crypto_thread_worker, |
| 80 | + kwargs=dict( |
| 81 | + crypto_function=crypto_function, |
| 82 | + output_queue=output_queue, |
| 83 | + **values |
| 84 | + ) |
| 85 | + ) |
| 86 | + _thread.start() |
| 87 | + active_threads.append(_thread) |
| 88 | + output_values = [] |
| 89 | + for _thread in active_threads: |
| 90 | + _thread.join() |
| 91 | + output_values.append(output_queue.get()) |
| 92 | + return output_values |
| 93 | + |
| 94 | + |
| 95 | +def random_pause_time(max_seconds=3): |
| 96 | + """Generates a random pause time between 0.0 and 10.0, limited by max_seconds. |
| 97 | +
|
| 98 | + :param int max_seconds: Maximum pause time (default: 3) |
| 99 | + :rtype: float |
| 100 | + """ |
| 101 | + return SystemRandom().random() * 10 % max_seconds |
| 102 | + |
| 103 | + |
| 104 | +@pytest.mark.skipif(skip_tests(), reason=SKIP_MESSAGE) |
| 105 | +def test_threading_loop(): |
| 106 | + """Test thread safety of client.""" |
| 107 | + rounds = 20 |
| 108 | + plaintext_inputs = [ |
| 109 | + dict(input_value=copy.copy(PLAINTEXT), start_pause=random_pause_time()) |
| 110 | + for _round in range(rounds) |
| 111 | + ] |
| 112 | + |
| 113 | + ciphertext_values = get_all_thread_outputs( |
| 114 | + crypto_function=aws_encryption_sdk.encrypt, |
| 115 | + thread_inputs=plaintext_inputs |
| 116 | + ) |
| 117 | + ciphertext_inputs = [ |
| 118 | + dict(input_value=ciphertext, start_pause=random_pause_time()) |
| 119 | + for ciphertext in ciphertext_values |
| 120 | + ] |
| 121 | + |
| 122 | + decrypted_values = get_all_thread_outputs( |
| 123 | + crypto_function=aws_encryption_sdk.decrypt, |
| 124 | + thread_inputs=ciphertext_inputs |
| 125 | + ) |
| 126 | + |
| 127 | + assert all(value == PLAINTEXT for value in decrypted_values) |
| 128 | + |
| 129 | + |
| 130 | +@pytest.mark.skipif(skip_tests(), reason=SKIP_MESSAGE) |
| 131 | +def test_threading_loop_with_common_cache(): |
| 132 | + """Test thread safety of client while using common cryptographic materials cache across all threads.""" |
| 133 | + rounds = 20 |
| 134 | + cache = aws_encryption_sdk.LocalCryptoMaterialsCache(capacity=40) |
| 135 | + plaintext_inputs = [ |
| 136 | + dict(input_value=copy.copy(PLAINTEXT), start_pause=random_pause_time(), cache=cache) |
| 137 | + for _round in range(rounds) |
| 138 | + ] |
| 139 | + |
| 140 | + ciphertext_values = get_all_thread_outputs( |
| 141 | + crypto_function=aws_encryption_sdk.encrypt, |
| 142 | + thread_inputs=plaintext_inputs |
| 143 | + ) |
| 144 | + ciphertext_inputs = [ |
| 145 | + dict(input_value=ciphertext, start_pause=random_pause_time(), cache=cache) |
| 146 | + for ciphertext in ciphertext_values |
| 147 | + ] |
| 148 | + |
| 149 | + decrypted_values = get_all_thread_outputs( |
| 150 | + crypto_function=aws_encryption_sdk.decrypt, |
| 151 | + thread_inputs=ciphertext_inputs |
| 152 | + ) |
| 153 | + |
| 154 | + assert all(value == PLAINTEXT for value in decrypted_values) |
0 commit comments