brontide: revise

This commit is contained in:
Olaoluwa Osuntokun
2025-09-10 18:22:52 -07:00
parent a58e329b2c
commit 3615a351ac
3 changed files with 125 additions and 32 deletions

View File

@@ -31,8 +31,9 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
err = noiseRemoteConn.WriteMessage(msg) err = noiseRemoteConn.WriteMessage(msg)
require.NoError(t, err, "unable to write encrypted message: %v", err) require.NoError(t, err, "unable to write encrypted message: %v", err)
cipherHeader := noiseRemoteConn.noise.nextHeaderSend noise := noiseRemoteConn.noise.Load()
cipherMsg := noiseRemoteConn.noise.nextBodySend cipherHeader := noise.nextHeaderSend
cipherMsg := noise.nextBodySend
var ( var (
benchErr error benchErr error
@@ -42,15 +43,16 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
t.ReportAllocs() t.ReportAllocs()
t.ResetTimer() t.ResetTimer()
nonceValue := noiseLocalConn.noise.recvCipher.nonce localNoise := noiseLocalConn.noise.Load()
nonceValue := localNoise.recvCipher.nonce
for i := 0; i < t.N; i++ { for i := 0; i < t.N; i++ {
pktLen, benchErr := noiseLocalConn.noise.ReadHeader( pktLen, benchErr := localNoise.ReadHeader(
bytes.NewReader(cipherHeader), bytes.NewReader(cipherHeader),
) )
require.NoError( require.NoError(
t, benchErr, "#%v: failed decryption: %v", i, benchErr, t, benchErr, "#%v: failed decryption: %v", i, benchErr,
) )
_, benchErr = noiseLocalConn.noise.ReadBody( _, benchErr = localNoise.ReadBody(
bytes.NewReader(cipherMsg), msgBuf[:pktLen], bytes.NewReader(cipherMsg), msgBuf[:pktLen],
) )
require.NoError( require.NoError(
@@ -60,7 +62,7 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
// We reset the internal nonce each time as otherwise, we'd // We reset the internal nonce each time as otherwise, we'd
// continue to increment it which would cause a decryption // continue to increment it which would cause a decryption
// failure. // failure.
noiseLocalConn.noise.recvCipher.nonce = nonceValue localNoise.recvCipher.nonce = nonceValue
} }
require.NoError(t, benchErr) require.NoError(t, benchErr)
} }
@@ -87,15 +89,16 @@ func BenchmarkWriteMessage(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
b.ResetTimer() b.ResetTimer()
noise := noiseLocalConn.noise.Load()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
// Write our massive message, then call flush to actually write // Write our massive message, then call flush to actually write
// the encrypted message This simulates a full write operation // the encrypted message This simulates a full write operation
// to a network. // to a network.
err := noiseLocalConn.noise.WriteMessage(largeMsg) err := noise.WriteMessage(largeMsg)
if err != nil { if err != nil {
b.Fatalf("WriteMessage failed: %v", err) b.Fatalf("WriteMessage failed: %v", err)
} }
_, err = noiseLocalConn.noise.Flush(discard) _, err = noise.Flush(discard)
if err != nil { if err != nil {
b.Fatalf("Flush failed: %v", err) b.Fatalf("Flush failed: %v", err)
} }

View File

@@ -2,9 +2,11 @@ package brontide
import ( import (
"bytes" "bytes"
"errors"
"io" "io"
"math" "math"
"net" "net"
"sync/atomic"
"time" "time"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
@@ -13,6 +15,9 @@ import (
"github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/tor"
) )
// ErrConnClosed is returned when operations are attempted on a closed connection.
var ErrConnClosed = errors.New("brontide: connection closed")
// Conn is an implementation of net.Conn which enforces an authenticated key // Conn is an implementation of net.Conn which enforces an authenticated key
// exchange and message encryption protocol dubbed "Brontide" after initial TCP // exchange and message encryption protocol dubbed "Brontide" after initial TCP
// connection establishment. In the case of a successful handshake, all // connection establishment. In the case of a successful handshake, all
@@ -22,9 +27,15 @@ import (
type Conn struct { type Conn struct {
conn net.Conn conn net.Conn
noise *Machine // noise is stored as an atomic pointer to allow safe cleanup on Close()
// while preventing nil pointer dereferences from concurrent operations.
noise atomic.Pointer[Machine]
readBuf bytes.Buffer readBuf bytes.Buffer
// closed is an atomic flag that tracks whether Close() has been called.
// This prevents nil pointer dereferences if methods are called after Close().
closed atomic.Uint32
} }
// A compile-time assertion to ensure that Conn meets the net.Conn interface. // A compile-time assertion to ensure that Conn meets the net.Conn interface.
@@ -46,12 +57,13 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
} }
b := &Conn{ b := &Conn{
conn: conn, conn: conn,
noise: NewBrontideMachine(true, local, netAddr.IdentityKey),
} }
b.noise.Store(NewBrontideMachine(true, local, netAddr.IdentityKey))
// Initiate the handshake by sending the first act to the receiver. // Initiate the handshake by sending the first act to the receiver.
actOne, err := b.noise.GenActOne() noise := b.noise.Load()
actOne, err := noise.GenActOne()
if err != nil { if err != nil {
b.conn.Close() b.conn.Close()
return nil, err return nil, err
@@ -79,14 +91,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
b.conn.Close() b.conn.Close()
return nil, err return nil, err
} }
if err := b.noise.RecvActTwo(actTwo); err != nil { if err := noise.RecvActTwo(actTwo); err != nil {
b.conn.Close() b.conn.Close()
return nil, err return nil, err
} }
// Finally, complete the handshake by sending over our encrypted static // Finally, complete the handshake by sending over our encrypted static
// key and execute the final ECDH operation. // key and execute the final ECDH operation.
actThree, err := b.noise.GenActThree() actThree, err := noise.GenActThree()
if err != nil { if err != nil {
b.conn.Close() b.conn.Close()
return nil, err return nil, err
@@ -116,7 +128,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
// appropriately, it is preferred that they use the split ReadNextHeader and // appropriately, it is preferred that they use the split ReadNextHeader and
// ReadNextBody methods so that the deadlines can be set appropriately on each. // ReadNextBody methods so that the deadlines can be set appropriately on each.
func (c *Conn) ReadNextMessage() ([]byte, error) { func (c *Conn) ReadNextMessage() ([]byte, error) {
return c.noise.ReadMessage(c.conn) if c.closed.Load() == 1 {
return nil, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return nil, ErrConnClosed
}
return noise.ReadMessage(c.conn)
} }
// ReadNextHeader uses the connection to read the next header from the brontide // ReadNextHeader uses the connection to read the next header from the brontide
@@ -124,7 +143,14 @@ func (c *Conn) ReadNextMessage() ([]byte, error) {
// return the packet length (including MAC overhead) that is expected from the // return the packet length (including MAC overhead) that is expected from the
// subsequent call to ReadNextBody. // subsequent call to ReadNextBody.
func (c *Conn) ReadNextHeader() (uint32, error) { func (c *Conn) ReadNextHeader() (uint32, error) {
return c.noise.ReadHeader(c.conn) if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
return noise.ReadHeader(c.conn)
} }
// ReadNextBody uses the connection to read the next message body from the // ReadNextBody uses the connection to read the next message body from the
@@ -132,7 +158,14 @@ func (c *Conn) ReadNextHeader() (uint32, error) {
// and return the decrypted payload. The provided buffer MUST be the packet // and return the decrypted payload. The provided buffer MUST be the packet
// length returned by the preceding call to ReadNextHeader. // length returned by the preceding call to ReadNextHeader.
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) { func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
return c.noise.ReadBody(c.conn, buf) if c.closed.Load() == 1 {
return nil, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return nil, ErrConnClosed
}
return noise.ReadBody(c.conn, buf)
} }
// Read reads data from the connection. Read can be made to time out and // Read reads data from the connection. Read can be made to time out and
@@ -141,13 +174,21 @@ func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
// //
// Part of the net.Conn interface. // Part of the net.Conn interface.
func (c *Conn) Read(b []byte) (n int, err error) { func (c *Conn) Read(b []byte) (n int, err error) {
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
// In order to reconcile the differences between the record abstraction // In order to reconcile the differences between the record abstraction
// of our AEAD connection, and the stream abstraction of TCP, we // of our AEAD connection, and the stream abstraction of TCP, we
// maintain an intermediate read buffer. If this buffer becomes // maintain an intermediate read buffer. If this buffer becomes
// depleted, then we read the next record, and feed it into the // depleted, then we read the next record, and feed it into the
// buffer. Otherwise, we read directly from the buffer. // buffer. Otherwise, we read directly from the buffer.
if c.readBuf.Len() == 0 { if c.readBuf.Len() == 0 {
plaintext, err := c.noise.ReadMessage(c.conn) noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
plaintext, err := noise.ReadMessage(c.conn)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@@ -166,14 +207,22 @@ func (c *Conn) Read(b []byte) (n int, err error) {
// //
// Part of the net.Conn interface. // Part of the net.Conn interface.
func (c *Conn) Write(b []byte) (n int, err error) { func (c *Conn) Write(b []byte) (n int, err error) {
if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
// If the message doesn't require any chunking, then we can go ahead // If the message doesn't require any chunking, then we can go ahead
// with a single write. // with a single write.
if len(b) <= math.MaxUint16 { if len(b) <= math.MaxUint16 {
err = c.noise.WriteMessage(b) noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
err = noise.WriteMessage(b)
if err != nil { if err != nil {
return 0, err return 0, err
} }
return c.noise.Flush(c.conn) return noise.Flush(c.conn)
} }
// If we need to split the message into fragments, then we'll write // If we need to split the message into fragments, then we'll write
@@ -192,11 +241,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// Slice off the next chunk to be written based on our running // Slice off the next chunk to be written based on our running
// counter and next chunk size. // counter and next chunk size.
chunk := b[bytesWritten : bytesWritten+chunkSize] chunk := b[bytesWritten : bytesWritten+chunkSize]
if err := c.noise.WriteMessage(chunk); err != nil { noise := c.noise.Load()
if noise == nil {
return bytesWritten, ErrConnClosed
}
if err := noise.WriteMessage(chunk); err != nil {
return bytesWritten, err return bytesWritten, err
} }
n, err := c.noise.Flush(c.conn) n, err := noise.Flush(c.conn)
bytesWritten += n bytesWritten += n
if err != nil { if err != nil {
return bytesWritten, err return bytesWritten, err
@@ -214,7 +267,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
// NOTE: This DOES NOT write the message to the wire, it should be followed by a // NOTE: This DOES NOT write the message to the wire, it should be followed by a
// call to Flush to ensure the message is written. // call to Flush to ensure the message is written.
func (c *Conn) WriteMessage(b []byte) error { func (c *Conn) WriteMessage(b []byte) error {
return c.noise.WriteMessage(b) if c.closed.Load() == 1 {
return ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return ErrConnClosed
}
return noise.WriteMessage(b)
} }
// Flush attempts to write a message buffered using WriteMessage to the // Flush attempts to write a message buffered using WriteMessage to the
@@ -226,7 +286,14 @@ func (c *Conn) WriteMessage(b []byte) error {
// //
// NOTE: It is safe to call this method again iff a timeout error is returned. // NOTE: It is safe to call this method again iff a timeout error is returned.
func (c *Conn) Flush() (int, error) { func (c *Conn) Flush() (int, error) {
return c.noise.Flush(c.conn) if c.closed.Load() == 1 {
return 0, ErrConnClosed
}
noise := c.noise.Load()
if noise == nil {
return 0, ErrConnClosed
}
return noise.Flush(c.conn)
} }
// Close closes the connection. Any blocked Read or Write operations will be // Close closes the connection. Any blocked Read or Write operations will be
@@ -234,8 +301,16 @@ func (c *Conn) Flush() (int, error) {
// //
// Part of the net.Conn interface. // Part of the net.Conn interface.
func (c *Conn) Close() error { func (c *Conn) Close() error {
// Use compare-and-swap to ensure Close is only executed once.
if !c.closed.CompareAndSwap(0, 1) {
return ErrConnClosed
}
// Clear the state we created to be able to handle this connection. // Clear the state we created to be able to handle this connection.
c.noise = nil // We atomically swap the noise pointer to nil, which allows the ~64KB
// buffers to be garbage collected immediately while preventing any
// nil pointer dereferences from concurrent operations.
c.noise.Store(nil)
c.readBuf = bytes.Buffer{} c.readBuf = bytes.Buffer{}
return c.conn.Close() return c.conn.Close()
@@ -283,10 +358,24 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
// RemotePub returns the remote peer's static public key. // RemotePub returns the remote peer's static public key.
func (c *Conn) RemotePub() *btcec.PublicKey { func (c *Conn) RemotePub() *btcec.PublicKey {
return c.noise.remoteStatic if c.closed.Load() == 1 {
return nil
}
noise := c.noise.Load()
if noise == nil {
return nil
}
return noise.remoteStatic
} }
// LocalPub returns the local peer's static public key. // LocalPub returns the local peer's static public key.
func (c *Conn) LocalPub() *btcec.PublicKey { func (c *Conn) LocalPub() *btcec.PublicKey {
return c.noise.localStatic.PubKey() if c.closed.Load() == 1 {
return nil
}
noise := c.noise.Load()
if noise == nil {
return nil
}
return noise.localStatic.PubKey()
} }

View File

@@ -116,9 +116,9 @@ func (l *Listener) doHandshake(conn net.Conn) {
remoteAddr := conn.RemoteAddr().String() remoteAddr := conn.RemoteAddr().String()
brontideConn := &Conn{ brontideConn := &Conn{
conn: conn, conn: conn,
noise: NewBrontideMachine(false, l.localStatic, nil),
} }
brontideConn.noise.Store(NewBrontideMachine(false, l.localStatic, nil))
// We'll ensure that we get ActOne from the remote peer in a timely // We'll ensure that we get ActOne from the remote peer in a timely
// manner. If they don't respond within handshakeReadTimeout, then // manner. If they don't respond within handshakeReadTimeout, then
@@ -139,7 +139,8 @@ func (l *Listener) doHandshake(conn net.Conn) {
l.rejectConn(rejectedConnErr(err, remoteAddr)) l.rejectConn(rejectedConnErr(err, remoteAddr))
return return
} }
if err := brontideConn.noise.RecvActOne(actOne); err != nil { noise := brontideConn.noise.Load()
if err := noise.RecvActOne(actOne); err != nil {
brontideConn.conn.Close() brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr)) l.rejectConn(rejectedConnErr(err, remoteAddr))
return return
@@ -147,7 +148,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
// Next, progress the handshake processes by sending over our ephemeral // Next, progress the handshake processes by sending over our ephemeral
// key for the session along with an authenticating tag. // key for the session along with an authenticating tag.
actTwo, err := brontideConn.noise.GenActTwo() actTwo, err := noise.GenActTwo()
if err != nil { if err != nil {
brontideConn.conn.Close() brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr)) l.rejectConn(rejectedConnErr(err, remoteAddr))
@@ -184,7 +185,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
l.rejectConn(rejectedConnErr(err, remoteAddr)) l.rejectConn(rejectedConnErr(err, remoteAddr))
return return
} }
if err := brontideConn.noise.RecvActThree(actThree); err != nil { if err := noise.RecvActThree(actThree); err != nil {
brontideConn.conn.Close() brontideConn.conn.Close()
l.rejectConn(rejectedConnErr(err, remoteAddr)) l.rejectConn(rejectedConnErr(err, remoteAddr))
return return