Skip to content

Commit c767187

Browse files
authored
Merge pull request #15 from mattsb42-aws/threadsafe-tests
Threadsafety tests
2 parents 0a5736c + a55e01f commit c767187

File tree

2 files changed

+161
-9
lines changed

2 files changed

+161
-9
lines changed

test/integration/test_i_aws_encrytion_sdk_client.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,19 @@
1414
import io
1515
import unittest
1616

17-
import six
18-
1917
import aws_encryption_sdk
2018
from aws_encryption_sdk.identifiers import Algorithm
2119
from .integration_test_utils import setup_kms_master_key_provider, SKIP_MESSAGE, skip_tests
2220

2321

2422
VALUES = {
25-
'plaintext_128': six.b(
26-
'\xa3\xf6\xbc\x89\x95\x15(\xc8}\\\x8d=zu^{JA\xc1\xe9\xf0&m\xe6TD\x03'
27-
'\x165F\x85\xae\x96\xd9~ \xa6\x13\x88\xf8\xdb\xc9\x0c\xd8\xd8\xd4\xe0'
28-
'\x02\xe9\xdb+\xd4l\xeaq\xf6\xba.cg\xda\xe4V\xd9\x9a\x96\xe8\xf4:\xf5'
29-
'\xfd\xd7\xa6\xfa\xd1\x85\xa7o\xf5\x94\xbcE\x14L\xa1\x87\xd9T\xa6\x95'
30-
'eZVv\xfe[\xeeJ$a<9\x1f\x97\xe1\xd6\x9dQc\x8b7n\x0f\x1e\xbd\xf5\xba'
31-
'\x0e\xae|%\xd8L]\xa2\xa2\x08\x1f'
23+
'plaintext_128': (
24+
b'\xa3\xf6\xbc\x89\x95\x15(\xc8}\\\x8d=zu^{JA\xc1\xe9\xf0&m\xe6TD\x03'
25+
b'\x165F\x85\xae\x96\xd9~ \xa6\x13\x88\xf8\xdb\xc9\x0c\xd8\xd8\xd4\xe0'
26+
b'\x02\xe9\xdb+\xd4l\xeaq\xf6\xba.cg\xda\xe4V\xd9\x9a\x96\xe8\xf4:\xf5'
27+
b'\xfd\xd7\xa6\xfa\xd1\x85\xa7o\xf5\x94\xbcE\x14L\xa1\x87\xd9T\xa6\x95'
28+
b'eZVv\xfe[\xeeJ$a<9\x1f\x97\xe1\xd6\x9dQc\x8b7n\x0f\x1e\xbd\xf5\xba'
29+
b'\x0e\xae|%\xd8L]\xa2\xa2\x08\x1f'
3230
),
3331
'encryption_context': {
3432
'key_a': 'value_a',
+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You
4+
# may not use this file except in compliance with the License. A copy of
5+
# the License is located at
6+
#
7+
# http://aws.amazon.com/apache2.0/
8+
#
9+
# or in the "license" file accompanying this file. This file is
10+
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
11+
# ANY KIND, either express or implied. See the License for the specific
12+
# language governing permissions and limitations under the License.
13+
"""Basic sanity check for ``aws_encryption_sdk`` client behavior when threading."""
14+
from __future__ import division
15+
16+
import copy
17+
from random import SystemRandom
18+
import threading
19+
import time
20+
21+
import pytest
22+
from six.moves import queue # six.moves confuses pylint: disable=import-error
23+
24+
import aws_encryption_sdk
25+
from .integration_test_utils import setup_kms_master_key_provider, SKIP_MESSAGE, skip_tests
26+
27+
28+
PLAINTEXT = (
29+
b'\xa3\xf6\xbc\x89\x95\x15(\xc8}\\\x8d=zu^{JA\xc1\xe9\xf0&m\xe6TD\x03'
30+
b'\x165F\x85\xae\x96\xd9~ \xa6\x13\x88\xf8\xdb\xc9\x0c\xd8\xd8\xd4\xe0'
31+
b'\x02\xe9\xdb+\xd4l\xeaq\xf6\xba.cg\xda\xe4V\xd9\x9a\x96\xe8\xf4:\xf5'
32+
b'\xfd\xd7\xa6\xfa\xd1\x85\xa7o\xf5\x94\xbcE\x14L\xa1\x87\xd9T\xa6\x95'
33+
b'eZVv\xfe[\xeeJ$a<9\x1f\x97\xe1\xd6\x9dQc\x8b7n\x0f\x1e\xbd\xf5\xba'
34+
b'\x0e\xae|%\xd8L]\xa2\xa2\x08\x1f'
35+
)
36+
37+
38+
def crypto_thread_worker(crypto_function, start_pause, input_value, output_queue, cache=None):
39+
"""Pauses for ``start_pause`` seconds, then calls ``crypto_function`` with ``input_value`` as source,
40+
sending output to ``output_queue``.
41+
42+
:param callable crypto_function: AWS Encryption SDK crypto function to call in each thread
43+
:param float start_pause: Seconds to pause before running thread (introduces some variability
44+
to ensure multiple threads run simultaneously)
45+
:param input_value: Value to pass to ``crypto_function`` as source
46+
:param output_queue: Queue into which to put output of ``crypto_function`` (ciphertext or decrypted plaintext)
47+
:param cache: Cache to use with master key provider (optional)
48+
"""
49+
time.sleep(start_pause)
50+
kms_master_key_provider = setup_kms_master_key_provider()
51+
if cache is None:
52+
# For simplicity, always use a caching CMM; just use a null cache if no cache is specified.
53+
cache = aws_encryption_sdk.NullCryptoMaterialsCache()
54+
materials_manager = aws_encryption_sdk.CachingCryptoMaterialsManager(
55+
master_key_provider=kms_master_key_provider,
56+
cache=cache,
57+
max_age=60.0
58+
)
59+
output_value, _header = crypto_function(
60+
source=input_value,
61+
materials_manager=materials_manager
62+
)
63+
output_queue.put(output_value)
64+
65+
66+
def get_all_thread_outputs(crypto_function, thread_inputs):
67+
"""Spawn a thread with ``crypto_function`` for each of ``thread_inputs``,
68+
collecting and returning all outputs.
69+
70+
:param callable crypto_function: AWS Encryption SDK crypto function to call in each thread
71+
:param list thread_inputs: List of inputs and pause times to feed to ``crypto_function`` as source.
72+
:retuns: Outputs (ciphertext or decrypted plaintext) from all threads in no particular order
73+
:rtype: list
74+
"""
75+
active_threads = []
76+
output_queue = queue.Queue()
77+
for values in thread_inputs:
78+
_thread = threading.Thread(
79+
target=crypto_thread_worker,
80+
kwargs=dict(
81+
crypto_function=crypto_function,
82+
output_queue=output_queue,
83+
**values
84+
)
85+
)
86+
_thread.start()
87+
active_threads.append(_thread)
88+
output_values = []
89+
for _thread in active_threads:
90+
_thread.join()
91+
output_values.append(output_queue.get())
92+
return output_values
93+
94+
95+
def random_pause_time(max_seconds=3):
96+
"""Generates a random pause time between 0.0 and 10.0, limited by max_seconds.
97+
98+
:param int max_seconds: Maximum pause time (default: 3)
99+
:rtype: float
100+
"""
101+
return SystemRandom().random() * 10 % max_seconds
102+
103+
104+
@pytest.mark.skipif(skip_tests(), reason=SKIP_MESSAGE)
105+
def test_threading_loop():
106+
"""Test thread safety of client."""
107+
rounds = 20
108+
plaintext_inputs = [
109+
dict(input_value=copy.copy(PLAINTEXT), start_pause=random_pause_time())
110+
for _round in range(rounds)
111+
]
112+
113+
ciphertext_values = get_all_thread_outputs(
114+
crypto_function=aws_encryption_sdk.encrypt,
115+
thread_inputs=plaintext_inputs
116+
)
117+
ciphertext_inputs = [
118+
dict(input_value=ciphertext, start_pause=random_pause_time())
119+
for ciphertext in ciphertext_values
120+
]
121+
122+
decrypted_values = get_all_thread_outputs(
123+
crypto_function=aws_encryption_sdk.decrypt,
124+
thread_inputs=ciphertext_inputs
125+
)
126+
127+
assert all(value == PLAINTEXT for value in decrypted_values)
128+
129+
130+
@pytest.mark.skipif(skip_tests(), reason=SKIP_MESSAGE)
131+
def test_threading_loop_with_common_cache():
132+
"""Test thread safety of client while using common cryptographic materials cache across all threads."""
133+
rounds = 20
134+
cache = aws_encryption_sdk.LocalCryptoMaterialsCache(capacity=40)
135+
plaintext_inputs = [
136+
dict(input_value=copy.copy(PLAINTEXT), start_pause=random_pause_time(), cache=cache)
137+
for _round in range(rounds)
138+
]
139+
140+
ciphertext_values = get_all_thread_outputs(
141+
crypto_function=aws_encryption_sdk.encrypt,
142+
thread_inputs=plaintext_inputs
143+
)
144+
ciphertext_inputs = [
145+
dict(input_value=ciphertext, start_pause=random_pause_time(), cache=cache)
146+
for ciphertext in ciphertext_values
147+
]
148+
149+
decrypted_values = get_all_thread_outputs(
150+
crypto_function=aws_encryption_sdk.decrypt,
151+
thread_inputs=ciphertext_inputs
152+
)
153+
154+
assert all(value == PLAINTEXT for value in decrypted_values)

0 commit comments

Comments
 (0)