Skip to content

Commit be3aef0

Browse files
authored
Merge pull request #259 from ProtonMail/less-memory-large-msgs
Reduce memory usage when AEAD en/decrypting large messages
2 parents b01f065 + 1fd5ec8 commit be3aef0

File tree

7 files changed

+217
-127
lines changed

7 files changed

+217
-127
lines changed

internal/byteutil/byteutil.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,16 +49,16 @@ func ShiftNBytesLeft(dst, x []byte, n int) {
4949
dst = append(dst, make([]byte, n/8)...)
5050
}
5151

52-
// XorBytesMut assumes equal input length, replaces X with X XOR Y
52+
// XorBytesMut replaces X with X XOR Y. len(X) must be >= len(Y).
5353
func XorBytesMut(X, Y []byte) {
54-
for i := 0; i < len(X); i++ {
54+
for i := 0; i < len(Y); i++ {
5555
X[i] ^= Y[i]
5656
}
5757
}
5858

59-
// XorBytes assumes equal input length, puts X XOR Y into Z
59+
// XorBytes puts X XOR Y into Z. len(Z) and len(X) must be >= len(Y).
6060
func XorBytes(Z, X, Y []byte) {
61-
for i := 0; i < len(X); i++ {
61+
for i := 0; i < len(Y); i++ {
6262
Z[i] = X[i] ^ Y[i]
6363
}
6464
}

ocb/ocb.go

Lines changed: 25 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,10 @@ func (o *ocb) Seal(dst, nonce, plaintext, adata []byte) []byte {
109109
if len(nonce) > o.nonceSize {
110110
panic("crypto/ocb: Incorrect nonce length given to OCB")
111111
}
112-
ret, out := byteutil.SliceForAppend(dst, len(plaintext)+o.tagSize)
113-
o.crypt(enc, out, nonce, adata, plaintext)
112+
sep := len(plaintext)
113+
ret, out := byteutil.SliceForAppend(dst, sep+o.tagSize)
114+
tag := o.crypt(enc, out[:sep], nonce, adata, plaintext)
115+
copy(out[sep:], tag)
114116
return ret
115117
}
116118

@@ -122,12 +124,10 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
122124
return nil, ocbError("Ciphertext shorter than tag length")
123125
}
124126
sep := len(ciphertext) - o.tagSize
125-
ret, out := byteutil.SliceForAppend(dst, len(ciphertext))
127+
ret, out := byteutil.SliceForAppend(dst, sep)
126128
ciphertextData := ciphertext[:sep]
127-
tag := ciphertext[sep:]
128-
o.crypt(dec, out, nonce, adata, ciphertextData)
129-
if subtle.ConstantTimeCompare(ret[sep:], tag) == 1 {
130-
ret = ret[:sep]
129+
tag := o.crypt(dec, out, nonce, adata, ciphertextData)
130+
if subtle.ConstantTimeCompare(tag, ciphertext[sep:]) == 1 {
131131
return ret, nil
132132
}
133133
for i := range out {
@@ -137,7 +137,8 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
137137
}
138138

139139
// On instruction enc (resp. dec), crypt is the encrypt (resp. decrypt)
140-
// function. It returns the resulting plain/ciphertext with the tag appended.
140+
// function. It writes the resulting plain/ciphertext into Y and returns
141+
// the tag.
141142
func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
142143
//
143144
// Consider X as a sequence of 128-bit blocks
@@ -194,13 +195,14 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
194195
byteutil.XorBytesMut(offset, o.mask.L[bits.TrailingZeros(uint(i+1))])
195196
blockX := X[i*blockSize : (i+1)*blockSize]
196197
blockY := Y[i*blockSize : (i+1)*blockSize]
197-
byteutil.XorBytes(blockY, blockX, offset)
198198
switch instruction {
199199
case enc:
200+
byteutil.XorBytesMut(checksum, blockX)
201+
byteutil.XorBytes(blockY, blockX, offset)
200202
o.block.Encrypt(blockY, blockY)
201203
byteutil.XorBytesMut(blockY, offset)
202-
byteutil.XorBytesMut(checksum, blockX)
203204
case dec:
205+
byteutil.XorBytes(blockY, blockX, offset)
204206
o.block.Decrypt(blockY, blockY)
205207
byteutil.XorBytesMut(blockY, offset)
206208
byteutil.XorBytesMut(checksum, blockY)
@@ -216,31 +218,24 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
216218
o.block.Encrypt(pad, offset)
217219
chunkX := X[blockSize*m:]
218220
chunkY := Y[blockSize*m : len(X)]
219-
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
220-
// P_* || bit(1) || zeroes(127) - len(P_*)
221221
switch instruction {
222222
case enc:
223-
paddedY := append(chunkX, byte(128))
224-
paddedY = append(paddedY, make([]byte, blockSize-len(chunkX)-1)...)
225-
byteutil.XorBytesMut(checksum, paddedY)
223+
byteutil.XorBytesMut(checksum, chunkX)
224+
checksum[len(chunkX)] ^= 128
225+
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
226+
// P_* || bit(1) || zeroes(127) - len(P_*)
226227
case dec:
227-
paddedX := append(chunkY, byte(128))
228-
paddedX = append(paddedX, make([]byte, blockSize-len(chunkY)-1)...)
229-
byteutil.XorBytesMut(checksum, paddedX)
228+
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
229+
// P_* || bit(1) || zeroes(127) - len(P_*)
230+
byteutil.XorBytesMut(checksum, chunkY)
231+
checksum[len(chunkY)] ^= 128
230232
}
231-
byteutil.XorBytes(tag, checksum, offset)
232-
byteutil.XorBytesMut(tag, o.mask.lDol)
233-
o.block.Encrypt(tag, tag)
234-
byteutil.XorBytesMut(tag, o.hash(adata))
235-
copy(Y[blockSize*m+len(chunkY):], tag[:o.tagSize])
236-
} else {
237-
byteutil.XorBytes(tag, checksum, offset)
238-
byteutil.XorBytesMut(tag, o.mask.lDol)
239-
o.block.Encrypt(tag, tag)
240-
byteutil.XorBytesMut(tag, o.hash(adata))
241-
copy(Y[blockSize*m:], tag[:o.tagSize])
242233
}
243-
return Y
234+
byteutil.XorBytes(tag, checksum, offset)
235+
byteutil.XorBytesMut(tag, o.mask.lDol)
236+
o.block.Encrypt(tag, tag)
237+
byteutil.XorBytesMut(tag, o.hash(adata))
238+
return tag[:o.tagSize]
244239
}
245240

246241
// This hash function is used to compute the tag. Per design, on empty input it

ocb/ocb_test.go

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,20 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
127127
adata, _ := hex.DecodeString(test.header)
128128
targetPt, _ := hex.DecodeString(test.plaintext)
129129
targetCt, _ := hex.DecodeString(test.ciphertext)
130-
ct := ocbInstance.Seal(nil, nonce, targetPt, adata)
131130
// Encrypt
131+
ct := ocbInstance.Seal(nil, nonce, targetPt, adata)
132+
if !bytes.Equal(ct, targetCt) {
133+
t.Errorf(
134+
`RFC7253 Test vectors Encrypt error (ciphertexts don't match):
135+
Got:
136+
%X
137+
Want:
138+
%X`, ct, targetCt)
139+
}
140+
// Encrypt reusing buffer
141+
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
142+
copy(pt, targetPt)
143+
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
132144
if !bytes.Equal(ct, targetCt) {
133145
t.Errorf(
134146
`RFC7253 Test vectors Encrypt error (ciphertexts don't match):
@@ -138,14 +150,14 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
138150
%X`, ct, targetCt)
139151
}
140152
// Decrypt
141-
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
153+
pt, err := ocbInstance.Open(nil, nonce, ct, adata)
142154
if err != nil {
143155
t.Errorf(
144156
`RFC7253 Valid ciphertext was refused decryption:
145157
plaintext %X
146158
nonce %X
147159
header %X
148-
ciphertext %X`, targetPt, nonce, adata, targetCt)
160+
ciphertext %X`, targetPt, nonce, adata, ct)
149161
}
150162
if !bytes.Equal(pt, targetPt) {
151163
t.Errorf(
@@ -155,6 +167,24 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
155167
Want:
156168
%X`, pt, targetPt)
157169
}
170+
// Decrypt reusing buffer
171+
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
172+
if err != nil {
173+
t.Errorf(
174+
`RFC7253 Valid ciphertext was refused decryption:
175+
plaintext %X
176+
nonce %X
177+
header %X
178+
ciphertext %X`, targetPt, nonce, adata, ct)
179+
}
180+
if !bytes.Equal(pt, targetPt) {
181+
t.Errorf(
182+
`RFC7253 test vectors Decrypt error (plaintexts don't match):
183+
Got:
184+
%X
185+
Want:
186+
%X`, targetPt, pt)
187+
}
158188
}
159189
}
160190

