Skip to content

Commit 6fa7f91

Browse files
committed
Preallocate the chunk size rather than buffering
Since the chunk size is capped at 4MB now, we can safely preallocate it so that we don't have to buffer each chunk.
1 parent add07bd commit 6fa7f91

File tree

4 files changed

+29
-41
lines changed

4 files changed

+29
-41
lines changed

openpgp/packet/aead_crypter.go

+17-37
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ func (wo *aeadCrypter) incrementIndex() error {
6262
type aeadDecrypter struct {
6363
aeadCrypter // Embedded ciphertext opener
6464
reader io.Reader // 'reader' is a partialLengthReader
65+
chunkBytes []byte
6566
peekedBytes []byte // Used to detect last chunk
66-
eof bool
6767
}
6868

6969
// Read decrypts bytes and reads them into dst. It decrypts when necessary and
@@ -75,22 +75,18 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
7575
return ar.buffer.Read(dst)
7676
}
7777

78-
// Return EOF if we've previously validated the final tag
79-
if ar.eof {
80-
return 0, io.EOF
81-
}
82-
8378
// Read a chunk
8479
tagLen := ar.aead.Overhead()
85-
cipherChunkBuf := new(bytes.Buffer)
86-
_, errRead := io.CopyN(cipherChunkBuf, ar.reader, int64(ar.chunkSize+tagLen))
87-
cipherChunk := cipherChunkBuf.Bytes()
88-
if errRead != nil && errRead != io.EOF {
80+
copy(ar.chunkBytes, ar.peekedBytes) // Copy bytes peeked in previous chunk or in initialization
81+
bytesRead, errRead := io.ReadFull(ar.reader, ar.chunkBytes[tagLen:])
82+
if errRead != nil && errRead != io.EOF && errRead != io.ErrUnexpectedEOF {
8983
return 0, errRead
9084
}
9185

92-
if len(cipherChunk) > 0 {
93-
decrypted, errChunk := ar.openChunk(cipherChunk)
86+
if bytesRead > 0 {
87+
ar.peekedBytes = ar.chunkBytes[bytesRead:bytesRead+tagLen]
88+
89+
decrypted, errChunk := ar.openChunk(ar.chunkBytes[:bytesRead])
9490
if errChunk != nil {
9591
return 0, errChunk
9692
}
@@ -102,28 +98,19 @@ func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
10298
} else {
10399
n = copy(dst, decrypted)
104100
}
101+
return
105102
}
106103

107-
// Check final authentication tag
108-
if errRead == io.EOF {
109-
errChunk := ar.validateFinalTag(ar.peekedBytes)
110-
if errChunk != nil {
111-
return n, errChunk
112-
}
113-
ar.eof = true // Mark EOF for when we've returned all buffered data
114-
}
115-
return
104+
return 0, io.EOF
116105
}
117106

