Skip to content

remove now unnecessary use of source_stream.close() #112

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ def _read_bytes_to_non_framed_body(self, b):

if len(plaintext) < b:
_LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b)
self.source_stream.close()
self.__unframed_plaintext_cache.close()

closing = self.encryptor.finalize()

if self.signer is not None:
Expand Down Expand Up @@ -625,7 +624,6 @@ def _read_bytes_to_framed_body(self, b):
if self.signer is not None:
output += serialize_footer(self.signer)
self.__message_complete = True
self.source_stream.close()
return output

def _read_bytes(self, b):
Expand Down Expand Up @@ -856,7 +854,6 @@ def _read_bytes_from_non_framed_body(self, b):
plaintext += self.decryptor.finalize()

self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier)
self.source_stream.close()
return plaintext

def _read_bytes_from_framed_body(self, b):
Expand Down Expand Up @@ -898,7 +895,7 @@ def _read_bytes_from_framed_body(self, b):
if final_frame:
_LOGGER.debug("Reading footer")
self.footer = deserialize_footer(stream=self.source_stream, verifier=self.verifier)
self.source_stream.close()

return plaintext

def _read_bytes(self, b):
Expand Down
30 changes: 3 additions & 27 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,15 +770,7 @@ def close(self):
raise NotImplementedError("NoClose does not close().")


@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
NoClosed,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead))
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
raw_plaintext = exact_length_plaintext(100)
Expand All @@ -792,15 +784,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
assert raw_plaintext == decrypted


@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
NoClosed,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead))
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
raw_plaintext = exact_length_plaintext(100)
Expand All @@ -813,15 +797,7 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
assert raw_plaintext == decrypted


@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
NoClosed,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("wrapping_class", (NoTell, NoClosed, NoClose, NothingButRead))
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class):
plaintext = exact_length_plaintext(100)
Expand Down
4 changes: 0 additions & 4 deletions test/unit/test_streaming_client_stream_decryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,6 @@ def test_read_bytes_from_non_framed(self):
assert test_decryptor.decryptor is self.mock_decryptor_instance
test_decryptor.verifier.update.assert_called_once_with(VALUES["data_128"])
self.mock_decryptor_instance.update.assert_called_once_with(VALUES["data_128"])
assert test_decryptor.source_stream.closed
assert test == b"12345678"

def test_read_bytes_from_non_framed_message_body_too_small(self):
Expand Down Expand Up @@ -324,7 +323,6 @@ def test_read_bytes_from_non_framed_finalize(self):
self.mock_deserialize_footer.assert_called_once_with(
stream=test_decryptor.source_stream, verifier=test_decryptor.verifier
)
assert test_decryptor.source_stream.closed
assert test == b"12345678"

def test_read_bytes_from_framed_body_multi_frame_finalize(self):
Expand Down Expand Up @@ -459,7 +457,6 @@ def test_read_bytes_from_framed_body_multi_frame_finalize(self):
self.mock_deserialize_footer.assert_called_once_with(
stream=test_decryptor.source_stream, verifier=test_decryptor.verifier
)
assert test_decryptor.source_stream.closed
assert test == b"1234567890-="

def test_read_bytes_from_framed_body_single_frame(self):
Expand All @@ -484,7 +481,6 @@ def test_read_bytes_from_framed_body_single_frame(self):
stream=test_decryptor.source_stream, header=test_decryptor._header, verifier=test_decryptor.verifier
)
assert not self.mock_deserialize_footer.called
assert not test_decryptor.source_stream.closed
assert test == b"1234"

def test_read_bytes_from_framed_body_bad_sequence_number(self):
Expand Down
6 changes: 0 additions & 6 deletions test/unit/test_streaming_client_stream_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ def test_read_bytes_to_non_framed_body(self):

test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5])
test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext)
assert not test_encryptor.source_stream.closed
assert test is sentinel.ciphertext

def test_read_bytes_to_non_framed_body_too_large(self):
Expand All @@ -414,7 +413,6 @@ def test_read_bytes_to_non_framed_body_close(self):
test = test_encryptor._read_bytes_to_non_framed_body(len(self.plaintext) + 1)

test_encryptor.signer.update.assert_has_calls(calls=(call(b"123"), call(b"456")), any_order=False)
assert test_encryptor.source_stream.closed
test_encryptor.encryptor.finalize.assert_called_once_with()
self.mock_serialize_non_framed_close.assert_called_once_with(
tag=test_encryptor.encryptor.tag, signer=test_encryptor.signer
Expand Down Expand Up @@ -514,7 +512,6 @@ def test_read_bytes_to_framed_body_single_frame_read(self):
signer=sentinel.signer,
)
assert not self.mock_serialize_footer.called
assert not test_encryptor.source_stream.closed
assert test == b"1234"

def test_read_bytes_to_framed_body_single_frame_with_final(self):
Expand Down Expand Up @@ -633,7 +630,6 @@ def test_read_bytes_to_framed_body_multi_frame_read(self):
any_order=False,
)
self.mock_serialize_footer.assert_called_once_with(sentinel.signer)
assert test_encryptor.source_stream.closed
assert test == b"1234567890-=FINAL/*-"

def test_read_bytes_to_framed_body_close(self):
Expand All @@ -651,7 +647,6 @@ def test_read_bytes_to_framed_body_close(self):
test_encryptor._read_bytes_to_framed_body(len(self.plaintext) + 1)

self.mock_serialize_footer.assert_called_once_with(sentinel.signer)
assert test_encryptor.source_stream.closed

def test_read_bytes_to_framed_body_close_no_signer(self):
self.mock_serialize_frame.return_value = (b"1234", b"")
Expand All @@ -670,7 +665,6 @@ def test_read_bytes_to_framed_body_close_no_signer(self):
test_encryptor._read_bytes_to_framed_body(len(self.plaintext) + 1)

assert not self.mock_serialize_footer.called
assert test_encryptor.source_stream.closed

@patch("aws_encryption_sdk.streaming_client._EncryptionStream.close")
def test_close(self, mock_close):
Expand Down