mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-09 19:54:04 +02:00
brontide: revise
This commit is contained in:
@@ -31,8 +31,9 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
|||||||
err = noiseRemoteConn.WriteMessage(msg)
|
err = noiseRemoteConn.WriteMessage(msg)
|
||||||
require.NoError(t, err, "unable to write encrypted message: %v", err)
|
require.NoError(t, err, "unable to write encrypted message: %v", err)
|
||||||
|
|
||||||
cipherHeader := noiseRemoteConn.noise.nextHeaderSend
|
noise := noiseRemoteConn.noise.Load()
|
||||||
cipherMsg := noiseRemoteConn.noise.nextBodySend
|
cipherHeader := noise.nextHeaderSend
|
||||||
|
cipherMsg := noise.nextBodySend
|
||||||
|
|
||||||
var (
|
var (
|
||||||
benchErr error
|
benchErr error
|
||||||
@@ -42,15 +43,16 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
|||||||
t.ReportAllocs()
|
t.ReportAllocs()
|
||||||
t.ResetTimer()
|
t.ResetTimer()
|
||||||
|
|
||||||
nonceValue := noiseLocalConn.noise.recvCipher.nonce
|
localNoise := noiseLocalConn.noise.Load()
|
||||||
|
nonceValue := localNoise.recvCipher.nonce
|
||||||
for i := 0; i < t.N; i++ {
|
for i := 0; i < t.N; i++ {
|
||||||
pktLen, benchErr := noiseLocalConn.noise.ReadHeader(
|
pktLen, benchErr := localNoise.ReadHeader(
|
||||||
bytes.NewReader(cipherHeader),
|
bytes.NewReader(cipherHeader),
|
||||||
)
|
)
|
||||||
require.NoError(
|
require.NoError(
|
||||||
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
|
t, benchErr, "#%v: failed decryption: %v", i, benchErr,
|
||||||
)
|
)
|
||||||
_, benchErr = noiseLocalConn.noise.ReadBody(
|
_, benchErr = localNoise.ReadBody(
|
||||||
bytes.NewReader(cipherMsg), msgBuf[:pktLen],
|
bytes.NewReader(cipherMsg), msgBuf[:pktLen],
|
||||||
)
|
)
|
||||||
require.NoError(
|
require.NoError(
|
||||||
@@ -60,7 +62,7 @@ func BenchmarkReadHeaderAndBody(t *testing.B) {
|
|||||||
// We reset the internal nonce each time as otherwise, we'd
|
// We reset the internal nonce each time as otherwise, we'd
|
||||||
// continue to increment it which would cause a decryption
|
// continue to increment it which would cause a decryption
|
||||||
// failure.
|
// failure.
|
||||||
noiseLocalConn.noise.recvCipher.nonce = nonceValue
|
localNoise.recvCipher.nonce = nonceValue
|
||||||
}
|
}
|
||||||
require.NoError(t, benchErr)
|
require.NoError(t, benchErr)
|
||||||
}
|
}
|
||||||
@@ -87,15 +89,16 @@ func BenchmarkWriteMessage(b *testing.B) {
|
|||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|
||||||
|
noise := noiseLocalConn.noise.Load()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
// Write our massive message, then call flush to actually write
|
// Write our massive message, then call flush to actually write
|
||||||
// the encrypted message This simulates a full write operation
|
// the encrypted message This simulates a full write operation
|
||||||
// to a network.
|
// to a network.
|
||||||
err := noiseLocalConn.noise.WriteMessage(largeMsg)
|
err := noise.WriteMessage(largeMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("WriteMessage failed: %v", err)
|
b.Fatalf("WriteMessage failed: %v", err)
|
||||||
}
|
}
|
||||||
_, err = noiseLocalConn.noise.Flush(discard)
|
_, err = noise.Flush(discard)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatalf("Flush failed: %v", err)
|
b.Fatalf("Flush failed: %v", err)
|
||||||
}
|
}
|
||||||
|
127
brontide/conn.go
127
brontide/conn.go
@@ -2,9 +2,11 @@ package brontide
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"errors"
|
||||||
"io"
|
"io"
|
||||||
"math"
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
@@ -13,6 +15,9 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/tor"
|
"github.com/lightningnetwork/lnd/tor"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ErrConnClosed is returned when operations are attempted on a closed connection.
|
||||||
|
var ErrConnClosed = errors.New("brontide: connection closed")
|
||||||
|
|
||||||
// Conn is an implementation of net.Conn which enforces an authenticated key
|
// Conn is an implementation of net.Conn which enforces an authenticated key
|
||||||
// exchange and message encryption protocol dubbed "Brontide" after initial TCP
|
// exchange and message encryption protocol dubbed "Brontide" after initial TCP
|
||||||
// connection establishment. In the case of a successful handshake, all
|
// connection establishment. In the case of a successful handshake, all
|
||||||
@@ -22,9 +27,15 @@ import (
|
|||||||
type Conn struct {
|
type Conn struct {
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
|
|
||||||
noise *Machine
|
// noise is stored as an atomic pointer to allow safe cleanup on Close()
|
||||||
|
// while preventing nil pointer dereferences from concurrent operations.
|
||||||
|
noise atomic.Pointer[Machine]
|
||||||
|
|
||||||
readBuf bytes.Buffer
|
readBuf bytes.Buffer
|
||||||
|
|
||||||
|
// closed is an atomic flag that tracks whether Close() has been called.
|
||||||
|
// This prevents nil pointer dereferences if methods are called after Close().
|
||||||
|
closed atomic.Uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
// A compile-time assertion to ensure that Conn meets the net.Conn interface.
|
// A compile-time assertion to ensure that Conn meets the net.Conn interface.
|
||||||
@@ -46,12 +57,13 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
|||||||
}
|
}
|
||||||
|
|
||||||
b := &Conn{
|
b := &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
noise: NewBrontideMachine(true, local, netAddr.IdentityKey),
|
|
||||||
}
|
}
|
||||||
|
b.noise.Store(NewBrontideMachine(true, local, netAddr.IdentityKey))
|
||||||
|
|
||||||
// Initiate the handshake by sending the first act to the receiver.
|
// Initiate the handshake by sending the first act to the receiver.
|
||||||
actOne, err := b.noise.GenActOne()
|
noise := b.noise.Load()
|
||||||
|
actOne, err := noise.GenActOne()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.conn.Close()
|
b.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -79,14 +91,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
|||||||
b.conn.Close()
|
b.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if err := b.noise.RecvActTwo(actTwo); err != nil {
|
if err := noise.RecvActTwo(actTwo); err != nil {
|
||||||
b.conn.Close()
|
b.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, complete the handshake by sending over our encrypted static
|
// Finally, complete the handshake by sending over our encrypted static
|
||||||
// key and execute the final ECDH operation.
|
// key and execute the final ECDH operation.
|
||||||
actThree, err := b.noise.GenActThree()
|
actThree, err := noise.GenActThree()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.conn.Close()
|
b.conn.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -116,7 +128,14 @@ func Dial(local keychain.SingleKeyECDH, netAddr *lnwire.NetAddress,
|
|||||||
// appropriately, it is preferred that they use the split ReadNextHeader and
|
// appropriately, it is preferred that they use the split ReadNextHeader and
|
||||||
// ReadNextBody methods so that the deadlines can be set appropriately on each.
|
// ReadNextBody methods so that the deadlines can be set appropriately on each.
|
||||||
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
func (c *Conn) ReadNextMessage() ([]byte, error) {
|
||||||
return c.noise.ReadMessage(c.conn)
|
if c.closed.Load() == 1 {
|
||||||
|
return nil, ErrConnClosed
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return nil, ErrConnClosed
|
||||||
|
}
|
||||||
|
return noise.ReadMessage(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadNextHeader uses the connection to read the next header from the brontide
|
// ReadNextHeader uses the connection to read the next header from the brontide
|
||||||
@@ -124,7 +143,14 @@ func (c *Conn) ReadNextMessage() ([]byte, error) {
|
|||||||
// return the packet length (including MAC overhead) that is expected from the
|
// return the packet length (including MAC overhead) that is expected from the
|
||||||
// subsequent call to ReadNextBody.
|
// subsequent call to ReadNextBody.
|
||||||
func (c *Conn) ReadNextHeader() (uint32, error) {
|
func (c *Conn) ReadNextHeader() (uint32, error) {
|
||||||
return c.noise.ReadHeader(c.conn)
|
if c.closed.Load() == 1 {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
return noise.ReadHeader(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadNextBody uses the connection to read the next message body from the
|
// ReadNextBody uses the connection to read the next message body from the
|
||||||
@@ -132,7 +158,14 @@ func (c *Conn) ReadNextHeader() (uint32, error) {
|
|||||||
// and return the decrypted payload. The provided buffer MUST be the packet
|
// and return the decrypted payload. The provided buffer MUST be the packet
|
||||||
// length returned by the preceding call to ReadNextHeader.
|
// length returned by the preceding call to ReadNextHeader.
|
||||||
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
||||||
return c.noise.ReadBody(c.conn, buf)
|
if c.closed.Load() == 1 {
|
||||||
|
return nil, ErrConnClosed
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return nil, ErrConnClosed
|
||||||
|
}
|
||||||
|
return noise.ReadBody(c.conn, buf)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read reads data from the connection. Read can be made to time out and
|
// Read reads data from the connection. Read can be made to time out and
|
||||||
@@ -141,13 +174,21 @@ func (c *Conn) ReadNextBody(buf []byte) ([]byte, error) {
|
|||||||
//
|
//
|
||||||
// Part of the net.Conn interface.
|
// Part of the net.Conn interface.
|
||||||
func (c *Conn) Read(b []byte) (n int, err error) {
|
func (c *Conn) Read(b []byte) (n int, err error) {
|
||||||
|
if c.closed.Load() == 1 {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
|
||||||
// In order to reconcile the differences between the record abstraction
|
// In order to reconcile the differences between the record abstraction
|
||||||
// of our AEAD connection, and the stream abstraction of TCP, we
|
// of our AEAD connection, and the stream abstraction of TCP, we
|
||||||
// maintain an intermediate read buffer. If this buffer becomes
|
// maintain an intermediate read buffer. If this buffer becomes
|
||||||
// depleted, then we read the next record, and feed it into the
|
// depleted, then we read the next record, and feed it into the
|
||||||
// buffer. Otherwise, we read directly from the buffer.
|
// buffer. Otherwise, we read directly from the buffer.
|
||||||
if c.readBuf.Len() == 0 {
|
if c.readBuf.Len() == 0 {
|
||||||
plaintext, err := c.noise.ReadMessage(c.conn)
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
plaintext, err := noise.ReadMessage(c.conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
@@ -166,14 +207,22 @@ func (c *Conn) Read(b []byte) (n int, err error) {
|
|||||||
//
|
//
|
||||||
// Part of the net.Conn interface.
|
// Part of the net.Conn interface.
|
||||||
func (c *Conn) Write(b []byte) (n int, err error) {
|
func (c *Conn) Write(b []byte) (n int, err error) {
|
||||||
|
if c.closed.Load() == 1 {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
|
||||||
// If the message doesn't require any chunking, then we can go ahead
|
// If the message doesn't require any chunking, then we can go ahead
|
||||||
// with a single write.
|
// with a single write.
|
||||||
if len(b) <= math.MaxUint16 {
|
if len(b) <= math.MaxUint16 {
|
||||||
err = c.noise.WriteMessage(b)
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
err = noise.WriteMessage(b)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
return c.noise.Flush(c.conn)
|
return noise.Flush(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we need to split the message into fragments, then we'll write
|
// If we need to split the message into fragments, then we'll write
|
||||||
@@ -192,11 +241,15 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||||||
// Slice off the next chunk to be written based on our running
|
// Slice off the next chunk to be written based on our running
|
||||||
// counter and next chunk size.
|
// counter and next chunk size.
|
||||||
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
chunk := b[bytesWritten : bytesWritten+chunkSize]
|
||||||
if err := c.noise.WriteMessage(chunk); err != nil {
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return bytesWritten, ErrConnClosed
|
||||||
|
}
|
||||||
|
if err := noise.WriteMessage(chunk); err != nil {
|
||||||
return bytesWritten, err
|
return bytesWritten, err
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := c.noise.Flush(c.conn)
|
n, err := noise.Flush(c.conn)
|
||||||
bytesWritten += n
|
bytesWritten += n
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return bytesWritten, err
|
return bytesWritten, err
|
||||||
@@ -214,7 +267,14 @@ func (c *Conn) Write(b []byte) (n int, err error) {
|
|||||||
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
// NOTE: This DOES NOT write the message to the wire, it should be followed by a
|
||||||
// call to Flush to ensure the message is written.
|
// call to Flush to ensure the message is written.
|
||||||
func (c *Conn) WriteMessage(b []byte) error {
|
func (c *Conn) WriteMessage(b []byte) error {
|
||||||
return c.noise.WriteMessage(b)
|
if c.closed.Load() == 1 {
|
||||||
|
return ErrConnClosed
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return ErrConnClosed
|
||||||
|
}
|
||||||
|
return noise.WriteMessage(b)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush attempts to write a message buffered using WriteMessage to the
|
// Flush attempts to write a message buffered using WriteMessage to the
|
||||||
@@ -226,7 +286,14 @@ func (c *Conn) WriteMessage(b []byte) error {
|
|||||||
//
|
//
|
||||||
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
// NOTE: It is safe to call this method again iff a timeout error is returned.
|
||||||
func (c *Conn) Flush() (int, error) {
|
func (c *Conn) Flush() (int, error) {
|
||||||
return c.noise.Flush(c.conn)
|
if c.closed.Load() == 1 {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return 0, ErrConnClosed
|
||||||
|
}
|
||||||
|
return noise.Flush(c.conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close closes the connection. Any blocked Read or Write operations will be
|
// Close closes the connection. Any blocked Read or Write operations will be
|
||||||
@@ -234,8 +301,16 @@ func (c *Conn) Flush() (int, error) {
|
|||||||
//
|
//
|
||||||
// Part of the net.Conn interface.
|
// Part of the net.Conn interface.
|
||||||
func (c *Conn) Close() error {
|
func (c *Conn) Close() error {
|
||||||
|
// Use compare-and-swap to ensure Close is only executed once.
|
||||||
|
if !c.closed.CompareAndSwap(0, 1) {
|
||||||
|
return ErrConnClosed
|
||||||
|
}
|
||||||
|
|
||||||
// Clear the state we created to be able to handle this connection.
|
// Clear the state we created to be able to handle this connection.
|
||||||
c.noise = nil
|
// We atomically swap the noise pointer to nil, which allows the ~64KB
|
||||||
|
// buffers to be garbage collected immediately while preventing any
|
||||||
|
// nil pointer dereferences from concurrent operations.
|
||||||
|
c.noise.Store(nil)
|
||||||
c.readBuf = bytes.Buffer{}
|
c.readBuf = bytes.Buffer{}
|
||||||
|
|
||||||
return c.conn.Close()
|
return c.conn.Close()
|
||||||
@@ -283,10 +358,24 @@ func (c *Conn) SetWriteDeadline(t time.Time) error {
|
|||||||
|
|
||||||
// RemotePub returns the remote peer's static public key.
|
// RemotePub returns the remote peer's static public key.
|
||||||
func (c *Conn) RemotePub() *btcec.PublicKey {
|
func (c *Conn) RemotePub() *btcec.PublicKey {
|
||||||
return c.noise.remoteStatic
|
if c.closed.Load() == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return noise.remoteStatic
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalPub returns the local peer's static public key.
|
// LocalPub returns the local peer's static public key.
|
||||||
func (c *Conn) LocalPub() *btcec.PublicKey {
|
func (c *Conn) LocalPub() *btcec.PublicKey {
|
||||||
return c.noise.localStatic.PubKey()
|
if c.closed.Load() == 1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
noise := c.noise.Load()
|
||||||
|
if noise == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return noise.localStatic.PubKey()
|
||||||
}
|
}
|
||||||
|
@@ -116,9 +116,9 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
|||||||
remoteAddr := conn.RemoteAddr().String()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
|
|
||||||
brontideConn := &Conn{
|
brontideConn := &Conn{
|
||||||
conn: conn,
|
conn: conn,
|
||||||
noise: NewBrontideMachine(false, l.localStatic, nil),
|
|
||||||
}
|
}
|
||||||
|
brontideConn.noise.Store(NewBrontideMachine(false, l.localStatic, nil))
|
||||||
|
|
||||||
// We'll ensure that we get ActOne from the remote peer in a timely
|
// We'll ensure that we get ActOne from the remote peer in a timely
|
||||||
// manner. If they don't respond within handshakeReadTimeout, then
|
// manner. If they don't respond within handshakeReadTimeout, then
|
||||||
@@ -139,7 +139,8 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
|||||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := brontideConn.noise.RecvActOne(actOne); err != nil {
|
noise := brontideConn.noise.Load()
|
||||||
|
if err := noise.RecvActOne(actOne); err != nil {
|
||||||
brontideConn.conn.Close()
|
brontideConn.conn.Close()
|
||||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||||
return
|
return
|
||||||
@@ -147,7 +148,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
|||||||
|
|
||||||
// Next, progress the handshake processes by sending over our ephemeral
|
// Next, progress the handshake processes by sending over our ephemeral
|
||||||
// key for the session along with an authenticating tag.
|
// key for the session along with an authenticating tag.
|
||||||
actTwo, err := brontideConn.noise.GenActTwo()
|
actTwo, err := noise.GenActTwo()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
brontideConn.conn.Close()
|
brontideConn.conn.Close()
|
||||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||||
@@ -184,7 +185,7 @@ func (l *Listener) doHandshake(conn net.Conn) {
|
|||||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if err := brontideConn.noise.RecvActThree(actThree); err != nil {
|
if err := noise.RecvActThree(actThree); err != nil {
|
||||||
brontideConn.conn.Close()
|
brontideConn.conn.Close()
|
||||||
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
l.rejectConn(rejectedConnErr(err, remoteAddr))
|
||||||
return
|
return
|
||||||
|
Reference in New Issue
Block a user