Merge pull request #5622 from Roasbeef/wire-serdes-optimization

peer+brontide: when decrypting re-use the allocated ciphertext buf & ensure buf pool buf doesn't escape
This commit is contained in:
Olaoluwa Osuntokun
2021-08-30 11:45:57 -07:00
committed by GitHub
4 changed files with 110 additions and 18 deletions

66
brontide/bench_test.go Normal file
View File

@@ -0,0 +1,66 @@
package brontide
import (
"bytes"
"math"
"math/rand"
"testing"
"time"
"github.com/stretchr/testify/require"
)
func BenchmarkReadHeaderAndBody(t *testing.B) {
// Create a test connection, grabbing either side of the connection
// into local variables. If the initial crypto handshake fails, then
// we'll get a non-nil error here.
localConn, remoteConn, cleanUp, err := establishTestConnection()
require.NoError(t, err, "unable to establish test connection: %v", err)
defer cleanUp()
rand.Seed(time.Now().Unix())
noiseRemoteConn := remoteConn.(*Conn)
noiseLocalConn := localConn.(*Conn)
// Now that we have a local and remote side (to set up the initial
// handshake state, we'll have the remote side write out something
// similar to a large message in the protocol.
const pktSize = 60_000
msg := bytes.Repeat([]byte("a"), pktSize)
err = noiseRemoteConn.WriteMessage(msg)
require.NoError(t, err, "unable to write encrypted message: %v", err)
cipherHeader := noiseRemoteConn.noise.nextHeaderSend
cipherMsg := noiseRemoteConn.noise.nextBodySend
var (
benchErr error
msgBuf [math.MaxUint16]byte
)
t.ReportAllocs()
t.ResetTimer()
nonceValue := noiseLocalConn.noise.recvCipher.nonce
for i := 0; i < t.N; i++ {
pktLen, benchErr := noiseLocalConn.noise.ReadHeader(
bytes.NewReader(cipherHeader),
)
require.NoError(
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
)
_, benchErr = noiseLocalConn.noise.ReadBody(
bytes.NewReader(cipherMsg), msgBuf[:pktLen],
)
require.NoError(
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
)
// We reset the internal nonce each time as otherwise, we'd
// continue to increment it which would cause a decryption
// failure.
noiseLocalConn.noise.recvCipher.nonce = nonceValue
}
require.NoError(t, benchErr)
}

View File

@@ -854,8 +854,11 @@ func (b *Machine) ReadHeader(r io.Reader) (uint32, error) {
} }
// Attempt to decrypt+auth the packet length present in the stream. // Attempt to decrypt+auth the packet length present in the stream.
//
// By passing in `nextCipherHeader` as the destination, we avoid making
// the library allocate a new buffer to decode the plaintext.
pktLenBytes, err := b.recvCipher.Decrypt( pktLenBytes, err := b.recvCipher.Decrypt(
nil, nil, b.nextCipherHeader[:], nil, b.nextCipherHeader[:0], b.nextCipherHeader[:],
) )
if err != nil { if err != nil {
return 0, err return 0, err
@@ -880,10 +883,13 @@ func (b *Machine) ReadBody(r io.Reader, buf []byte) ([]byte, error) {
return nil, err return nil, err
} }
// Finally, decrypt the message held in the buffer, and return a // Finally, decrypt the message held in the buffer, and return a new
// new byte slice containing the plaintext. // byte slice containing the plaintext.
// TODO(roasbeef): modify to let pass in slice //
return b.recvCipher.Decrypt(nil, nil, buf) // By passing in the buf (the ciphertext) as the first argument, we end
// up re-using it as we don't force the library to allocate a new
// buffer to decode the plaintext.
return b.recvCipher.Decrypt(nil, buf[:0], buf)
} }
// SetCurveToNil sets the 'Curve' parameter to nil on the handshakeState keys. // SetCurveToNil sets the 'Curve' parameter to nil on the handshakeState keys.

View File

@@ -200,7 +200,13 @@ you.
transaction each time a private key needs to be derived for signing or ECDH transaction each time a private key needs to be derived for signing or ECDH
operations]https://github.com/lightningnetwork/lnd/pull/5629). This results operations]https://github.com/lightningnetwork/lnd/pull/5629). This results
in a massive performance improvement across several routine operations at the in a massive performance improvement across several routine operations at the
cost of a small amount of memory allocated for a new cache.
* [When decrypting incoming encrypted brontide messages on the wire, we'll now
properly re-use the buffer that was allocated for the ciphertext to store the
plaintext]https://github.com/lightningnetwork/lnd/pull/5622). When combined
with the buffer pool, this ensures that we no longer need to allocate a new
buffer each time we decrypt an incoming message, as we
recycle these buffers in the peer.
## Log system ## Log system

View File

@@ -951,7 +951,10 @@ func (p *Brontide) readNextMessage() (lnwire.Message, error) {
// reading incrementally from the stream as the Lightning wire protocol // reading incrementally from the stream as the Lightning wire protocol
// is message oriented and allows nodes to pad on additional data to // is message oriented and allows nodes to pad on additional data to
// the message stream. // the message stream.
var rawMsg []byte var (
nextMsg lnwire.Message
msgLen uint64
)
err = p.cfg.ReadPool.Submit(func(buf *buffer.Read) error { err = p.cfg.ReadPool.Submit(func(buf *buffer.Read) error {
// Before reading the body of the message, set the read timeout // Before reading the body of the message, set the read timeout
// accordingly to ensure we don't block other readers using the // accordingly to ensure we don't block other readers using the
@@ -964,18 +967,29 @@ func (p *Brontide) readNextMessage() (lnwire.Message, error) {
return readErr return readErr
} }
rawMsg, readErr = noiseConn.ReadNextBody(buf[:pktLen]) // The ReadNextBody method will actually end up re-using the
return readErr // buffer, so within this closure, we can continue to use
}) // rawMsg as it's just a slice into the buf from the buffer
atomic.AddUint64(&p.bytesReceived, uint64(len(rawMsg))) // pool.
if err != nil { rawMsg, readErr := noiseConn.ReadNextBody(buf[:pktLen])
return nil, err if readErr != nil {
} return readErr
}
msgLen = uint64(len(rawMsg))
// Next, create a new io.Reader implementation from the raw message, // Next, create a new io.Reader implementation from the raw
// and use this to decode the message directly from. // message, and use this to decode the message directly from.
msgReader := bytes.NewReader(rawMsg) msgReader := bytes.NewReader(rawMsg)
nextMsg, err := lnwire.ReadMessage(msgReader, 0) nextMsg, err = lnwire.ReadMessage(msgReader, 0)
if err != nil {
return err
}
// At this point, rawMsg and buf will be returned back to the
// buffer pool for re-use.
return nil
})
atomic.AddUint64(&p.bytesReceived, msgLen)
if err != nil { if err != nil {
return nil, err return nil, err
} }