From f7f25a83032c8e2f189982d26a5c7cc8d8174b18 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 13 Nov 2018 13:42:03 -0800 Subject: [PATCH 1/6] add tests for encrypting, decrypting, and cycling data using streams that only support read() --- .../test_f_aws_encryption_sdk_client.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 08cec1c5e..8971b8770 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -745,3 +745,42 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size): _look_in_logs(caplog, plaintext) _error_check(capsys) + + +class NothingButRead(object): + def __init__(self, data): + self._data = io.BytesIO(data) + + def read(self, size=-1): + return self._data.read(size) + + +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_cycle_nothing_but_read(frame_length): + raw_plaintext = exact_length_plaintext(100) + plaintext = NothingButRead(raw_plaintext) + key_provider = fake_kms_key_provider() + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + ciphertext = NothingButRead(raw_ciphertext) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert raw_plaintext == decrypted + + +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_encrypt_nothing_but_read(frame_length): + raw_plaintext = exact_length_plaintext(100) + plaintext = NothingButRead(raw_plaintext) + key_provider = fake_kms_key_provider() + ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert raw_plaintext == decrypted + + +@pytest.mark.parametrize("frame_length", (0, 1024)) +def test_decrypt_nothing_but_read(frame_length): + plaintext = exact_length_plaintext(100) + key_provider = fake_kms_key_provider() + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + ciphertext = NothingButRead(raw_ciphertext) + decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) + assert plaintext == decrypted From 4314e37e0804ede9f6755d3fdf5130aaa3ddeaf4 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 13 Nov 2018 19:27:47 -0800 Subject: [PATCH 2/6] remove need for source_stream.tell() on decrypt path --- .../internal/formatting/deserialize.py | 28 ++++-- src/aws_encryption_sdk/streaming_client.py | 66 +++++++++----- .../test_f_aws_encryption_sdk_client.py | 22 +++++ test/unit/test_deserialize.py | 29 ++++++ ...test_streaming_client_encryption_stream.py | 6 +- .../test_streaming_client_stream_decryptor.py | 88 +++++++++++-------- test/unit/test_utils.py | 9 +- test/unit/unit_test_utils.py | 13 +++ 8 files changed, 183 insertions(+), 78 deletions(-) diff --git a/src/aws_encryption_sdk/internal/formatting/deserialize.py b/src/aws_encryption_sdk/internal/formatting/deserialize.py index 024ccca28..86fa4c06d 100644 --- a/src/aws_encryption_sdk/internal/formatting/deserialize.py +++ b/src/aws_encryption_sdk/internal/formatting/deserialize.py @@ -282,7 +282,7 @@ def deserialize_header_auth(stream, algorithm, verifier=None): def deserialize_non_framed_values(stream, header, verifier=None): - """Deserializes the IV and Tag from a non-framed stream. + """Deserializes the IV and body length from a non-framed stream. :param stream: Source data stream :type stream: io.BytesIO @@ -290,18 +290,30 @@ def deserialize_non_framed_values(stream, header, verifier=None): :type header: aws_encryption_sdk.structures.MessageHeader :param verifier: Signature verifier object (optional) :type verifier: aws_encryption_sdk.internal.crypto.Verifier - :returns: IV, Tag, and Data Length values for body - :rtype: tuple of bytes, bytes, and int + :returns: IV and Data Length values for body + :rtype: tuple of bytes and int """ _LOGGER.debug("Starting non-framed body iv/tag deserialization") (data_iv, data_length) = unpack_values(">{}sQ".format(header.algorithm.iv_len), stream, verifier) - body_start = stream.tell() - stream.seek(data_length, 1) + return data_iv, data_length + + +def deserialize_tag(stream, header, verifier=None): + """Deserialize the Tag value from a non-framed stream. + + :param stream: Source data stream + :type stream: io.BytesIO + :param header: Deserialized header + :type header: aws_encryption_sdk.structures.MessageHeader + :param verifier: Signature verifier object (optional) + :type verifier: aws_encryption_sdk.internal.crypto.Verifier + :returns: Tag value for body + :rtype: bytes + """ (data_tag,) = unpack_values( - format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=None + format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=verifier ) - stream.seek(body_start, 0) - return data_iv, data_tag, data_length + return data_tag def update_verifier_with_tag(stream, header, verifier): diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 539bdf86d..1d13e448a 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -696,6 +696,7 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called """Prepares necessary initial values.""" self.last_sequence_number = 0 + self.__unframed_bytes_read = 0 def _prep_message(self): """Performs initial message setup.""" @@ -713,6 +714,7 @@ def _read_header(self): :raises CustomMaximumValueExceeded: if frame length is greater than the custom max value """ header, raw_header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(self.source_stream) + self.__unframed_bytes_read += len(raw_header) if ( self.config.max_body_length is not None @@ -751,9 +753,19 @@ def _read_header(self): ) return header, header_auth + @property + def body_start(self): + _LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0") + return self._body_start + + @property + def body_end(self): + _LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0") + return self._body_end + def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" - iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( + self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( stream=self.source_stream, header=self._header, verifier=self.verifier ) @@ -764,24 +776,10 @@ def _prep_non_framed(self): ) ) - aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( - content_type=self._header.content_type, is_final_frame=True - ) - associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad( - message_id=self._header.message_id, - aad_content_string=aad_content_string, - seq_num=1, - length=self.body_length, - ) - self.decryptor = Decryptor( - algorithm=self._header.algorithm, - key=self._derived_data_key, - associated_data=associated_data, - iv=iv, - tag=tag, - ) - self.body_start = self.source_stream.tell() - self.body_end = self.body_start + self.body_length + self.__unframed_bytes_read += self._header.algorithm.iv_len + self.__unframed_bytes_read += 8 # encrypted content length field + self._body_start = self.__unframed_bytes_read + self._body_end = self._body_start + self.body_length def _read_bytes_from_non_framed_body(self, b): """Reads the requested number of bytes from a streaming non-framed message body. @@ -792,7 +790,8 @@ def _read_bytes_from_non_framed_body(self, b): """ _LOGGER.debug("starting non-framed body read") # Always read the entire message for non-framed message bodies. - bytes_to_read = self.body_end - self.source_stream.tell() + bytes_to_read = self.body_length + _LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read) ciphertext = self.source_stream.read(bytes_to_read) @@ -802,11 +801,32 @@ def _read_bytes_from_non_framed_body(self, b): if self.verifier is not None: self.verifier.update(ciphertext) + tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( + stream=self.source_stream, + header=self._header, + verifier=self.verifier, + ) + + aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( + content_type=self._header.content_type, is_final_frame=True + ) + associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad( + message_id=self._header.message_id, + aad_content_string=aad_content_string, + seq_num=1, + length=self.body_length, + ) + self.decryptor = Decryptor( + algorithm=self._header.algorithm, + key=self._derived_data_key, + associated_data=associated_data, + iv=self._unframed_body_iv, + tag=tag, + ) + plaintext = self.decryptor.update(ciphertext) plaintext += self.decryptor.finalize() - aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag( - stream=self.source_stream, header=self._header, verifier=self.verifier - ) + self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer( stream=self.source_stream, verifier=self.verifier ) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 8971b8770..2d6e33405 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -755,6 +755,7 @@ def read(self, size=-1): return self._data.read(size) +@pytest.mark.xfail @pytest.mark.parametrize("frame_length", (0, 1024)) def test_cycle_nothing_but_read(frame_length): raw_plaintext = exact_length_plaintext(100) @@ -766,6 +767,7 @@ def test_cycle_nothing_but_read(frame_length): assert raw_plaintext == decrypted +@pytest.mark.xfail @pytest.mark.parametrize("frame_length", (0, 1024)) def test_encrypt_nothing_but_read(frame_length): raw_plaintext = exact_length_plaintext(100) @@ -776,6 +778,7 @@ def test_encrypt_nothing_but_read(frame_length): assert raw_plaintext == decrypted +@pytest.mark.xfail @pytest.mark.parametrize("frame_length", (0, 1024)) def test_decrypt_nothing_but_read(frame_length): plaintext = exact_length_plaintext(100) @@ -784,3 +787,22 @@ def test_decrypt_nothing_but_read(frame_length): ciphertext = NothingButRead(raw_ciphertext) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert plaintext == decrypted + + +@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0"))) +def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): + caplog.set_level(logging.WARNING) + plaintext = exact_length_plaintext(100) + key_provider = fake_kms_key_provider() + ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0) + with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor: + decrypted = decryptor.read() + + assert decrypted == plaintext + assert hasattr(decryptor, attribute) + watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format( + name=attribute, + version=no_later_than + ) + assert watch_string in caplog.text + assert aws_encryption_sdk.__version__ < no_later_than diff --git a/test/unit/test_deserialize.py b/test/unit/test_deserialize.py index ac3be1bf3..f99d57fb1 100644 --- a/test/unit/test_deserialize.py +++ b/test/unit/test_deserialize.py @@ -12,6 +12,7 @@ # language governing permissions and limitations under the License. """Unit test suite for aws_encryption_sdk.deserialize""" import io +import struct import unittest import pytest @@ -29,6 +30,34 @@ pytestmark = [pytest.mark.unit, pytest.mark.local] +def test_deserialize_non_framed_values(): + iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11' + length = 42 + packed = struct.pack(">12sQ", iv, length) + mock_header = MagicMock(algorithm=MagicMock(iv_len=12)) + + parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( + stream=io.BytesIO(packed), + header=mock_header + ) + + assert parsed_iv == iv + assert parsed_length == length + + +def test_deserialize_tag(): + tag = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15' + packed = struct.pack(">16s", tag) + mock_header = MagicMock(algorithm=MagicMock(auth_len=16)) + + parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( + stream=io.BytesIO(packed), + header=mock_header + ) + + assert parsed_tag == tag + + class TestDeserialize(unittest.TestCase): def setUp(self): self.mock_wrapping_algorithm = MagicMock() diff --git a/test/unit/test_streaming_client_encryption_stream.py b/test/unit/test_streaming_client_encryption_stream.py index 22cefffb0..e3a06347a 100644 --- a/test/unit/test_streaming_client_encryption_stream.py +++ b/test/unit/test_streaming_client_encryption_stream.py @@ -20,11 +20,11 @@ import aws_encryption_sdk.exceptions from aws_encryption_sdk.internal.defaults import LINE_LENGTH -from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.key_providers.base import MasterKeyProvider from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream from .test_values import VALUES +from .unit_test_utils import assert_prepped_stream_identity pytestmark = [pytest.mark.unit, pytest.mark.local] @@ -110,7 +110,7 @@ def test_new_with_params(self): ) assert mock_stream.config.source == self.mock_source_stream - assert isinstance(mock_stream.config.source, InsistentReaderBytesIO) + assert_prepped_stream_identity(mock_stream.config.source, object) assert mock_stream.config.key_provider is self.mock_key_provider assert mock_stream.config.mock_read_bytes is sentinel.read_bytes assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE @@ -120,7 +120,7 @@ def test_new_with_params(self): assert mock_stream.output_buffer == b"" assert not mock_stream._message_prepped assert mock_stream.source_stream == self.mock_source_stream - assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO) + assert_prepped_stream_identity(mock_stream.source_stream, object) assert mock_stream._stream_length is mock_int_sentinel assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index 50e981b16..987b31941 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -37,10 +37,12 @@ def setUp(self): data_key=VALUES["data_key_obj"], verification_key=sentinel.verification_key ) self.mock_header = MagicMock() - self.mock_header.algorithm = MagicMock(__class__=Algorithm) + self.mock_header.algorithm = MagicMock(__class__=Algorithm, iv_len=12) self.mock_header.encrypted_data_keys = sentinel.encrypted_data_keys self.mock_header.encryption_context = sentinel.encryption_context + self.mock_raw_header = b'some bytes' + self.mock_input_stream = MagicMock() self.mock_input_stream.__class__ = io.IOBase self.mock_input_stream.tell.side_effect = (0, 500) @@ -50,7 +52,7 @@ def setUp(self): "aws_encryption_sdk.streaming_client.aws_encryption_sdk.internal.formatting.deserialize.deserialize_header" ) self.mock_deserialize_header = self.mock_deserialize_header_patcher.start() - self.mock_deserialize_header.return_value = self.mock_header, sentinel.raw_header + self.mock_deserialize_header.return_value = self.mock_header, self.mock_raw_header # Set up deserialize_header_auth patch self.mock_deserialize_header_auth_patcher = patch( "aws_encryption_sdk.streaming_client" @@ -69,7 +71,14 @@ def setUp(self): ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values" ) self.mock_deserialize_non_framed_values = self.mock_deserialize_non_framed_values_patcher.start() - self.mock_deserialize_non_framed_values.return_value = (sentinel.iv, sentinel.tag, len(VALUES["data_128"])) + self.mock_deserialize_non_framed_values.return_value = (sentinel.iv, len(VALUES["data_128"])) + # Set up deserialize_tag_value patch + self.mock_deserialize_tag_patcher = patch( + "aws_encryption_sdk.streaming_client" + ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag" + ) + self.mock_deserialize_tag = self.mock_deserialize_tag_patcher.start() + self.mock_deserialize_tag.return_value = sentinel.tag # Set up get_aad_content_string patch self.mock_get_aad_content_string_patcher = patch( "aws_encryption_sdk.streaming_client.aws_encryption_sdk.internal.utils.get_aad_content_string" @@ -113,6 +122,7 @@ def tearDown(self): self.mock_deserialize_header_auth_patcher.stop() self.mock_validate_header_patcher.stop() self.mock_deserialize_non_framed_values_patcher.stop() + self.mock_deserialize_tag_patcher.stop() self.mock_get_aad_content_string_patcher.stop() self.mock_assemble_content_aad_patcher.stop() self.mock_decryptor_patcher.stop() @@ -151,11 +161,9 @@ def test_prep_message_non_framed_message(self, mock_read_header, mock_prep_non_f @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") - def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): + def test_read_header(self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): mock_verifier_instance = MagicMock() mock_verifier.from_key_bytes.return_value = mock_verifier_instance - mock_init.return_value = None ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(materials_manager=self.mock_materials_manager, source=ct_stream) test_decryptor.source_stream = ct_stream @@ -175,7 +183,7 @@ def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_material self.mock_materials_manager.decrypt_materials.assert_called_once_with( request=mock_decrypt_materials_request.return_value ) - mock_verifier_instance.update.assert_called_once_with(sentinel.raw_header) + mock_verifier_instance.update.assert_called_once_with(self.mock_raw_header) self.mock_deserialize_header_auth.assert_called_once_with( stream=ct_stream, algorithm=self.mock_header.algorithm, verifier=mock_verifier_instance ) @@ -188,18 +196,16 @@ def test_read_header(self, mock_init, mock_derive_datakey, mock_decrypt_material self.mock_validate_header.assert_called_once_with( header=self.mock_header, header_auth=sentinel.header_auth, - raw_header=sentinel.raw_header, + raw_header=self.mock_raw_header, data_key=mock_derive_datakey.return_value, ) assert test_header is self.mock_header assert test_header_auth is sentinel.header_auth @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") - def test_read_header_frame_too_large(self, mock_init, mock_derive_datakey): + def test_read_header_frame_too_large(self, mock_derive_datakey): self.mock_header.content_type = ContentType.FRAMED_DATA self.mock_header.frame_length = 1024 - mock_init.return_value = None ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream, max_body_length=10) test_decryptor.key_provider = self.mock_key_provider @@ -215,14 +221,12 @@ def test_read_header_frame_too_large(self, mock_init, mock_derive_datakey): @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - @patch("aws_encryption_sdk.streaming_client.StreamDecryptor.__init__") def test_read_header_no_verifier( - self, mock_init, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier + self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier ): self.mock_materials_manager.decrypt_materials.return_value = MagicMock( data_key=VALUES["data_key_obj"], verification_key=None ) - mock_init.return_value = None test_decryptor = StreamDecryptor(materials_manager=self.mock_materials_manager, source=self.mock_input_stream) test_decryptor.key_provider = self.mock_key_provider test_decryptor.source_stream = self.mock_input_stream @@ -264,6 +268,28 @@ def test_prep_non_framed(self): stream=test_decryptor.source_stream, header=self.mock_header, verifier=sentinel.verifier ) assert test_decryptor.body_length == len(VALUES["data_128"]) + assert test_decryptor.body_start == self.mock_header.algorithm.iv_len + 8 + assert test_decryptor.body_end == self.mock_header.algorithm.iv_len + 8 + len(VALUES["data_128"]) + + def test_read_bytes_from_non_framed(self): + ct_stream = io.BytesIO(VALUES["data_128"]) + test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) + test_decryptor.body_length = len(VALUES["data_128"]) + test_decryptor.decryptor = self.mock_decryptor_instance + test_decryptor._header = self.mock_header + test_decryptor.verifier = MagicMock() + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv + self.mock_decryptor_instance.update.return_value = b"1234" + self.mock_decryptor_instance.finalize.return_value = b"5678" + + test = test_decryptor._read_bytes_from_non_framed_body(5) + + self.mock_deserialize_tag.assert_called_once_with( + stream=test_decryptor.source_stream, + header=test_decryptor._header, + verifier=test_decryptor.verifier + ) self.mock_get_aad_content_string.assert_called_once_with( content_type=self.mock_header.content_type, is_final_frame=True ) @@ -277,24 +303,10 @@ def test_prep_non_framed(self): algorithm=self.mock_header.algorithm, key=sentinel.derived_data_key, associated_data=sentinel.associated_data, - iv=sentinel.iv, + iv=sentinel.unframed_body_iv, tag=sentinel.tag, ) assert test_decryptor.decryptor is self.mock_decryptor_instance - assert test_decryptor.body_start == 0 - assert test_decryptor.body_end == len(VALUES["data_128"]) - - def test_read_bytes_from_non_framed(self): - ct_stream = io.BytesIO(VALUES["data_128"]) - test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) - test_decryptor.decryptor = self.mock_decryptor_instance - test_decryptor._header = self.mock_header - test_decryptor.verifier = MagicMock() - self.mock_decryptor_instance.update.return_value = b"1234" - self.mock_decryptor_instance.finalize.return_value = b"5678" - test = test_decryptor._read_bytes_from_non_framed_body(5) test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"]) self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"]) assert test_decryptor.source_stream.closed @@ -303,8 +315,7 @@ def test_read_bytes_from_non_framed(self): def test_read_bytes_from_non_framed_message_body_too_small(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"] * 2) + test_decryptor.body_length = len(VALUES["data_128"] * 2) test_decryptor._header = self.mock_header with six.assertRaisesRegex( self, SerializationError, "Total message body contents less than specified in body description" @@ -314,10 +325,11 @@ def test_read_bytes_from_non_framed_message_body_too_small(self): def test_read_bytes_from_non_framed_no_verifier(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) + test_decryptor.body_length = len(VALUES["data_128"]) test_decryptor.decryptor = self.mock_decryptor_instance test_decryptor._header = self.mock_header + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv test_decryptor.verifier = None self.mock_decryptor_instance.update.return_value = b"1234" test_decryptor._read_bytes_from_non_framed_body(5) @@ -325,19 +337,19 @@ def test_read_bytes_from_non_framed_no_verifier(self): def test_read_bytes_from_non_framed_finalize(self): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.body_start = 0 - test_decryptor.body_length = test_decryptor.body_end = len(VALUES["data_128"]) + test_decryptor.body_length = len(VALUES["data_128"]) test_decryptor.decryptor = self.mock_decryptor_instance test_decryptor.verifier = MagicMock() test_decryptor._header = self.mock_header + test_decryptor._derived_data_key = sentinel.derived_data_key + test_decryptor._unframed_body_iv = sentinel.unframed_body_iv self.mock_decryptor_instance.update.return_value = b"1234" self.mock_decryptor_instance.finalize.return_value = b"5678" + test = test_decryptor._read_bytes_from_non_framed_body(len(VALUES["data_128"]) + 1) + test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"]) self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"]) - self.mock_update_verifier_with_tag.assert_called_once_with( - stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier - ) self.mock_deserialize_footer.assert_called_once_with( stream=test_decryptor.source_stream, verifier=test_decryptor.verifier ) diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index c30247522..a94519b58 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -23,10 +23,10 @@ import aws_encryption_sdk.internal.utils from aws_encryption_sdk.exceptions import InvalidDataKeyError, SerializationError, UnknownIdentityError from aws_encryption_sdk.internal.defaults import MAX_FRAME_SIZE, MESSAGE_ID_LENGTH -from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO from aws_encryption_sdk.structures import DataKey, EncryptedDataKey, MasterKeyInfo, RawDataKey from .test_values import VALUES +from .unit_test_utils import assert_prepped_stream_identity pytestmark = [pytest.mark.unit, pytest.mark.local] @@ -34,17 +34,14 @@ def test_prep_stream_data_passthrough(): test = aws_encryption_sdk.internal.utils.prep_stream_data(io.BytesIO(b"some data")) - assert isinstance(test, InsistentReaderBytesIO) + assert_prepped_stream_identity(test, io.BytesIO) @pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02")) def test_prep_stream_data_wrap(source): test = aws_encryption_sdk.internal.utils.prep_stream_data(source) - # Check the wrapped stream - assert isinstance(test, io.BytesIO) - # Check the wrapping stream - assert isinstance(test, InsistentReaderBytesIO) + assert_prepped_stream_identity(test, io.BytesIO) class TestUtils(unittest.TestCase): diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 7873456c1..3f4412237 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -14,6 +14,7 @@ import copy import io import itertools +from aws_encryption_sdk.internal.utils.streams import FauxCloseStream, LinearTellStream, InsistentReaderBytesIO def all_valid_kwargs(valid_kwargs): @@ -79,3 +80,15 @@ def read(self, size=-1): if self._read_counter >= 2: self.close() return super(ExactlyTwoReads, self).read(size) + + +class FailingTeller(object): + def tell(self): + raise IOError("Tell not allowed!") + + +def assert_prepped_stream_identity(prepped_stream, wrapped_type): + # Check the wrapped stream + assert isinstance(prepped_stream, wrapped_type) + # Check the wrapping streams + assert isinstance(prepped_stream, InsistentReaderBytesIO) From b4bbc1d12c0a11b5531c535fda8db1a70685c97a Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 14 Nov 2018 10:40:29 -0800 Subject: [PATCH 3/6] remove errant imports in unit_test_utils --- test/unit/unit_test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 3f4412237..3084005d4 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -14,7 +14,7 @@ import copy import io import itertools -from aws_encryption_sdk.internal.utils.streams import FauxCloseStream, LinearTellStream, InsistentReaderBytesIO +from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO def all_valid_kwargs(valid_kwargs): From 4f3f339a69c252f42e096144880187add3a38831 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 14 Nov 2018 11:31:11 -0800 Subject: [PATCH 4/6] add docstrings to deprecated properties --- src/aws_encryption_sdk/streaming_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 1d13e448a..302e2bf56 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -755,11 +755,13 @@ def _read_header(self): @property def body_start(self): + """Log deprecation warning when body_start is accessed.""" _LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0") return self._body_start @property def body_end(self): + """Log deprecation warning when body_end is accessed.""" _LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0") return self._body_end From 9945e6ce34f1126b198a134a113f067f395b3139 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 14 Nov 2018 11:31:32 -0800 Subject: [PATCH 5/6] autoformat --- src/aws_encryption_sdk/streaming_client.py | 4 +--- .../test_f_aws_encryption_sdk_client.py | 15 ++++++++++----- test/unit/test_deserialize.py | 10 ++++------ .../test_streaming_client_stream_decryptor.py | 13 ++++--------- test/unit/unit_test_utils.py | 1 + 5 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 302e2bf56..650c89144 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -804,9 +804,7 @@ def _read_bytes_from_non_framed_body(self, b): self.verifier.update(ciphertext) tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( - stream=self.source_stream, - header=self._header, - verifier=self.verifier, + stream=self.source_stream, header=self._header, verifier=self.verifier ) aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string( diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 2d6e33405..a0bd32675 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -761,7 +761,9 @@ def test_cycle_nothing_but_read(frame_length): raw_plaintext = exact_length_plaintext(100) plaintext = NothingButRead(raw_plaintext) key_provider = fake_kms_key_provider() - raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) ciphertext = NothingButRead(raw_ciphertext) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert raw_plaintext == decrypted @@ -773,7 +775,9 @@ def test_encrypt_nothing_but_read(frame_length): raw_plaintext = exact_length_plaintext(100) plaintext = NothingButRead(raw_plaintext) key_provider = fake_kms_key_provider() - ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert raw_plaintext == decrypted @@ -783,7 +787,9 @@ def test_encrypt_nothing_but_read(frame_length): def test_decrypt_nothing_but_read(frame_length): plaintext = exact_length_plaintext(100) key_provider = fake_kms_key_provider() - raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=frame_length) + raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt( + source=plaintext, key_provider=key_provider, frame_length=frame_length + ) ciphertext = NothingButRead(raw_ciphertext) decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider) assert plaintext == decrypted @@ -801,8 +807,7 @@ def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than): assert decrypted == plaintext assert hasattr(decryptor, attribute) watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format( - name=attribute, - version=no_later_than + name=attribute, version=no_later_than ) assert watch_string in caplog.text assert aws_encryption_sdk.__version__ < no_later_than diff --git a/test/unit/test_deserialize.py b/test/unit/test_deserialize.py index f99d57fb1..b591159e0 100644 --- a/test/unit/test_deserialize.py +++ b/test/unit/test_deserialize.py @@ -31,14 +31,13 @@ def test_deserialize_non_framed_values(): - iv = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11' + iv = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11" length = 42 packed = struct.pack(">12sQ", iv, length) mock_header = MagicMock(algorithm=MagicMock(iv_len=12)) parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( - stream=io.BytesIO(packed), - header=mock_header + stream=io.BytesIO(packed), header=mock_header ) assert parsed_iv == iv @@ -46,13 +45,12 @@ def test_deserialize_non_framed_values(): def test_deserialize_tag(): - tag = b'\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15' + tag = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15" packed = struct.pack(">16s", tag) mock_header = MagicMock(algorithm=MagicMock(auth_len=16)) parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag( - stream=io.BytesIO(packed), - header=mock_header + stream=io.BytesIO(packed), header=mock_header ) assert parsed_tag == tag diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index 987b31941..0c9f53ee7 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -41,7 +41,7 @@ def setUp(self): self.mock_header.encrypted_data_keys = sentinel.encrypted_data_keys self.mock_header.encryption_context = sentinel.encryption_context - self.mock_raw_header = b'some bytes' + self.mock_raw_header = b"some bytes" self.mock_input_stream = MagicMock() self.mock_input_stream.__class__ = io.IOBase @@ -74,8 +74,7 @@ def setUp(self): self.mock_deserialize_non_framed_values.return_value = (sentinel.iv, len(VALUES["data_128"])) # Set up deserialize_tag_value patch self.mock_deserialize_tag_patcher = patch( - "aws_encryption_sdk.streaming_client" - ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag" + "aws_encryption_sdk.streaming_client" ".aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag" ) self.mock_deserialize_tag = self.mock_deserialize_tag_patcher.start() self.mock_deserialize_tag.return_value = sentinel.tag @@ -221,9 +220,7 @@ def test_read_header_frame_too_large(self, mock_derive_datakey): @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @patch("aws_encryption_sdk.streaming_client.derive_data_encryption_key") - def test_read_header_no_verifier( - self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier - ): + def test_read_header_no_verifier(self, mock_derive_datakey, mock_decrypt_materials_request, mock_verifier): self.mock_materials_manager.decrypt_materials.return_value = MagicMock( data_key=VALUES["data_key_obj"], verification_key=None ) @@ -286,9 +283,7 @@ def test_read_bytes_from_non_framed(self): test = test_decryptor._read_bytes_from_non_framed_body(5) self.mock_deserialize_tag.assert_called_once_with( - stream=test_decryptor.source_stream, - header=test_decryptor._header, - verifier=test_decryptor.verifier + stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier ) self.mock_get_aad_content_string.assert_called_once_with( content_type=self.mock_header.content_type, is_final_frame=True diff --git a/test/unit/unit_test_utils.py b/test/unit/unit_test_utils.py index 3084005d4..6b0a84bdc 100644 --- a/test/unit/unit_test_utils.py +++ b/test/unit/unit_test_utils.py @@ -14,6 +14,7 @@ import copy import io import itertools + from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO From 8f7f215f57d4b1a5c48d9b3ee477c10561cd7363 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Wed, 5 Dec 2018 11:16:03 -0800 Subject: [PATCH 6/6] temporary linting line disable pending more complex fix in next PR --- src/aws_encryption_sdk/streaming_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 650c89144..faadc6515 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -767,7 +767,7 @@ def body_end(self): def _prep_non_framed(self): """Prepare the opening data for a non-framed message.""" - self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( + self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( # noqa # pylint: disable=line-too-long stream=self.source_stream, header=self._header, verifier=self.verifier )