@@ -182,7 +212,30 @@ func TestEncryptDecryptRFC7253TagLen96(t *testing.T) {
182212
Want:
183213
%X`, ct, targetCt)
184214
}
185-
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
215+
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
216+
copy(pt, targetPt)
217+
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
218+
if !bytes.Equal(ct, targetCt) {
219+
t.Errorf(
220+
`RFC7253 test tagLen96 error (ciphertexts don't match):
221+
Got:
222+
%X
223+
Want:
224+
%X`, ct, targetCt)
225+
}
226+
pt, err = ocbInstance.Open(nil, nonce, ct, adata)
227+
if err != nil {
228+
t.Errorf(`RFC7253 test tagLen96 was refused decryption`)
229+
}
230+
if !bytes.Equal(pt, targetPt) {
231+
t.Errorf(
232+
`RFC7253 test tagLen96 error (plaintexts don't match):
233+
Got:
234+
%X
235+
Want:
236+
%X`, pt, targetPt)
237+
}
238+
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
186239
if err != nil {
187240
t.Errorf(`RFC7253 test tagLen96 was refused decryption`)
188241
}
@@ -274,15 +327,47 @@ func TestEncryptDecryptGoTestVectors(t *testing.T) {
274327
%X`, ct, targetCt)
275328
}
276329

330+
// Encrypt reusing buffer
331+
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
332+
copy(pt, targetPt)
333+
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
334+
if !bytes.Equal(ct, targetCt) {
335+
t.Errorf(
336+
`Go Test vectors Encrypt error (ciphertexts don't match):
337+
Got:
338+
%X
339+
Want:
340+
%X`, ct, targetCt)
341+
}
342+
277343
// Decrypt
278-
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
344+
pt, err = ocbInstance.Open(nil, nonce, ct, adata)
279345
if err != nil {
280346
t.Errorf(
281347
`Valid Go ciphertext was refused decryption:
282348
plaintext %X
283349
nonce %X
284350
header %X
285-
ciphertext %X`, targetPt, nonce, adata, targetCt)
351+
ciphertext %X`, targetPt, nonce, adata, ct)
352+
}
353+
if !bytes.Equal(pt, targetPt) {
354+
t.Errorf(
355+
`Go Test vectors Decrypt error (plaintexts don't match):
356+
Got:
357+
%X
358+
Want:
359+
%X`, pt, targetPt)
360+
}
361+
362+
// Decrypt reusing buffer
363+
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
364+
if err != nil {
365+
t.Errorf(
366+
`Valid Go ciphertext was refused decryption:
367+
plaintext %X
368+
nonce %X
369+
header %X
370+
ciphertext %X`, targetPt, nonce, adata, ct)
286371
}
287372
if !bytes.Equal(pt, targetPt) {
288373
t.Errorf(
@@ -333,6 +418,17 @@ func TestEncryptDecryptVectorsWithPreviousDataRandomizeSlow(t *testing.T) {
333418
`Random Encrypt/Decrypt error (plaintexts don't match)`)
334419
break
335420
}
421+
decrypted, err = ocb.Open(ct[:0], nonce, ct, header)
422+
if err != nil {
423+
t.Errorf(
424+
`Decrypt refused valid tag (not displaying long output)`)
425+
break
426+
}
427+
if !bytes.Equal(pt, decrypted) {
428+
t.Errorf(
429+
`Random Encrypt/Decrypt error (plaintexts don't match)`)
430+
break
431+
}
336432
}
337433
}
338434

@@ -369,6 +465,12 @@ func TestRejectTamperedCiphertextRandomizeSlow(t *testing.T) {
369465
"Tampered ciphertext was not refused decryption (OCB did not return an error)")
370466
return
371467
}
468+
_, err = ocb.Open(tampered[:0], nonce, tampered, header)
469+
if err == nil {
470+
t.Errorf(
471+
"Tampered ciphertext was not refused decryption (OCB did not return an error)")
472+
return
473+
}
372474
}
373475

374476
func TestParameters(t *testing.T) {

0 commit comments

Comments
 (0)