From 3615a351acf42dd526aaddcfb48b3983362d6c77 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 10 Sep 2025 18:22:52 -0700 Subject: [PATCH] brontide: revise --- brontide/bench_test.go | 19 +++--- brontide/conn.go | 127 +++++++++++++++++++++++++++++++++++------ brontide/listener.go | 11 ++-- 3 files changed, 125 insertions(+), 32 deletions(-) diff --git a/brontide/bench_test.go b/brontide/bench_test.go index 0605c54f5..d1daae378 100644 --- a/brontide/bench_test.go +++ b/brontide/bench_test.go @@ -31,8 +31,9 @@ func BenchmarkReadHeaderAndBody(t *testing.B) { err = noiseRemoteConn.WriteMessage(msg) require.NoError(t, err, "unable to write encrypted message: %v", err) - cipherHeader := noiseRemoteConn.noise.nextHeaderSend - cipherMsg := noiseRemoteConn.noise.nextBodySend + noise := noiseRemoteConn.noise.Load() + cipherHeader := noise.nextHeaderSend + cipherMsg := noise.nextBodySend var ( benchErr error @@ -42,15 +43,16 @@ func BenchmarkReadHeaderAndBody(t *testing.B) { t.ReportAllocs() t.ResetTimer() - nonceValue := noiseLocalConn.noise.recvCipher.nonce + localNoise := noiseLocalConn.noise.Load() + nonceValue := localNoise.recvCipher.nonce for i := 0; i < t.N; i++ { - pktLen, benchErr := noiseLocalConn.noise.ReadHeader( + pktLen, benchErr := localNoise.ReadHeader( bytes.NewReader(cipherHeader), ) require.NoError( t, benchErr, "#%v: failed decryption: %v", i, benchErr, ) - _, benchErr = noiseLocalConn.noise.ReadBody( + _, benchErr = localNoise.ReadBody( bytes.NewReader(cipherMsg), msgBuf[:pktLen], ) require.NoError( @@ -60,7 +62,7 @@ func BenchmarkReadHeaderAndBody(t *testing.B) { // We reset the internal nonce each time as otherwise, we'd // continue to increment it which would cause a decryption // failure. - noiseLocalConn.noise.recvCipher.nonce = nonceValue + localNoise.recvCipher.nonce = nonceValue } require.NoError(t, benchErr) } @@ -87,15 +89,16 @@ func BenchmarkWriteMessage(b *testing.B) { b.ReportAllocs() b.ResetTimer() + noise := noiseLocalConn.noise.Load() for i := 0; i < b.N; i++ { // Write our massive message, then call flush to actually write // the encrypted message This simulates a full write operation // to a network. - err := noiseLocalConn.noise.WriteMessage(largeMsg) + err := noise.WriteMessage(largeMsg) if err != nil { b.Fatalf("WriteMessage failed: %v", err) } - _, err = noiseLocalConn.noise.Flush(discard) + _, err = noise.Flush(discard) if err != nil { b.Fatalf("Flush failed: %v", err) } diff --git a/brontide/conn.go b/brontide/conn.go index 0c45932ec..a5491baad 100644 --- a/brontide/conn.go +++ b/brontide/conn.go @@ -2,9 +2,11 @@ package brontide import ( "bytes" + "errors" "io" "math" "net" + "sync/atomic" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -13,6 +15,9 @@ import ( "github.com/lightningnetwork/lnd/tor" ) +// ErrConnClosed is returned when operations are attempted on a closed connection. +var ErrConnClosed = errors.New("brontide: connection closed") + // Conn is an implementation of net.Conn which enforces an authenticated key // exchange and message encryption protocol dubbed "Brontide" after initial TCP // connection establishment. In the case of a successful handshake, all @@ -22,9 +27,15 @@ import ( type Conn struct { conn net.Conn - noise *Machine + // noise is stored as an atomic pointer to allow safe cleanup on Close() + // while preventing nil pointer dereferences from concurrent operations. + noise atomic.Pointer[Machine] readBuf bytes.Buffer + + // closed is an atomic flag that tracks whether Close() has been called. + // This prevents nil pointer dereferences if methods are called after Close(). + closed atomic.Uint32 } // A compile-time assertion to ensure that Conn meets the net.Conn interface. @@ -46,12 +57,13 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, } b := &Conn{ - conn: conn, - noise: NewBrontideMachine(true, local, netAddr.IdentityKey), + conn: conn, } + b.noise.Store(NewBrontideMachine(true, local, netAddr.IdentityKey)) // Initiate the handshake by sending the first act to the receiver. - actOne, err := b.noise.GenActOne() + noise := b.noise.Load() + actOne, err := noise.GenActOne() if err != nil { b.conn.Close() return nil, err @@ -79,14 +91,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, b.conn.Close() return nil, err } - if err := b.noise.RecvActTwo(actTwo); err != nil { + if err := noise.RecvActTwo(actTwo); err != nil { b.conn.Close() return nil, err } // Finally, complete the handshake by sending over our encrypted static // key and execute the final ECDH operation. - actThree, err := b.noise.GenActThree() + actThree, err := noise.GenActThree() if err != nil { b.conn.Close() return nil, err @@ -116,7 +128,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress, // appropriately, it is preferred that they use the split ReadNextHeader and // ReadNextBody methods so that the deadlines can be set appropriately on each. func (c *Conn) ReadNextMessage() ([]byte, error) { - return c.noise.ReadMessage(c.conn) + if c.closed.Load() == 1 { + return nil, ErrConnClosed + } + noise := c.noise.Load() + if noise == nil { + return nil, ErrConnClosed + } + return noise.ReadMessage(c.conn) } // ReadNextHeader uses the connection to read the next header from the brontide @@ -124,7 +143,14 @@ func (c *Conn) ReadNextMessage() ([]byte, error) { // return the packet length (including MAC overhead) that is expected from the // subsequent call to ReadNextBody. func (c *Conn) ReadNextHeader() (uint32, error) { - return c.noise.ReadHeader(c.conn) + if c.closed.Load() == 1 { + return 0, ErrConnClosed + } + noise := c.noise.Load() + if noise == nil { + return 0, ErrConnClosed + } + return noise.ReadHeader(c.conn) } // ReadNextBody uses the connection to read the next message body from the @@ -132,7 +158,14 @@ func (c *Conn) ReadNextHeader() (uint32, error) { // and return the decrypted payload. The provided buffer MUST be the packet // length returned by the preceding call to ReadNextHeader. func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) { - return c.noise.ReadBody(c.conn, buf) + if c.closed.Load() == 1 { + return nil, ErrConnClosed + } + noise := c.noise.Load() + if noise == nil { + return nil, ErrConnClosed + } + return noise.ReadBody(c.conn, buf) } // Read reads data from the connection. Read can be made to time out and @@ -141,13 +174,21 @@ func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) { // // Part of the net.Conn interface. func (c *Conn) Read(b []byte) (n int, err error) { + if c.closed.Load() == 1 { + return 0, ErrConnClosed + } + // In order to reconcile the differences between the record abstraction // of our AEAD connection, and the stream abstraction of TCP, we // maintain an intermediate read buffer. If this buffer becomes // depleted, then we read the next record, and feed it into the // buffer. Otherwise, we read directly from the buffer. if c.readBuf.Len() == 0 { - plaintext, err := c.noise.ReadMessage(c.conn) + noise := c.noise.Load() + if noise == nil { + return 0, ErrConnClosed + } + plaintext, err := noise.ReadMessage(c.conn) if err != nil { return 0, err } @@ -166,14 +207,22 @@ func (c *Conn) Read(b []byte) (n int, err error) { // // Part of the net.Conn interface. func (c *Conn) Write(b []byte) (n int, err error) { + if c.closed.Load() == 1 { + return 0, ErrConnClosed + } + // If the message doesn't require any chunking, then we can go ahead // with a single write. if len(b) <= math.MaxUint16 { - err = c.noise.WriteMessage(b) + noise := c.noise.Load() + if noise == nil { + return 0, ErrConnClosed + } + err = noise.WriteMessage(b) if err != nil { return 0, err } - return c.noise.Flush(c.conn) + return noise.Flush(c.conn) } // If we need to split the message into fragments, then we'll write @@ -192,11 +241,15 @@ func (c *Conn) Write(b []byte) (n int, err error) { // Slice off the next chunk to be written based on our running // counter and next chunk size. chunk := b[bytesWritten : bytesWritten+chunkSize] - if err := c.noise.WriteMessage(chunk); err != nil { + noise := c.noise.Load() + if noise == nil { + return bytesWritten, ErrConnClosed + } + if err := noise.WriteMessage(chunk); err != nil { return bytesWritten, err } - n, err := c.noise.Flush(c.conn) + n, err := noise.Flush(c.conn) bytesWritten += n if err != nil { return bytesWritten, err @@ -214,7 +267,14 @@ func (c *Conn) Write(b []byte) (n int, err error) { // NOTE: This DOES NOT write the message to the wire, it should be followed by a // call to Flush to ensure the message is written. func (c *Conn) WriteMessage(b []byte) error { - return c.noise.WriteMessage(b) + if c.closed.Load() == 1 { + return ErrConnClosed + } + noise := c.noise.Load() + if noise == nil { + return ErrConnClosed + } + return noise.WriteMessage(b) } // Flush attempts to write a message buffered using WriteMessage to the @@ -226,7 +286,14 @@ func (c *Conn) WriteMessage(b []byte) error { // // NOTE: It is safe to call this method again iff a timeout error is returned. func (c *Conn) Flush() (int, error) { - return c.noise.Flush(c.conn) + if c.closed.Load() == 1 { + return 0, ErrConnClosed + } + noise := c.noise.Load() + if noise == nil { + return 0, ErrConnClosed + } + return noise.Flush(c.conn) } // Close closes the connection. Any blocked Read or Write operations will be @@ -234,8 +301,16 @@ func (c *Conn) Flush() (int, error) { // // Part of the net.Conn interface. func (c *Conn) Close() error { + // Use compare-and-swap to ensure Close is only executed once. + if !c.closed.CompareAndSwap(0, 1) { + return ErrConnClosed + } + // Clear the state we created to be able to handle this connection. - c.noise = nil + // We atomically swap the noise pointer to nil, which allows the ~64KB + // buffers to be garbage collected immediately while preventing any + // nil pointer dereferences from concurrent operations. + c.noise.Store(nil) c.readBuf = bytes.Buffer{} return c.conn.Close() @@ -283,10 +358,24 @@ func (c *Conn) SetWriteDeadline(t time.Time) error { // RemotePub returns the remote peer's static public key. func (c *Conn) RemotePub() *btcec.PublicKey { - return c.noise.remoteStatic + if c.closed.Load() == 1 { + return nil + } + noise := c.noise.Load() + if noise == nil { + return nil + } + return noise.remoteStatic } // LocalPub returns the local peer's static public key. func (c *Conn) LocalPub() *btcec.PublicKey { - return c.noise.localStatic.PubKey() + if c.closed.Load() == 1 { + return nil + } + noise := c.noise.Load() + if noise == nil { + return nil + } + return noise.localStatic.PubKey() } diff --git a/brontide/listener.go b/brontide/listener.go index 4d386c1c5..7ef27186a 100644 --- a/brontide/listener.go +++ b/brontide/listener.go @@ -116,9 +116,9 @@ func (l *Listener) doHandshake(conn net.Conn) { remoteAddr := conn.RemoteAddr().String() brontideConn := &Conn{ - conn: conn, - noise: NewBrontideMachine(false, l.localStatic, nil), + conn: conn, } + brontideConn.noise.Store(NewBrontideMachine(false, l.localStatic, nil)) // We'll ensure that we get ActOne from the remote peer in a timely // manner. If they don't respond within handshakeReadTimeout, then @@ -139,7 +139,8 @@ func (l *Listener) doHandshake(conn net.Conn) { l.rejectConn(rejectedConnErr(err, remoteAddr)) return } - if err := brontideConn.noise.RecvActOne(actOne); err != nil { + noise := brontideConn.noise.Load() + if err := noise.RecvActOne(actOne); err != nil { brontideConn.conn.Close() l.rejectConn(rejectedConnErr(err, remoteAddr)) return @@ -147,7 +148,7 @@ func (l *Listener) doHandshake(conn net.Conn) { // Next, progress the handshake processes by sending over our ephemeral // key for the session along with an authenticating tag. - actTwo, err := brontideConn.noise.GenActTwo() + actTwo, err := noise.GenActTwo() if err != nil { brontideConn.conn.Close() l.rejectConn(rejectedConnErr(err, remoteAddr)) @@ -184,7 +185,7 @@ func (l *Listener) doHandshake(conn net.Conn) { l.rejectConn(rejectedConnErr(err, remoteAddr)) return } - if err := brontideConn.noise.RecvActThree(actThree); err != nil { + if err := noise.RecvActThree(actThree); err != nil { brontideConn.conn.Close() l.rejectConn(rejectedConnErr(err, remoteAddr)) return