From f4ddcf82e6e62ba5a693296328f2f3122b406d69 Mon Sep 17 00:00:00 2001 From: mattsb42-aws Date: Thu, 30 Aug 2018 14:15:59 -0700 Subject: [PATCH] allow duck-typing of source streams --- CHANGELOG.rst | 4 +++- .../internal/utils/__init__.py | 17 ++++---------- test/unit/test_utils.py | 23 ++++++++++++------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index a9847f62e..b4a1ceeee 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,13 +2,15 @@ Changelog ********* -1.3.6 -- 2018-08-29 +1.3.6 -- 2018-08-xx =================== Bugfixes -------- * :class:`StreamEncryptor` and :class:`StreamDecryptor` should always report as readable if they are open. `#73 `_ +* Allow duck-typing of source streams. + `#75 `_ 1.3.5 -- 2018-08-01 =================== diff --git a/src/aws_encryption_sdk/internal/utils/__init__.py b/src/aws_encryption_sdk/internal/utils/__init__.py index f9d705e2b..065722d0d 100644 --- a/src/aws_encryption_sdk/internal/utils/__init__.py +++ b/src/aws_encryption_sdk/internal/utils/__init__.py @@ -127,24 +127,17 @@ def prepare_data_keys(primary_master_key, master_keys, algorithm, encryption_con return data_encryption_key, encrypted_data_keys -try: - _FILE_TYPE = file # Python 2 -except NameError: - _FILE_TYPE = io.IOBase # Python 3 pylint: disable=invalid-name - - def prep_stream_data(data): - """Takes an input str, bytes, io.IOBase, or file object and returns an appropriate - stream for _EncryptionStream objects. + """Take an input and prepare it for use as a stream. :param data: Input data - :type data: str, bytes, io.IOBase, or file :returns: Prepared stream :rtype: io.BytesIO """ - if isinstance(data, (_FILE_TYPE, io.IOBase, six.StringIO)): - return data - return io.BytesIO(to_bytes(data)) + if isinstance(data, (six.string_types, six.binary_type)): + return io.BytesIO(to_bytes(data)) + + return data def source_data_key_length_check(source_data_key, algorithm): diff --git a/test/unit/test_utils.py b/test/unit/test_utils.py index 1883f6324..d74499787 100644 --- a/test/unit/test_utils.py +++ b/test/unit/test_utils.py @@ -1,3 +1,4 @@ +# coding: utf-8 # Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). You @@ -11,6 +12,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. """Test suite for aws_encryption_sdk.internal.utils""" +import io import unittest import pytest @@ -28,6 +30,19 @@ pytestmark = [pytest.mark.unit, pytest.mark.local] +def test_prep_stream_data_passthrough(): + test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.not_a_string_or_bytes) + + assert test is sentinel.not_a_string_or_bytes + + +@pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02")) +def test_prep_stream_data_wrap(source): + test = aws_encryption_sdk.internal.utils.prep_stream_data(source) + + assert isinstance(test, io.BytesIO) + + class TestUtils(unittest.TestCase): def setUp(self): # Set up mock key provider and keys @@ -235,14 +250,6 @@ def test_prepare_data_keys(self): [mock_encrypted_data_encryption_key, sentinel.encrypted_data_key_1, sentinel.encrypted_data_key_2] ) - @patch("aws_encryption_sdk.internal.utils.to_bytes", return_value=sentinel.bytes) - @patch("aws_encryption_sdk.internal.utils.io.BytesIO", return_value=sentinel.bytesio) - def test_prep_stream_data(self, mock_bytesio, mock_to_bytes): - test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.data) - mock_to_bytes.assert_called_once_with(sentinel.data) - mock_bytesio.assert_called_once_with(sentinel.bytes) - assert test is sentinel.bytesio - def test_source_data_key_length_check_valid(self): mock_algorithm = MagicMock() mock_algorithm.kdf_input_len = 5