Skip to content

Commit f417067

Browse files
committed
remove source_stream.closed from StreamEncryptor._read_bytes()
1 parent 785f33f commit f417067

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

src/aws_encryption_sdk/streaming_client.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-
404404
raise SerializationError("Source too large for non-framed message")
405405

406406
self.__unframed_plaintext_cache = io.BytesIO()
407+
self.__message_complete = False
407408

408409
def ciphertext_length(self):
409410
"""Returns the length of the resulting ciphertext message in bytes.
@@ -556,6 +557,7 @@ def _read_bytes_to_non_framed_body(self, b):
556557

557558
if self.signer is not None:
558559
closing += serialize_footer(self.signer)
560+
self.__message_complete = True
559561
return ciphertext + closing
560562

561563
return ciphertext
@@ -622,6 +624,7 @@ def _read_bytes_to_framed_body(self, b):
622624
_LOGGER.debug("Writing footer")
623625
if self.signer is not None:
624626
output += serialize_footer(self.signer)
627+
self.__message_complete = True
625628
self.source_stream.close()
626629
return output
627630

@@ -632,7 +635,7 @@ def _read_bytes(self, b):
632635
:raises NotSupportedError: if content type is not supported
633636
"""
634637
_LOGGER.debug("%d bytes requested from stream with content type: %s", b, self.content_type)
635-
if 0 <= b <= len(self.output_buffer) or self.source_stream.closed:
638+
if 0 <= b <= len(self.output_buffer) or self.__message_complete:
636639
_LOGGER.debug("No need to read from source stream or source stream closed")
637640
return
638641

test/functional/test_f_aws_encryption_sdk_client.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -761,9 +761,11 @@ def tell(self):
761761
raise NotImplementedError("NoTell does not tell().")
762762

763763

764-
class NoClose(ObjectProxy):
764+
class NoClosed(ObjectProxy):
765765
closed = NotImplemented
766766

767+
768+
class NoClose(ObjectProxy):
767769
def close(self):
768770
raise NotImplementedError("NoClose does not close().")
769771

@@ -772,6 +774,7 @@ def close(self):
772774
"wrapping_class",
773775
(
774776
NoTell,
777+
NoClosed,
775778
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
776779
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
777780
),
@@ -793,6 +796,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
793796
"wrapping_class",
794797
(
795798
NoTell,
799+
NoClosed,
796800
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
797801
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
798802
),
@@ -813,6 +817,7 @@ def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
813817
"wrapping_class",
814818
(
815819
NoTell,
820+
NoClosed,
816821
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
817822
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
818823
),

test/unit/test_streaming_client_stream_encryptor.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -451,10 +451,10 @@ def test_read_bytes_less_than_buffer(self, mock_read_non_framed, mock_read_frame
451451

452452
@patch("aws_encryption_sdk.streaming_client.StreamEncryptor._read_bytes_to_framed_body")
453453
@patch("aws_encryption_sdk.streaming_client.StreamEncryptor._read_bytes_to_non_framed_body")
454-
def test_read_bytes_closed(self, mock_read_non_framed, mock_read_framed):
454+
def test_read_bytes_completed(self, mock_read_non_framed, mock_read_framed):
455455
pt_stream = io.BytesIO(self.plaintext)
456456
test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider)
457-
test_encryptor.source_stream.close()
457+
test_encryptor._StreamEncryptor__message_complete = True
458458
test_encryptor._read_bytes(5)
459459
assert not mock_read_non_framed.called
460460
assert not mock_read_framed.called

0 commit comments

Comments
 (0)