Skip to content

Commit

Permalink
Merge pull request #259 from ProtonMail/less-memory-large-msgs
Browse files Browse the repository at this point in the history
Reduce memory usage when AEAD en/decrypting large messages
  • Loading branch information
twiss authored Dec 16, 2024
2 parents b01f065 + 1fd5ec8 commit be3aef0
Show file tree
Hide file tree
Showing 7 changed files with 217 additions and 127 deletions.
8 changes: 4 additions & 4 deletions internal/byteutil/byteutil.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,16 +49,16 @@ func ShiftNBytesLeft(dst, x []byte, n int) {
dst = append(dst, make([]byte, n/8)...)
}

// XorBytesMut assumes equal input length, replaces X with X XOR Y
// XorBytesMut replaces X with X XOR Y. len(X) must be >= len(Y).
func XorBytesMut(X, Y []byte) {
for i := 0; i < len(X); i++ {
for i := 0; i < len(Y); i++ {
X[i] ^= Y[i]
}
}

// XorBytes assumes equal input length, puts X XOR Y into Z
// XorBytes puts X XOR Y into Z. len(Z) and len(X) must be >= len(Y).
func XorBytes(Z, X, Y []byte) {
for i := 0; i < len(X); i++ {
for i := 0; i < len(Y); i++ {
Z[i] = X[i] ^ Y[i]
}
}
Expand Down
55 changes: 25 additions & 30 deletions ocb/ocb.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ func (o *ocb) Seal(dst, nonce, plaintext, adata []byte) []byte {
if len(nonce) > o.nonceSize {
panic("crypto/ocb: Incorrect nonce length given to OCB")
}
ret, out := byteutil.SliceForAppend(dst, len(plaintext)+o.tagSize)
o.crypt(enc, out, nonce, adata, plaintext)
sep := len(plaintext)
ret, out := byteutil.SliceForAppend(dst, sep+o.tagSize)
tag := o.crypt(enc, out[:sep], nonce, adata, plaintext)
copy(out[sep:], tag)
return ret
}

Expand All @@ -122,12 +124,10 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
return nil, ocbError("Ciphertext shorter than tag length")
}
sep := len(ciphertext) - o.tagSize
ret, out := byteutil.SliceForAppend(dst, len(ciphertext))
ret, out := byteutil.SliceForAppend(dst, sep)
ciphertextData := ciphertext[:sep]
tag := ciphertext[sep:]
o.crypt(dec, out, nonce, adata, ciphertextData)
if subtle.ConstantTimeCompare(ret[sep:], tag) == 1 {
ret = ret[:sep]
tag := o.crypt(dec, out, nonce, adata, ciphertextData)
if subtle.ConstantTimeCompare(tag, ciphertext[sep:]) == 1 {
return ret, nil
}
for i := range out {
Expand All @@ -137,7 +137,8 @@ func (o *ocb) Open(dst, nonce, ciphertext, adata []byte) ([]byte, error) {
}

// On instruction enc (resp. dec), crypt is the encrypt (resp. decrypt)
// function. It returns the resulting plain/ciphertext with the tag appended.
// function. It writes the resulting plain/ciphertext into Y and returns
// the tag.
func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
//
// Consider X as a sequence of 128-bit blocks
Expand Down Expand Up @@ -194,13 +195,14 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
byteutil.XorBytesMut(offset, o.mask.L[bits.TrailingZeros(uint(i+1))])
blockX := X[i*blockSize : (i+1)*blockSize]
blockY := Y[i*blockSize : (i+1)*blockSize]
byteutil.XorBytes(blockY, blockX, offset)
switch instruction {
case enc:
byteutil.XorBytesMut(checksum, blockX)
byteutil.XorBytes(blockY, blockX, offset)
o.block.Encrypt(blockY, blockY)
byteutil.XorBytesMut(blockY, offset)
byteutil.XorBytesMut(checksum, blockX)
case dec:
byteutil.XorBytes(blockY, blockX, offset)
o.block.Decrypt(blockY, blockY)
byteutil.XorBytesMut(blockY, offset)
byteutil.XorBytesMut(checksum, blockY)
Expand All @@ -216,31 +218,24 @@ func (o *ocb) crypt(instruction int, Y, nonce, adata, X []byte) []byte {
o.block.Encrypt(pad, offset)
chunkX := X[blockSize*m:]
chunkY := Y[blockSize*m : len(X)]
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
// P_* || bit(1) || zeroes(127) - len(P_*)
switch instruction {
case enc:
paddedY := append(chunkX, byte(128))
paddedY = append(paddedY, make([]byte, blockSize-len(chunkX)-1)...)
byteutil.XorBytesMut(checksum, paddedY)
byteutil.XorBytesMut(checksum, chunkX)
checksum[len(chunkX)] ^= 128
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
// P_* || bit(1) || zeroes(127) - len(P_*)
case dec:
paddedX := append(chunkY, byte(128))
paddedX = append(paddedX, make([]byte, blockSize-len(chunkY)-1)...)
byteutil.XorBytesMut(checksum, paddedX)
byteutil.XorBytes(chunkY, chunkX, pad[:len(chunkX)])
// P_* || bit(1) || zeroes(127) - len(P_*)
byteutil.XorBytesMut(checksum, chunkY)
checksum[len(chunkY)] ^= 128
}
byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
copy(Y[blockSize*m+len(chunkY):], tag[:o.tagSize])
} else {
byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
copy(Y[blockSize*m:], tag[:o.tagSize])
}
return Y
byteutil.XorBytes(tag, checksum, offset)
byteutil.XorBytesMut(tag, o.mask.lDol)
o.block.Encrypt(tag, tag)
byteutil.XorBytesMut(tag, o.hash(adata))
return tag[:o.tagSize]
}

// This hash function is used to compute the tag. Per design, on empty input it
Expand Down
114 changes: 108 additions & 6 deletions ocb/ocb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,20 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
adata, _ := hex.DecodeString(test.header)
targetPt, _ := hex.DecodeString(test.plaintext)
targetCt, _ := hex.DecodeString(test.ciphertext)
ct := ocbInstance.Seal(nil, nonce, targetPt, adata)
// Encrypt
ct := ocbInstance.Seal(nil, nonce, targetPt, adata)
if !bytes.Equal(ct, targetCt) {
t.Errorf(
`RFC7253 Test vectors Encrypt error (ciphertexts don't match):
Got:
%X
Want:
%X`, ct, targetCt)
}
// Encrypt reusing buffer
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
copy(pt, targetPt)
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
if !bytes.Equal(ct, targetCt) {
t.Errorf(
`RFC7253 Test vectors Encrypt error (ciphertexts don't match):
Expand All @@ -138,14 +150,14 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
%X`, ct, targetCt)
}
// Decrypt
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
pt, err := ocbInstance.Open(nil, nonce, ct, adata)
if err != nil {
t.Errorf(
`RFC7253 Valid ciphertext was refused decryption:
plaintext %X
nonce %X
header %X
ciphertext %X`, targetPt, nonce, adata, targetCt)
ciphertext %X`, targetPt, nonce, adata, ct)
}
if !bytes.Equal(pt, targetPt) {
t.Errorf(
Expand All @@ -155,6 +167,24 @@ func TestEncryptDecryptRFC7253TestVectors(t *testing.T) {
Want:
%X`, pt, targetPt)
}
// Decrypt reusing buffer
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
if err != nil {
t.Errorf(
`RFC7253 Valid ciphertext was refused decryption:
plaintext %X
nonce %X
header %X
ciphertext %X`, targetPt, nonce, adata, ct)
}
if !bytes.Equal(pt, targetPt) {
t.Errorf(
`RFC7253 test vectors Decrypt error (plaintexts don't match):
Got:
%X
Want:
%X`, targetPt, pt)
}
}
}

