diff --git a/src/aws_encryption_sdk/streaming_client.py b/src/aws_encryption_sdk/streaming_client.py index 516a21913..55e6b779b 100644 --- a/src/aws_encryption_sdk/streaming_client.py +++ b/src/aws_encryption_sdk/streaming_client.py @@ -243,10 +243,11 @@ def read(self, b=-1): output.write(self.output_buffer[:b]) self.output_buffer = self.output_buffer[b:] else: - while not self.source_stream.closed: - self._read_bytes(LINE_LENGTH) - output.write(self.output_buffer) - self.output_buffer = b"" + while True: + line = self.readline() + if not line: + break + output.write(line) self.bytes_read += output.tell() _LOGGER.debug("Returning %d bytes of %d bytes requested", output.tell(), b) @@ -294,10 +295,13 @@ def next(self): if self.closed: _LOGGER.debug("stream is closed") raise StopIteration() - if self.source_stream.closed and not self.output_buffer: + + line = self.readline() + if not line: _LOGGER.debug("nothing more to read") raise StopIteration() - return self.readline() + + return line #: Provides hook for Python3 iterator functionality. __next__ = next @@ -400,6 +404,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not- raise SerializationError("Source too large for non-framed message") self.__unframed_plaintext_cache = io.BytesIO() + self.__message_complete = False def ciphertext_length(self): """Returns the length of the resulting ciphertext message in bytes. @@ -552,6 +557,7 @@ def _read_bytes_to_non_framed_body(self, b): if self.signer is not None: closing += serialize_footer(self.signer) + self.__message_complete = True return ciphertext + closing return ciphertext @@ -618,6 +624,7 @@ def _read_bytes_to_framed_body(self, b): _LOGGER.debug("Writing footer") if self.signer is not None: output += serialize_footer(self.signer) + self.__message_complete = True self.source_stream.close() return output @@ -628,7 +635,7 @@ def _read_bytes(self, b): :raises NotSupportedError: if content type is not supported """ _LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type) - if 0 <= b <= len(self.output_buffer) or self.source_stream.closed: + if 0 <= b <= len(self.output_buffer) or self.__message_complete: _LOGGER.debug("No need to read from source stream or source stream closed") return @@ -900,8 +907,8 @@ def _read_bytes(self, b): :param int b: Number of bytes to read :raises NotSupportedError: if content type is not supported """ - if self.source_stream.closed: - _LOGGER.debug("Source stream closed") + if hasattr(self, "footer"): + _LOGGER.debug("Source stream processing complete") return buffer_length = len(self.output_buffer) diff --git a/test/functional/test_f_aws_encryption_sdk_client.py b/test/functional/test_f_aws_encryption_sdk_client.py index 74e78b382..f92f2b3a4 100644 --- a/test/functional/test_f_aws_encryption_sdk_client.py +++ b/test/functional/test_f_aws_encryption_sdk_client.py @@ -761,9 +761,11 @@ def tell(self): raise NotImplementedError("NoTell does not tell().") -class NoClose(ObjectProxy): +class NoClosed(ObjectProxy): closed = NotImplemented + +class NoClose(ObjectProxy): def close(self): raise NotImplementedError("NoClose does not close().") @@ -772,6 +774,7 @@ def close(self): "wrapping_class", ( NoTell, + NoClosed, pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), ), @@ -793,6 +796,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class): "wrapping_class", ( NoTell, + NoClosed, pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), ), @@ -813,6 +817,7 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class): "wrapping_class", ( NoTell, + NoClosed, pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)), pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)), ), diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index c59ae4beb..e479070b4 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -502,10 +502,10 @@ def test_read_bytes_from_framed_body_bad_sequence_number(self): @patch("aws_encryption_sdk.streaming_client.StreamDecryptor._read_bytes_from_non_framed_body") @patch("aws_encryption_sdk.streaming_client.StreamDecryptor._read_bytes_from_framed_body") - def test_read_bytes_closed(self, mock_read_frame, mock_read_block): + def test_read_bytes_completed(self, mock_read_frame, mock_read_block): ct_stream = io.BytesIO(VALUES["data_128"]) test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) - test_decryptor.source_stream.close() + test_decryptor.footer = None test_decryptor._read_bytes(5) assert not mock_read_frame.called assert not mock_read_block.called diff --git a/test/unit/test_streaming_client_stream_encryptor.py b/test/unit/test_streaming_client_stream_encryptor.py index 8435d2bb3..e0ade654c 100644 --- a/test/unit/test_streaming_client_stream_encryptor.py +++ b/test/unit/test_streaming_client_stream_encryptor.py @@ -451,10 +451,10 @@ def test_read_bytes_less_than_buffer(self, mock_read_non_framed, mock_read_frame @patch("aws_encryption_sdk.streaming_client.StreamEncryptor._read_bytes_to_framed_body") @patch("aws_encryption_sdk.streaming_client.StreamEncryptor._read_bytes_to_non_framed_body") - def test_read_bytes_closed(self, mock_read_non_framed, mock_read_framed): + def test_read_bytes_completed(self, mock_read_non_framed, mock_read_framed): pt_stream = io.BytesIO(self.plaintext) test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider) - test_encryptor.source_stream.close() + test_encryptor._StreamEncryptor__message_complete = True test_encryptor._read_bytes(5) assert not mock_read_non_framed.called assert not mock_read_framed.called