mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-03 20:53:04 +02:00
brontide: revise
This commit is contained in:
@@ -31,8 +31,9 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
||||
err = noiseRemoteConn.WriteMessage(msg)
|
||||
require.NoError(t, err, "unable to write encrypted message: %v", err)
|
||||
|
||||
cipherHeader := noiseRemoteConn.noise.nextHeaderSend
|
||||
cipherMsg := noiseRemoteConn.noise.nextBodySend
|
||||
noise := noiseRemoteConn.noise.Load()
|
||||
cipherHeader := noise.nextHeaderSend
|
||||
cipherMsg := noise.nextBodySend
|
||||
|
||||
var (
|
||||
benchErr error
|
||||
@@ -42,15 +43,16 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
||||
t.ReportAllocs()
|
||||
t.ResetTimer()
|
||||
|
||||
nonceValue := noiseLocalConn.noise.recvCipher.nonce
|
||||
localNoise := noiseLocalConn.noise.Load()
|
||||
nonceValue := localNoise.recvCipher.nonce
|
||||
for i := 0; i < t.N; i++ {
|
||||
pktLen, benchErr := noiseLocalConn.noise.ReadHeader(
|
||||
pktLen, benchErr := localNoise.ReadHeader(
|
||||
bytes.NewReader(cipherHeader),
|
||||
)
|
||||
require.NoError(
|
||||
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
|
||||
)
|
||||
_, benchErr = noiseLocalConn.noise.ReadBody(
|
||||
_, benchErr = localNoise.ReadBody(
|
||||
bytes.NewReader(cipherMsg), msgBuf[:pktLen],
|
||||
)
|
||||
require.NoError(
|
||||
@@ -60,7 +62,7 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
||||
// We reset the internal nonce each time as otherwise, we'd
|
||||
// continue to increment it which would cause a decryption
|
||||
// failure.
|
||||
noiseLocalConn.noise.recvCipher.nonce = nonceValue
|
||||
localNoise.recvCipher.nonce = nonceValue
|
||||
}
|
||||
require.NoError(t, benchErr)
|
||||
}
|
||||
@@ -87,15 +89,16 @@ func BenchmarkWriteMessage(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
noise := noiseLocalConn.noise.Load()
|
||||
for i := 0; i < b.N; i++ {
|
||||
// Write our massive message, then call flush to actually write
|
||||
// the encrypted message This simulates a full write operation
|
||||
// to a network.
|
||||
err := noiseLocalConn.noise.WriteMessage(largeMsg)
|
||||
err := noise.WriteMessage(largeMsg)
|
||||
if err != nil {
|
||||
b.Fatalf("WriteMessage failed: %v", err)
|
||||
}
|
||||
_, err = noiseLocalConn.noise.Flush(discard)
|
||||
_, err = noise.Flush(discard)
|
||||
if err != nil {
|
||||
b.Fatalf("Flush failed: %v", err)
|
||||
}
|
||||
|
127
brontide/conn.go
127
brontide/conn.go
@@ -2,9 +2,11 @@ package brontide
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"math"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
@@ -13,6 +15,9 @@ import (
|
||||
"github.com/lightningnetwork/lnd/tor"
|
||||
)
|
||||
|
||||
// ErrConnClosed is returned when operations are attempted on a closed connection.
|
||||
var ErrConnClosed = errors.New("brontide: connection closed")
|
||||
|
||||
// Conn is an implementation of net.Conn which enforces an authenticated key
|
||||
// exchange and message encryption protocol dubbed "Brontide" after initial TCP
|
||||
// connection establishment. In the case of a successful handshake, all
|
||||
@@ -22,9 +27,15 @@ import (
|
||||
type Conn struct {
|
||||
conn net.Conn
|
||||
|
||||
noise *Machine
|
||||
// noise is stored as an atomic pointer to allow safe cleanup on Close()
|
||||
// while preventing nil pointer dereferences from concurrent operations.
|
||||
noise atomic.Pointer[Machine]
|
||||
|
||||
readBuf bytes.Buffer
|
||||
|
||||
// closed is an atomic flag that tracks whether Close() has been called.
|
||||
// This prevents nil pointer dereferences if methods are called after Close().
|
||||
closed atomic.Uint32
|
||||
}
|
||||
|
||||
// A compile-time assertion to ensure that Conn meets the net.Conn interface.
|
||||
@@ -46,12 +57,13 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
||||
}
|
||||
|
||||
b := &Conn{
|
||||
conn: conn,
|
||||
noise: NewBrontideMachine(true, local, netAddr.IdentityKey),
|
||||
conn: conn,
|
||||
}
|
||||
b.noise.Store(NewBrontideMachine(true, local, netAddr.IdentityKey))
|
||||
|
||||
// Initiate the handshake by sending the first act to the receiver.
|
||||
actOne, err := b.noise.GenActOne()
|
||||
noise := b.noise.Load()
|
||||
actOne, err := noise.GenActOne()
|
||||
if err != nil {
|
||||
b.conn.Close()
|
||||
return nil, err
|
||||
@@ -79,14 +91,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
||||
b.conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
if err := b.noise.RecvActTwo(actTwo); err != nil {
|
||||
if err := noise.RecvActTwo(actTwo); err != nil {
|
||||
b.conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Finally, complete the handshake by sending over our encrypted static
|
||||
// key and execute the final ECDH operation.
|
||||
actThree, err := b.noise.GenActThree()
|
||||
actThree, err := noise.GenActThree()
|
||||
if err != nil {
|
||||
b.conn.Close()
|
||||
return nil, err
|
||||
@@ -116,7 +128,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
||||
// appropriately, it is preferred that they use the split ReadNextHeader and
|
||||
// ReadNextBody methods so that the deadlines can be set appropriately on each.
|
||||
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
||||
return c.noise.ReadMessage(c.conn)
|
||||
if c.closed.Load() == 1 {
|
||||
return nil, ErrConnClosed
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return nil, ErrConnClosed
|
||||
}
|
||||
return noise.ReadMessage(c.conn)
|
||||
}
|
||||
|
||||
// ReadNextHeader uses the connection to read the next header from the brontide
|
||||
@@ -124,7 +143,14 @@ func (c *Conn) ReadNextMessage() ([]byte, error) {
|
||||
// return the packet length (including MAC overhead) that is expected from the
|
||||
// subsequent call to ReadNextBody.
|
||||
func (c *Conn) ReadNextHeader() (uint32, error) {
|
||||
return c.noise.ReadHeader(c.conn)
|
||||
if c.closed.Load() == 1 {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
return noise.ReadHeader(c.conn)
|
||||
}
|
||||
|
||||
// ReadNextBody uses the connection to read the next message body from the
|
||||
@@ -132,7 +158,14 @@ func (c *Conn) ReadNextHeader() (uint32, error) {
|
||||
// and return the decrypted payload. The provided buffer MUST be the packet
|
||||
// length returned by the preceding call to ReadNextHeader.
|
||||
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
||||
return c.noise.ReadBody(c.conn, buf)
|
||||
if c.closed.Load() == 1 {
|
||||
return nil, ErrConnClosed
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return nil, ErrConnClosed
|
||||
}
|
||||
return noise.ReadBody(c.conn, buf)
|
||||
}
|
||||
|
||||
// Read reads data from the connection. Read can be made to time out and
|
||||
@@ -141,13 +174,21 @@ func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
||||
//
|
||||
// Part of the net.Conn interface.
|
||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
if c.closed.Load() == 1 {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
|
||||
// In order to reconcile the differences between the record abstraction
|
||||
// of our AEAD connection, and the stream abstraction of TCP, we
|
||||
// maintain an intermediate read buffer. If this buffer becomes
|
||||
// depleted, then we read the next record, and feed it into the
|
||||
// buffer. Otherwise, we read directly from the buffer.
|
||||
if c.readBuf.Len() == 0 {
|
||||
plaintext, err := c.noise.ReadMessage(c.conn)
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
plaintext, err := noise.ReadMessage(c.conn)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -166,14 +207,22 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
||||
//
|
||||
// Part of the net.Conn interface.
|
||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
if c.closed.Load() == 1 {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
|
||||
// If the message doesn't require any chunking, then we can go ahead
|
||||
// with a single write.
|
||||
if len(b) <= math.MaxUint16 {
|
||||
err = c.noise.WriteMessage(b)
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
err = noise.WriteMessage(b)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return c.noise.Flush(c.conn)
|
||||
return noise.Flush(c.conn)
|
||||
}
|
||||
|
||||
// If we need to split the message into fragments, then we'll write
|
||||
@@ -192,11 +241,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// Slice off the next chunk to be written based on our running
|
||||
// counter and next chunk size.
|
||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||
if err := c.noise.WriteMessage(chunk); err != nil {
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return bytesWritten, ErrConnClosed
|
||||
}
|
||||
if err := noise.WriteMessage(chunk); err != nil {
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
n, err := c.noise.Flush(c.conn)
|
||||
n, err := noise.Flush(c.conn)
|
||||
bytesWritten += n
|
||||
if err != nil {
|
||||
return bytesWritten, err
|
||||
@@ -214,7 +267,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
||||
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||
// call to Flush to ensure the message is written.
|
||||
func (c *Conn) WriteMessage(b []byte) error {
|
||||
return c.noise.WriteMessage(b)
|
||||
if c.closed.Load() == 1 {
|
||||
return ErrConnClosed
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return ErrConnClosed
|
||||
}
|
||||
return noise.WriteMessage(b)
|
||||
}
|
||||
|
||||
// Flush attempts to write a message buffered using WriteMessage to the
|
||||
@@ -226,7 +286,14 @@ func (c *Conn) WriteMessage(b []byte) error {
|
||||
//
|
||||
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||
func (c *Conn) Flush() (int, error) {
|
||||
return c.noise.Flush(c.conn)
|
||||
if c.closed.Load() == 1 {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return 0, ErrConnClosed
|
||||
}
|
||||
return noise.Flush(c.conn)
|
||||
}
|
||||
|
||||
// Close closes the connection. Any blocked Read or Write operations will be
|
||||
@@ -234,8 +301,16 @@ func (c *Conn) Flush() (int, error) {
|
||||
//
|
||||
// Part of the net.Conn interface.
|
||||
func (c *Conn) Close() error {
|
||||
// Use compare-and-swap to ensure Close is only executed once.
|
||||
if !c.closed.CompareAndSwap(0, 1) {
|
||||
return ErrConnClosed
|
||||
}
|
||||
|
||||
// Clear the state we created to be able to handle this connection.
|
||||
c.noise = nil
|
||||
// We atomically swap the noise pointer to nil, which allows the ~64KB
|
||||
// buffers to be garbage collected immediately while preventing any
|
||||
// nil pointer dereferences from concurrent operations.
|
||||
c.noise.Store(nil)
|
||||
c.readBuf = bytes.Buffer{}
|
||||
|
||||
return c.conn.Close()
|
||||
@@ -283,10 +358,24 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
||||
|
||||
// RemotePub returns the remote peer's static public key.
|
||||
func (c *Conn) RemotePub() *btcec.PublicKey {
|
||||
return c.noise.remoteStatic
|
||||
if c.closed.Load() == 1 {
|
||||
return nil
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return nil
|
||||
}
|
||||
return noise.remoteStatic
|
||||
}
|
||||
|
||||
// LocalPub returns the local peer's static public key.
|
||||
func (c *Conn) LocalPub() *btcec.PublicKey {
|
||||
return c.noise.localStatic.PubKey()
|
||||
if c.closed.Load() == 1 {
|
||||
return nil
|
||||
}
|
||||
noise := c.noise.Load()
|
||||
if noise == nil {
|
||||
return nil
|
||||
}
|
||||
return noise.localStatic.PubKey()
|
||||
}
|
||||
|
@@ -116,9 +116,9 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
|
||||
brontideConn := &Conn{
|
||||
conn: conn,
|
||||
noise: NewBrontideMachine(false, l.localStatic, nil),
|
||||
conn: conn,
|
||||
}
|
||||
brontideConn.noise.Store(NewBrontideMachine(false, l.localStatic, nil))
|
||||
|
||||
// We'll ensure that we get ActOne from the remote peer in a timely
|
||||
// manner. If they don't respond within handshakeReadTimeout, then
|
||||
@@ -139,7 +139,8 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||
return
|
||||
}
|
||||
if err := brontideConn.noise.RecvActOne(actOne); err != nil {
|
||||
noise := brontideConn.noise.Load()
|
||||
if err := noise.RecvActOne(actOne); err != nil {
|
||||
brontideConn.conn.Close()
|
||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||
return
|
||||
@@ -147,7 +148,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
||||
|
||||
// Next, progress the handshake processes by sending over our ephemeral
|
||||
// key for the session along with an authenticating tag.
|
||||
actTwo, err := brontideConn.noise.GenActTwo()
|
||||
actTwo, err := noise.GenActTwo()
|
||||
if err != nil {
|
||||
brontideConn.conn.Close()
|
||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||
@@ -184,7 +185,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||
return
|
||||
}
|
||||
if err := brontideConn.noise.RecvActThree(actThree); err != nil {
|
||||
if err := noise.RecvActThree(actThree); err != nil {
|
||||
brontideConn.conn.Close()
|
||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||
return
|
||||
|
Reference in New Issue
Block a user