Expand Down Expand Up @@ -182,7 +212,30 @@ func TestEncryptDecryptRFC7253TagLen96(t *testing.T) {
Want:
%X`, ct, targetCt)
}
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
copy(pt, targetPt)
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
if !bytes.Equal(ct, targetCt) {
t.Errorf(
`RFC7253 test tagLen96 error (ciphertexts don't match):
Got:
%X
Want:
%X`, ct, targetCt)
}
pt, err = ocbInstance.Open(nil, nonce, ct, adata)
if err != nil {
t.Errorf(`RFC7253 test tagLen96 was refused decryption`)
}
if !bytes.Equal(pt, targetPt) {
t.Errorf(
`RFC7253 test tagLen96 error (plaintexts don't match):
Got:
%X
Want:
%X`, pt, targetPt)
}
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
if err != nil {
t.Errorf(`RFC7253 test tagLen96 was refused decryption`)
}
Expand Down Expand Up @@ -274,15 +327,47 @@ func TestEncryptDecryptGoTestVectors(t *testing.T) {
%X`, ct, targetCt)
}

// Encrypt reusing buffer
pt := make([]byte, len(targetPt) + ocbInstance.Overhead())
copy(pt, targetPt)
ct = ocbInstance.Seal(pt[:0], nonce, pt[:len(targetPt)], adata)
if !bytes.Equal(ct, targetCt) {
t.Errorf(
`Go Test vectors Encrypt error (ciphertexts don't match):
Got:
%X
Want:
%X`, ct, targetCt)
}

// Decrypt
pt, err := ocbInstance.Open(nil, nonce, targetCt, adata)
pt, err = ocbInstance.Open(nil, nonce, ct, adata)
if err != nil {
t.Errorf(
`Valid Go ciphertext was refused decryption:
plaintext %X
nonce %X
header %X
ciphertext %X`, targetPt, nonce, adata, targetCt)
ciphertext %X`, targetPt, nonce, adata, ct)
}
if !bytes.Equal(pt, targetPt) {
t.Errorf(
`Go Test vectors Decrypt error (plaintexts don't match):
Got:
%X
Want:
%X`, pt, targetPt)
}

// Decrypt reusing buffer
pt, err = ocbInstance.Open(ct[:0], nonce, ct, adata)
if err != nil {
t.Errorf(
`Valid Go ciphertext was refused decryption:
plaintext %X
nonce %X
header %X
ciphertext %X`, targetPt, nonce, adata, ct)
}
if !bytes.Equal(pt, targetPt) {
t.Errorf(
Expand Down Expand Up @@ -333,6 +418,17 @@ func TestEncryptDecryptVectorsWithPreviousDataRandomizeSlow(t *testing.T) {
`Random Encrypt/Decrypt error (plaintexts don't match)`)
break
}
decrypted, err = ocb.Open(ct[:0], nonce, ct, header)
if err != nil {
t.Errorf(
`Decrypt refused valid tag (not displaying long output)`)
break
}
if !bytes.Equal(pt, decrypted) {
t.Errorf(
`Random Encrypt/Decrypt error (plaintexts don't match)`)
break
}
}
}

Expand Down Expand Up @@ -369,6 +465,12 @@ func TestRejectTamperedCiphertextRandomizeSlow(t *testing.T) {
"Tampered ciphertext was not refused decryption (OCB did not return an error)")
return
}
_, err = ocb.Open(tampered[:0], nonce, tampered, header)
if err == nil {
t.Errorf(
"Tampered ciphertext was not refused decryption (OCB did not return an error)")
return
}
}

func TestParameters(t *testing.T) {
Expand Down
Loading

0 comments on commit be3aef0

Please sign in to comment.