12
12
# language governing permissions and limitations under the License.
13
13
"""Components for handling AWS Encryption SDK message deserialization."""
14
14
from __future__ import division
15
+ import io
15
16
import logging
16
17
import struct
17
18
29
30
EncryptedData , MessageFooter ,
30
31
MessageFrameBody , MessageHeaderAuthentication
31
32
)
33
+ from aws_encryption_sdk .internal .utils .streams import TeeStream
32
34
from aws_encryption_sdk .structures import EncryptedDataKey , MasterKeyInfo , MessageHeader
33
35
34
36
_LOGGER = logging .getLogger (__name__ )
35
37
36
38
37
- def validate_header (header , header_auth , stream , header_start , header_end , data_key ):
39
+ def validate_header (header , header_auth , raw_header , data_key ):
38
40
"""Validates the header using the header authentication data.
39
41
40
42
:param header: Deserialized header
41
43
:type header: aws_encryption_sdk.structures.MessageHeader
42
44
:param header_auth: Deserialized header auth
43
45
:type header_auth: aws_encryption_sdk.internal.structures.MessageHeaderAuthentication
44
- :param stream: Stream containing serialized message
45
46
:type stream: io.BytesIO
46
- :param int header_start: Position in stream of start of serialized header
47
- :param int header_end: Position in stream of end of serialized header
47
+ :param bytes raw_header: Raw header bytes
48
48
:param bytes data_key: Data key with which to perform validation
49
49
:raises SerializationError: if header authorization fails
50
50
"""
51
51
_LOGGER .debug ('Starting header validation' )
52
- current_position = stream .tell ()
53
- stream .seek (header_start )
54
52
try :
55
53
decrypt (
56
54
algorithm = header .algorithm ,
57
55
key = data_key ,
58
56
encrypted_data = EncryptedData (header_auth .iv , b'' , header_auth .tag ),
59
- associated_data = stream . read ( header_end - header_start )
57
+ associated_data = raw_header
60
58
)
61
59
except InvalidTag :
62
60
raise SerializationError ('Header authorization failed' )
63
- stream .seek (current_position )
64
61
65
62
66
63
def deserialize_header (stream ):
@@ -69,13 +66,15 @@ def deserialize_header(stream):
69
66
:param stream: Source data stream
70
67
:type stream: io.BytesIO
71
68
:returns: Deserialized MessageHeader object
72
- :rtype: aws_encryption_sdk.structures.MessageHeader
69
+ :rtype: :class:` aws_encryption_sdk.structures.MessageHeader` and bytes
73
70
:raises NotSupportedError: if unsupported data types are found
74
71
:raises UnknownIdentityError: if unknown data types are found
75
72
:raises SerializationError: if IV length does not match algorithm
76
73
"""
77
74
_LOGGER .debug ('Starting header deserialization' )
78
- version_id , message_type_id = unpack_values ('>BB' , stream )
75
+ tee = io .BytesIO ()
76
+ tee_stream = TeeStream (stream , tee )
77
+ version_id , message_type_id = unpack_values ('>BB' , tee_stream )
79
78
try :
80
79
message_type = ObjectType (message_type_id )
81
80
except ValueError as error :
@@ -89,7 +88,7 @@ def deserialize_header(stream):
89
88
raise NotSupportedError ('Unsupported version {}' .format (version_id ), error )
90
89
header = {'version' : version , 'type' : message_type }
91
90
92
- algorithm_id , message_id , ser_encryption_context_length = unpack_values ('>H16sH' , stream )
91
+ algorithm_id , message_id , ser_encryption_context_length = unpack_values ('>H16sH' , tee_stream )
93
92
94
93
try :
95
94
alg = Algorithm .get_by_id (algorithm_id )
@@ -101,24 +100,24 @@ def deserialize_header(stream):
101
100
header ['message_id' ] = message_id
102
101
103
102
header ['encryption_context' ] = deserialize_encryption_context (
104
- stream .read (ser_encryption_context_length )
103
+ tee_stream .read (ser_encryption_context_length )
105
104
)
106
- (encrypted_data_key_count ,) = unpack_values ('>H' , stream )
105
+ (encrypted_data_key_count ,) = unpack_values ('>H' , tee_stream )
107
106
108
107
encrypted_data_keys = set ([])
109
108
for _ in range (encrypted_data_key_count ):
110
- (key_provider_length ,) = unpack_values ('>H' , stream )
109
+ (key_provider_length ,) = unpack_values ('>H' , tee_stream )
111
110
(key_provider_identifier ,) = unpack_values (
112
111
'>{}s' .format (key_provider_length ),
113
- stream
112
+ tee_stream
114
113
)
115
- (key_provider_information_length ,) = unpack_values ('>H' , stream )
114
+ (key_provider_information_length ,) = unpack_values ('>H' , tee_stream )
116
115
(key_provider_information ,) = unpack_values (
117
116
'>{}s' .format (key_provider_information_length ),
118
- stream
117
+ tee_stream
119
118
)
120
- (encrypted_data_key_length ,) = unpack_values ('>H' , stream )
121
- encrypted_data_key = stream .read (encrypted_data_key_length )
119
+ (encrypted_data_key_length ,) = unpack_values ('>H' , tee_stream )
120
+ encrypted_data_key = tee_stream .read (encrypted_data_key_length )
122
121
encrypted_data_keys .add (EncryptedDataKey (
123
122
key_provider = MasterKeyInfo (
124
123
provider_id = to_str (key_provider_identifier ),
@@ -128,7 +127,7 @@ def deserialize_header(stream):
128
127
))
129
128
header ['encrypted_data_keys' ] = encrypted_data_keys
130
129
131
- (content_type_id ,) = unpack_values ('>B' , stream )
130
+ (content_type_id ,) = unpack_values ('>B' , tee_stream )
132
131
try :
133
132
content_type = ContentType (content_type_id )
134
133
except ValueError as error :
@@ -138,14 +137,14 @@ def deserialize_header(stream):
138
137
)
139
138
header ['content_type' ] = content_type
140
139
141
- (content_aad_length ,) = unpack_values ('>I' , stream )
140
+ (content_aad_length ,) = unpack_values ('>I' , tee_stream )
142
141
if content_aad_length != 0 :
143
142
raise SerializationError (
144
143
'Content AAD length field is currently unused, its value must be always 0'
145
144
)
146
145
header ['content_aad_length' ] = 0
147
146
148
- (iv_length ,) = unpack_values ('>B' , stream )
147
+ (iv_length ,) = unpack_values ('>B' , tee_stream )
149
148
if iv_length != alg .iv_len :
150
149
raise SerializationError (
151
150
'Specified IV length ({length}) does not match algorithm IV length ({alg})' .format (
@@ -155,7 +154,7 @@ def deserialize_header(stream):
155
154
)
156
155
header ['header_iv_length' ] = iv_length
157
156
158
- (frame_length ,) = unpack_values ('>I' , stream )
157
+ (frame_length ,) = unpack_values ('>I' , tee_stream )
159
158
if content_type == ContentType .FRAMED_DATA and frame_length > MAX_FRAME_SIZE :
160
159
raise SerializationError ('Specified frame length larger than allowed maximum: {found} > {max}' .format (
161
160
found = frame_length ,
@@ -165,7 +164,7 @@ def deserialize_header(stream):
165
164
raise SerializationError ('Non-zero frame length found for non-framed message' )
166
165
header ['frame_length' ] = frame_length
167
166
168
- return MessageHeader (** header )
167
+ return MessageHeader (** header ), tee . getvalue ()
169
168
170
169
171
170
def deserialize_header_auth (stream , algorithm , verifier = None ):
0 commit comments