Skip to content

allow duck-typing of source streams #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
Changelog
*********

1.3.6 -- 2018-08-29
1.3.6 -- 2018-08-xx
===================

Bugfixes
--------
* :class:`StreamEncryptor` and :class:`StreamDecryptor` should always report as readable if they are open.
`#73 <https://github.com/aws/aws-encryption-sdk-python/issues/73>`_
* Allow duck-typing of source streams.
`#75 <https://github.com/aws/aws-encryption-sdk-python/issues/75>`_

1.3.5 -- 2018-08-01
===================
Expand Down
17 changes: 5 additions & 12 deletions src/aws_encryption_sdk/internal/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,24 +127,17 @@ def prepare_data_keys(primary_master_key, master_keys, algorithm, encryption_con
return data_encryption_key, encrypted_data_keys


try:
_FILE_TYPE = file # Python 2
except NameError:
_FILE_TYPE = io.IOBase # Python 3 pylint: disable=invalid-name


def prep_stream_data(data):
"""Takes an input str, bytes, io.IOBase, or file object and returns an appropriate
stream for _EncryptionStream objects.
"""Take an input and prepare it for use as a stream.

:param data: Input data
:type data: str, bytes, io.IOBase, or file
:returns: Prepared stream
:rtype: io.BytesIO
"""
if isinstance(data, (_FILE_TYPE, io.IOBase, six.StringIO)):
return data
return io.BytesIO(to_bytes(data))
if isinstance(data, (six.string_types, six.binary_type)):
return io.BytesIO(to_bytes(data))

return data


def source_data_key_length_check(source_data_key, algorithm):
Expand Down
23 changes: 15 additions & 8 deletions test/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding: utf-8
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
Expand All @@ -11,6 +12,7 @@
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.
"""Test suite for aws_encryption_sdk.internal.utils"""
import io
import unittest

import pytest
Expand All @@ -28,6 +30,19 @@
pytestmark = [pytest.mark.unit, pytest.mark.local]


def test_prep_stream_data_passthrough():
test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.not_a_string_or_bytes)

assert test is sentinel.not_a_string_or_bytes


@pytest.mark.parametrize("source", (u"some unicode data ловие", b"\x00\x01\x02"))
def test_prep_stream_data_wrap(source):
test = aws_encryption_sdk.internal.utils.prep_stream_data(source)

assert isinstance(test, io.BytesIO)


class TestUtils(unittest.TestCase):
def setUp(self):
# Set up mock key provider and keys
Expand Down Expand Up @@ -235,14 +250,6 @@ def test_prepare_data_keys(self):
[mock_encrypted_data_encryption_key, sentinel.encrypted_data_key_1, sentinel.encrypted_data_key_2]
)

@patch("aws_encryption_sdk.internal.utils.to_bytes", return_value=sentinel.bytes)
@patch("aws_encryption_sdk.internal.utils.io.BytesIO", return_value=sentinel.bytesio)
def test_prep_stream_data(self, mock_bytesio, mock_to_bytes):
test = aws_encryption_sdk.internal.utils.prep_stream_data(sentinel.data)
mock_to_bytes.assert_called_once_with(sentinel.data)
mock_bytesio.assert_called_once_with(sentinel.bytes)
assert test is sentinel.bytesio

def test_source_data_key_length_check_valid(self):
mock_algorithm = MagicMock()
mock_algorithm.kdf_input_len = 5
Expand Down