brontide: revise

This commit is contained in:
Olaoluwa Osuntokun
2025-09-10 18:22:52 -07:00
parent a58e329b2c
commit 3615a351ac
3 changed files with 125 additions and 32 deletions

View File

@@ -31,8 +31,9 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
err = noiseRemoteConn.WriteMessage(msg)
require.NoError(t, err, "unable to write encrypted message: %v", err)
cipherHeader := noiseRemoteConn.noise.nextHeaderSend
cipherMsg := noiseRemoteConn.noise.nextBodySend
noise := noiseRemoteConn.noise.Load()
cipherHeader := noise.nextHeaderSend
cipherMsg := noise.nextBodySend
var (
benchErr error
@@ -42,15 +43,16 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
t.ReportAllocs()
t.ResetTimer()
nonceValue := noiseLocalConn.noise.recvCipher.nonce
localNoise := noiseLocalConn.noise.Load()
nonceValue := localNoise.recvCipher.nonce
for i := 0; i < t.N; i++ {
pktLen, benchErr := noiseLocalConn.noise.ReadHeader(
pktLen, benchErr := localNoise.ReadHeader(
bytes.NewReader(cipherHeader),
)
require.NoError(
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
)
_, benchErr = noiseLocalConn.noise.ReadBody(
_, benchErr = localNoise.ReadBody(
bytes.NewReader(cipherMsg), msgBuf[:pktLen],
)
require.NoError(
@@ -60,7 +62,7 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
// We reset the internal nonce each time as otherwise, we'd
// continue to increment it which would cause a decryption
// failure.
noiseLocalConn.noise.recvCipher.nonce = nonceValue
localNoise.recvCipher.nonce = nonceValue
}
require.NoError(t, benchErr)
}
@@ -87,15 +89,16 @@ func BenchmarkWriteMessage(b *testing.B) {
b.ReportAllocs()
b.ResetTimer()
noise := noiseLocalConn.noise.Load()
for i := 0; i < b.N; i++ {
// Write our massive message, then call flush to actually write
// the encrypted message This simulates a full write operation
// to a network.
err := noiseLocalConn.noise.WriteMessage(largeMsg)
err := noise.WriteMessage(largeMsg)
if err != nil {
b.Fatalf("WriteMessage failed: %v", err)
}
_, err = noiseLocalConn.noise.Flush(discard)
_, err = noise.Flush(discard)
if err != nil {
b.Fatalf("Flush failed: %v", err)
}

View File

@@ -2,9 +2,11 @@ package brontide
import (
"bytes"
"errors"
"io"
"math"
"net"
"sync/atomic"
"time"
"github.com/btcsuite/btcd/btcec/v2"
@@ -13,6 +15,9 @@ import (
"github.com/lightningnetwork/lnd/tor"
)
// ErrConnClosed is returned when operations are attempted on a closed connection.
var ErrConnClosed = errors.New("brontide: connection closed")
// Conn is an implementation of net.Conn which enforces an authenticated key
// exchange and message encryption protocol dubbed "Brontide" after initial TCP
// connection establishment. In the case of a successful handshake, all
@@ -22,9 +27,15 @@ import (
type Conn struct {
conn net.Conn
noise *Machine
// noise is stored as an atomic pointer to allow safe cleanup on Close()
// while preventing nil pointer dereferences from concurrent operations.
noise atomic.Pointer[Machine]
readBuf bytes.Buffer
// closed is an atomic flag that tracks whether Close() has been called.
// This prevents nil pointer dereferences if methods are called after Close().
closed atomic.Uint32
}
// A compile-time assertion to ensure that Conn meets the net.Conn interface.
@@ -46,12 +57,13 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
}
b := &Conn{
conn: conn,
noise: NewBrontideMachine(true, local, netAddr.IdentityKey),
conn: conn,
}
b.noise.Store(NewBrontideMachine(true, local, netAddr.IdentityKey))
// Initiate the handshake by sending the first act to the receiver.
actOne, err := b.noise.GenActOne()
noise := b.noise.Load()
actOne, err := noise.GenActOne()
if err != nil {
b.conn.Close()
return nil, err
@@ -79,14 +91,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
b.conn.Close()
return nil, err
}
if err := b.noise.RecvActTwo(actTwo); err != nil {
if err := noise.RecvActTwo(actTwo); err != nil {
b.conn.Close()
return nil, err
}
// Finally, complete the handshake by sending over our encrypted static
// key and execute the final ECDH operation.
actThree, err := b.noise.GenActThree()
actThree, err := noise.GenActThree()
if err != nil {
b.conn.Close()
return nil, err
@@ -116,7 +128,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
// appropriately, it is preferred that they use the split ReadNextHeader and
// ReadNextBody methods so that the deadlines can be set appropriately on each.
func (c *Conn) ReadNextMessage() ([]byte, error) {
return c.noise.ReadMessage(c.conn)
if c.closed.Load() == 1 {
return nil, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return nil, ErrConnClosed
}
return noise.ReadMessage(c.conn)
}
// ReadNextHeader uses the connection to read the next header from the brontide
@@ -124,7 +143,14 @@ func (c *Conn) ReadNextMessage() ([]byte, error) {
// return the packet length (including MAC overhead) that is expected from the
// subsequent call to ReadNextBody.
func (c *Conn) ReadNextHeader() (uint32, error) {
return c.noise.ReadHeader(c.conn)
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
return noise.ReadHeader(c.conn)
}
// ReadNextBody uses the connection to read the next message body from the
@@ -132,7 +158,14 @@ func (c *Conn) ReadNextHeader() (uint32, error) {
// and return the decrypted payload. The provided buffer MUST be the packet
// length returned by the preceding call to ReadNextHeader.
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
return c.noise.ReadBody(c.conn, buf)
if c.closed.Load() == 1 {
return nil, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return nil, ErrConnClosed
}
return noise.ReadBody(c.conn, buf)
}
// Read reads data from the connection. Read can be made to time out and
@@ -141,13 +174,21 @@ func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
//
// Part of the net.Conn interface.
func (c *Conn) Read(b []byte) (n int, err error) {
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
// In order to reconcile the differences between the record abstraction
// of our AEAD connection, and the stream abstraction of TCP, we
// maintain an intermediate read buffer. If this buffer becomes
// depleted, then we read the next record, and feed it into the
// buffer. Otherwise, we read directly from the buffer.
if c.readBuf.Len() == 0 {
plaintext, err := c.noise.ReadMessage(c.conn)
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
plaintext, err := noise.ReadMessage(c.conn)
if err != nil {
return 0, err
}
@@ -166,14 +207,22 @@ func (c *Conn) Read(b []byte) (n int, err error) {
//
// Part of the net.Conn interface.
func (c *Conn) Write(b []byte) (n int, err error) {
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
// If the message doesn't require any chunking, then we can go ahead
// with a single write.
if len(b) <= math.MaxUint16 {
err = c.noise.WriteMessage(b)
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
err = noise.WriteMessage(b)
if err != nil {
return 0, err
}
return c.noise.Flush(c.conn)
return noise.Flush(c.conn)
}
// If we need to split the message into fragments, then we'll write
@@ -192,11 +241,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// Slice off the next chunk to be written based on our running
// counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(chunk); err != nil {
noise := c.noise.Load()
if noise == nil {
return bytesWritten, ErrConnClosed
}
if err := noise.WriteMessage(chunk); err != nil {
return bytesWritten, err
}
n, err := c.noise.Flush(c.conn)
n, err := noise.Flush(c.conn)
bytesWritten += n
if err != nil {
return bytesWritten, err
@@ -214,7 +267,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written.
func (c *Conn) WriteMessage(b []byte) error {
return c.noise.WriteMessage(b)
if c.closed.Load() == 1 {
return ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return ErrConnClosed
}
return noise.WriteMessage(b)
}
// Flush attempts to write a message buffered using WriteMessage to the
@@ -226,7 +286,14 @@ func (c *Conn) WriteMessage(b []byte) error {
//
// NOTE: It is safe to call this method again iff a timeout error is returned.
func (c *Conn) Flush() (int, error) {
return c.noise.Flush(c.conn)
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
return noise.Flush(c.conn)
}
// Close closes the connection. Any blocked Read or Write operations will be
@@ -234,8 +301,16 @@ func (c *Conn) Flush() (int, error) {
//
// Part of the net.Conn interface.
func (c *Conn) Close() error {
// Use compare-and-swap to ensure Close is only executed once.
if !c.closed.CompareAndSwap(0, 1) {
return ErrConnClosed
}
// Clear the state we created to be able to handle this connection.
c.noise = nil
// We atomically swap the noise pointer to nil, which allows the ~64KB
// buffers to be garbage collected immediately while preventing any
// nil pointer dereferences from concurrent operations.
c.noise.Store(nil)
c.readBuf = bytes.Buffer{}
return c.conn.Close()
@@ -283,10 +358,24 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// RemotePub returns the remote peer's static public key.
func (c *Conn) RemotePub() *btcec.PublicKey {
return c.noise.remoteStatic
if c.closed.Load() == 1 {
return nil
}
noise := c.noise.Load()
if noise == nil {
return nil
}
return noise.remoteStatic
}
// LocalPub returns the local peer's static public key.
func (c *Conn) LocalPub() *btcec.PublicKey {
return c.noise.localStatic.PubKey()
if c.closed.Load() == 1 {
return nil
}
noise := c.noise.Load()
if noise == nil {
return nil
}
return noise.localStatic.PubKey()
}

View File

@@ -116,9 +116,9 @@ func (l *Listener) doHandshake(conn net.Conn) {
remoteAddr := conn.RemoteAddr().String()
brontideConn := &Conn{
conn: conn,
noise: NewBrontideMachine(false, l.localStatic, nil),
conn: conn,
}
brontideConn.noise.Store(NewBrontideMachine(false, l.localStatic, nil))
// We'll ensure that we get ActOne from the remote peer in a timely
// manner. If they don't respond within handshakeReadTimeout, then
@@ -139,7 +139,8 @@ func (l *Listener) doHandshake(conn net.Conn) {
l.rejectConn(rejectedConnErr(err, remoteAddr))
return
}
if err := brontideConn.noise.RecvActOne(actOne); err != nil {
noise := brontideConn.noise.Load()
if err := noise.RecvActOne(actOne); err != nil {
brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr))
return
@@ -147,7 +148,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
// Next, progress the handshake processes by sending over our ephemeral
// key for the session along with an authenticating tag.
actTwo, err := brontideConn.noise.GenActTwo()
actTwo, err := noise.GenActTwo()
if err != nil {
brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr))
@@ -184,7 +185,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
l.rejectConn(rejectedConnErr(err, remoteAddr))
return
}
if err := brontideConn.noise.RecvActThree(actThree); err != nil {
if err := noise.RecvActThree(actThree); err != nil {
brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr))
return