From 9ab4afb181e12d6c7a693972c8e5819e45edcd75 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Tue, 4 Dec 2018 19:05:56 -0800 Subject: [PATCH] remove now unnecessary use of source_stream.close() --- src/aws_encryption_sdk/streaming_client.py | 7 ++--- .../test_f_aws_encryption_sdk_client.py | 30 ++----------------- .../test_streaming_client_stream_decryptor.py | 4 --- .../test_streaming_client_stream_encryptor.py | 6 ---- 4 files changed, 5 insertions(+), 42 deletions(-) diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 55e6b779b..2132ea177 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -546,8 +546,7 @@ def _read_bytes_to_non_framed_body(self, b): if len(plaintext) < b: _LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b) - self.source_stream.close() - self.__unframed_plaintext_cache.close() + closing = self.encryptor.finalize() if self.signer is not None: @@ -625,7 +624,6 @@ def _read_bytes_to_framed_body(self, b): if self.signer is not None: output += serialize_footer(self.signer) self.__message_complete = True - self.source_stream.close() return output def _read_bytes(self, b): @@ -856,7 +854,6 @@ def _read_bytes_from_non_framed_body(self, b): plaintext += self.decryptor.finalize() self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) - self.source_stream.close() return plaintext def _read_bytes_from_framed_body(self, b): @@ -898,7 +895,7 @@ def _read_bytes_from_framed_body(self, b): if final_frame: _LOGGER.debug("Reading footer") self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier) - self.source_stream.close() + return plaintext def _read_bytes(self, b): diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index f92f2b3a4..3686ccead 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -770,15 +770,7 @@ def close(self): raise NotImplementedError("NoClose does not close().") -@pytest.mark.parametrize( - "wrapping_class", - ( - NoTell, - NoClosed, - pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), - pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), - ), -) +@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead)) @pytest.mark.parametrize("frame_length", (0, 1024)) def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): raw_plaintext = exact_length_plaintext(100) @@ -792,15 +784,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): assert raw_plaintext == decrypted -@pytest.mark.parametrize( - "wrapping_class", - ( - NoTell, - NoClosed, - pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), - pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), - ), -) +@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead)) @pytest.mark.parametrize("frame_length", (0, 1024)) def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): raw_plaintext = exact_length_plaintext(100) @@ -813,15 +797,7 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): assert raw_plaintext == decrypted -@pytest.mark.parametrize( - "wrapping_class", - ( - NoTell, - NoClosed, - pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), - pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), - ), -) +@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead)) @pytest.mark.parametrize("frame_length", (0, 1024)) def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class): plaintext = exact_length_plaintext(100) diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index e479070b4..467bf0faa 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -280,7 +280,6 @@ def test_read_bytes_from_non_framed(self): assert test_decryptor.decryptor is self.mock_decryptor_instance test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"]) self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"]) - assert test_decryptor.source_stream.closed assert test == b"12345678" def test_read_bytes_from_non_framed_message_body_too_small(self): @@ -324,7 +323,6 @@ def test_read_bytes_from_non_framed_finalize(self): self.mock_deserialize_footer.assert_called_once_with( stream=test_decryptor.source_stream, verifier=test_decryptor.verifier ) - assert test_decryptor.source_stream.closed assert test == b"12345678" def test_read_bytes_from_framed_body_multi_frame_finalize(self): @@ -459,7 +457,6 @@ def test_read_bytes_from_framed_body_multi_frame_finalize(self): self.mock_deserialize_footer.assert_called_once_with( stream=test_decryptor.source_stream, verifier=test_decryptor.verifier ) - assert test_decryptor.source_stream.closed assert test == b"1234567890-=" def test_read_bytes_from_framed_body_single_frame(self): @@ -484,7 +481,6 @@ def test_read_bytes_from_framed_body_single_frame(self): stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier ) assert not self.mock_deserialize_footer.called - assert not test_decryptor.source_stream.closed assert test == b"1234" def test_read_bytes_from_framed_body_bad_sequence_number(self): diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index e0ade654c..06f7ef3d9 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -388,7 +388,6 @@ def test_read_bytes_to_non_framed_body(self): test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5]) test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext) - assert not test_encryptor.source_stream.closed assert test is sentinel.ciphertext def test_read_bytes_to_non_framed_body_too_large(self): @@ -414,7 +413,6 @@ def test_read_bytes_to_non_framed_body_close(self): test = test_encryptor._read_bytes_to_non_framed_body(len(self.plaintext) + 1) test_encryptor.signer.update.assert_has_calls(calls=(call(b"123"), call(b"456")), any_order=False) - assert test_encryptor.source_stream.closed test_encryptor.encryptor.finalize.assert_called_once_with() self.mock_serialize_non_framed_close.assert_called_once_with( tag=test_encryptor.encryptor.tag, signer=test_encryptor.signer @@ -514,7 +512,6 @@ def test_read_bytes_to_framed_body_single_frame_read(self): signer=sentinel.signer, ) assert not self.mock_serialize_footer.called - assert not test_encryptor.source_stream.closed assert test == b"1234" def test_read_bytes_to_framed_body_single_frame_with_final(self): @@ -633,7 +630,6 @@ def test_read_bytes_to_framed_body_multi_frame_read(self): any_order=False, ) self.mock_serialize_footer.assert_called_once_with(sentinel.signer) - assert test_encryptor.source_stream.closed assert test == b"1234567890-=FINAL/*-" def test_read_bytes_to_framed_body_close(self): @@ -651,7 +647,6 @@ def test_read_bytes_to_framed_body_close(self): test_encryptor._read_bytes_to_framed_body(len(self.plaintext) + 1) self.mock_serialize_footer.assert_called_once_with(sentinel.signer) - assert test_encryptor.source_stream.closed def test_read_bytes_to_framed_body_close_no_signer(self): self.mock_serialize_frame.return_value = (b"1234", b"") @@ -670,7 +665,6 @@ def test_read_bytes_to_framed_body_close_no_signer(self): test_encryptor._read_bytes_to_framed_body(len(self.plaintext) + 1) assert not self.mock_serialize_footer.called - assert test_encryptor.source_stream.closed @patch("aws_encryption_sdk.streaming_client._EncryptionStream.close") def test_close(self, mock_close):