118-
// Close is noOp. The final authentication tag of the stream was already
119-
// checked in the last Read call. In the future, this function could be used to
120-
// wipe the reader and peeked, decrypted bytes, if necessary.
107+
// Close checks the final authentication tag of the stream.
108+
// In the future, this function could also be used to wipe the reader
109+
// and peeked & decrypted bytes, if necessary.
121110
func (ar *aeadDecrypter) Close() (err error) {
122-
if !ar.eof {
123-
errChunk := ar.validateFinalTag(ar.peekedBytes)
124-
if errChunk != nil {
125-
return errChunk
126-
}
111+
errChunk := ar.validateFinalTag(ar.peekedBytes)
112+
if errChunk != nil {
113+
return errChunk
127114
}
128115
return nil
129116
}
@@ -132,20 +119,13 @@ func (ar *aeadDecrypter) Close() (err error) {
132119
// the underlying plaintext and an error. It accesses peeked bytes from next
133120
// chunk, to identify the last chunk and decrypt/validate accordingly.
134121
func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) {
135-
tagLen := ar.aead.Overhead()
136-
// Restore carried bytes from last call
137-
chunkExtra := append(ar.peekedBytes, data...)
138-
// 'chunk' contains encrypted bytes, followed by an authentication tag.
139-
chunk := chunkExtra[:len(chunkExtra)-tagLen]
140-
ar.peekedBytes = chunkExtra[len(chunkExtra)-tagLen:]
141-
142122
adata := ar.associatedData
143123
if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted {
144124
adata = append(ar.associatedData, ar.chunkIndex...)
145125
}
146126

147127
nonce := ar.computeNextNonce()
148-
plainChunk, err := ar.aead.Open(nil, nonce, chunk, adata)
128+
plainChunk, err := ar.aead.Open(nil, nonce, data, adata)
149129
if err != nil {
150130
return nil, errors.ErrAEADTagVerification
151131
}

openpgp/packet/aead_encrypted.go

+7-3
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,15 @@ func (ae *AEADEncrypted) decrypt(key []byte) (io.ReadCloser, error) {
6565
blockCipher := ae.cipher.new(key)
6666
aead := ae.mode.new(blockCipher)
6767
// Carry the first tagLen bytes
68+
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)
6869
tagLen := ae.mode.TagLength()
69-
peekedBytes := make([]byte, tagLen)
70+
chunkBytes := make([]byte, chunkSize+tagLen*2)
71+
peekedBytes := chunkBytes[chunkSize+tagLen:]
7072
n, err := io.ReadFull(ae.Contents, peekedBytes)
7173
if n < tagLen || (err != nil && err != io.EOF) {
7274
return nil, errors.AEADError("Not enough data to decrypt:" + err.Error())
7375
}
74-
chunkSize := decodeAEADChunkSize(ae.chunkSizeByte)
76+
7577
return &aeadDecrypter{
7678
aeadCrypter: aeadCrypter{
7779
aead: aead,
@@ -82,7 +84,9 @@ func (ae *AEADEncrypted) decrypt(key []byte) (io.ReadCloser, error) {
8284
packetTag: packetTypeAEADEncrypted,
8385
},
8486
reader: ae.Contents,
85-
peekedBytes: peekedBytes}, nil
87+
chunkBytes: chunkBytes,
88+
peekedBytes: peekedBytes,
89+
}, nil
8690
}
8791

8892
// associatedData for chunks: tag, version, cipher, mode, chunk size byte

openpgp/packet/aead_encrypted_test.go

+1
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ func readDecryptedStream(rc io.ReadCloser) (got []byte, err error) {
407407
}
408408
}
409409
}
410+
err = rc.Close()
410411
return got, err
411412
}
412413

openpgp/packet/symmetrically_encrypted_aead.go

+4-1
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,10 @@ func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, e
7070

7171
aead, nonce := getSymmetricallyEncryptedAeadInstance(se.Cipher, se.Mode, inputKey, se.Salt[:], se.associatedData())
7272
// Carry the first tagLen bytes
73+
chunkSize := decodeAEADChunkSize(se.ChunkSizeByte)
7374
tagLen := se.Mode.TagLength()
74-
peekedBytes := make([]byte, tagLen)
75+
chunkBytes := make([]byte, chunkSize+tagLen*2)
76+
peekedBytes := chunkBytes[chunkSize+tagLen:]
7577
n, err := io.ReadFull(se.Contents, peekedBytes)
7678
if n < tagLen || (err != nil && err != io.EOF) {
7779
return nil, errors.StructuralError("not enough data to decrypt:" + err.Error())
@@ -87,6 +89,7 @@ func (se *SymmetricallyEncrypted) decryptAead(inputKey []byte) (io.ReadCloser, e
8789
packetTag: packetTypeSymmetricallyEncryptedIntegrityProtected,
8890
},
8991
reader: se.Contents,
92+
chunkBytes: chunkBytes,
9093
peekedBytes: peekedBytes,
9194
}, nil
9295
}

0 commit comments

Comments
 (0)