From 405d4e5f73cc9879c69266669294d1da659cd2f2 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 12 Jul 2023 16:26:12 -0600 Subject: [PATCH 1/8] chainntnfs: introduce system for chain state tracking and querying This change adds a new subsystem that is responsible for providing an up to date view of some global chainstate parameters. --- chainntnfs/best_block_view.go | 135 +++++++++++++++++++++++++++++ chainntnfs/best_block_view_test.go | 112 ++++++++++++++++++++++++ 2 files changed, 247 insertions(+) create mode 100644 chainntnfs/best_block_view.go create mode 100644 chainntnfs/best_block_view_test.go diff --git a/chainntnfs/best_block_view.go b/chainntnfs/best_block_view.go new file mode 100644 index 000000000..c043e68e7 --- /dev/null +++ b/chainntnfs/best_block_view.go @@ -0,0 +1,135 @@ +package chainntnfs + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + + "github.com/btcsuite/btcd/wire" +) + +// BestBlockView is an interface that allows the querying of the most +// up-to-date blockchain state with low overhead. Valid implementations of this +// interface must track the latest chain state. +type BestBlockView interface { + // BestHeight gets the most recent block height known to the view. + BestHeight() (uint32, error) + + // BestBlockHeader gets the most recent block header known to the view. + BestBlockHeader() (*wire.BlockHeader, error) +} + +// BestBlockTracker is a tiny subsystem that tracks the blockchain tip +// and saves the most recent tip information in memory for querying. It is a +// valid implementation of BestBlockView and additionally includes +// methods for starting and stopping the system. +type BestBlockTracker struct { + notifier ChainNotifier + blockNtfnStream *BlockEpochEvent + current atomic.Pointer[BlockEpoch] + mu sync.Mutex + quit chan struct{} + wg sync.WaitGroup +} + +// This is a compile time check to ensure that BestBlockTracker implements +// BestBlockView. +var _ BestBlockView = (*BestBlockTracker)(nil) + +// NewBestBlockTracker creates a new BestBlockTracker that isn't running yet. +// It will not provide up to date information unless it has been started. The +// ChainNotifier parameter must also be started prior to starting the +// BestBlockTracker. +func NewBestBlockTracker(chainNotifier ChainNotifier) *BestBlockTracker { + return &BestBlockTracker{ + notifier: chainNotifier, + blockNtfnStream: nil, + quit: make(chan struct{}), + } +} + +// BestHeight gets the most recent block height known to the +// BestBlockTracker. +func (t *BestBlockTracker) BestHeight() (uint32, error) { + epoch := t.current.Load() + if epoch == nil { + return 0, errors.New("best block height not yet known") + } + + return uint32(epoch.Height), nil +} + +// BestBlockHeader gets the most recent block header known to the +// BestBlockTracker. +func (t *BestBlockTracker) BestBlockHeader() (*wire.BlockHeader, error) { + epoch := t.current.Load() + if epoch == nil { + return nil, errors.New("best block header not yet known") + } + + return epoch.BlockHeader, nil +} + +// updateLoop is a helper that subscribes to the underlying BlockEpochEvent +// stream and updates the internal values to match the new BlockEpochs that +// are discovered. +// +// MUST be run as a goroutine. +func (t *BestBlockTracker) updateLoop() { + defer t.wg.Done() + for { + select { + case epoch, ok := <-t.blockNtfnStream.Epochs: + if !ok { + Log.Error("dead epoch stream in " + + "BestBlockTracker") + + return + } + t.current.Store(epoch) + case <-t.quit: + t.current.Store(nil) + return + } + } +} + +// Start starts the BestBlockTracker. It is an error to start it if it +// is already started. +func (t *BestBlockTracker) Start() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.blockNtfnStream != nil { + return fmt.Errorf("BestBlockTracker is already started") + } + + var err error + t.blockNtfnStream, err = t.notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return err + } + + t.wg.Add(1) + go t.updateLoop() + + return nil +} + +// Stop stops the BestBlockTracker. It is an error to stop it if it has +// not been started or if it has already been stopped. +func (t *BestBlockTracker) Stop() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.blockNtfnStream == nil { + return fmt.Errorf("BestBlockTracker is not running") + } + close(t.quit) + t.wg.Wait() + t.blockNtfnStream.Cancel() + t.blockNtfnStream = nil + + return nil +} diff --git a/chainntnfs/best_block_view_test.go b/chainntnfs/best_block_view_test.go new file mode 100644 index 000000000..2a55bc8b0 --- /dev/null +++ b/chainntnfs/best_block_view_test.go @@ -0,0 +1,112 @@ +package chainntnfs_test + +import ( + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" + "github.com/lightningnetwork/lnd/lntest/mock" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/stretchr/testify/require" +) + +type blockEpoch chainntnfs.BlockEpoch + +func (blockEpoch) Generate(r *rand.Rand, size int) reflect.Value { + var chainHash, prevBlockHash, merkleRootHash chainhash.Hash + r.Read(chainHash[:]) + r.Read(prevBlockHash[:]) + r.Read(merkleRootHash[:]) + + return reflect.ValueOf(blockEpoch(chainntnfs.BlockEpoch{ + Hash: &chainHash, + Height: r.Int31n(1000000), + BlockHeader: &wire.BlockHeader{ + Version: 2, + PrevBlock: prevBlockHash, + MerkleRoot: merkleRootHash, + Timestamp: time.Now(), + Bits: r.Uint32(), + Nonce: r.Uint32(), + }, + })) +} + +// TestBestBlockTracker ensures that the most recent event pushed on the +// underlying EpochChan is remembered by the BestBlockView functions as well +// as testing the idempotence of the BestBlockView interface. +func TestBestBlockTracker(t *testing.T) { + t.Parallel() + + notifier := &mock.ChainNotifier{ + SpendChan: nil, + EpochChan: make(chan *chainntnfs.BlockEpoch), + ConfChan: nil, + } + + chainNotifierI := chainntnfs.ChainNotifier(notifier) + + tracker := chainntnfs.NewBestBlockTracker(chainNotifierI) + require.Nil(t, tracker.Start(), + "BestBlockTacker could not be started") + + // we have to limit test cases because the poll interval of + // wait.Predicate isn't tight enough to support the usual 100 + cfg := quick.Config{MaxCount: 50} + correctness := func(epochRand blockEpoch) bool { + epoch := chainntnfs.BlockEpoch(epochRand) + notifier.EpochChan <- &epoch + + // wait for new block to propagate + err := wait.Predicate( + func() bool { + _, err := tracker.BestHeight() + return err == nil + }, + 1*time.Second, + ) + require.Nil(t, err, + "BestBlockTracker: block propagation timeout") + + height, _ := tracker.BestHeight() + header, _ := tracker.BestBlockHeader() + + return height == uint32(epoch.Height) && + header == epoch.BlockHeader + } + idempotence := func(epochRand blockEpoch) bool { + epoch := chainntnfs.BlockEpoch(epochRand) + notifier.EpochChan <- &epoch + + // wait for new block to propagate + err := wait.Predicate( + func() bool { + _, err := tracker.BestHeight() + return err == nil + }, + 1*time.Second, + ) + require.Nil(t, err, + "ChainStateTracker: block propagation timeout") + + height0, _ := tracker.BestHeight() + height1, _ := tracker.BestHeight() + header0, _ := tracker.BestBlockHeader() + header1, _ := tracker.BestBlockHeader() + + return height0 == height1 && header0 == header1 + } + err := quick.Check(correctness, &cfg) + require.Nil(t, err, + "ChainStateTracker does not give up to date info: %v", err) + + require.Nil(t, quick.Check(idempotence, &cfg), + "ChainStateTracker is not idempotent") + + require.Nil(t, tracker.Stop(), "ChainStateTracker could not be stopped") +} From 7c403b439c5de2714517c7a1820941677b305e11 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 12 Jul 2023 16:28:26 -0600 Subject: [PATCH 2/8] chainreg: Add BestBlockTracker to ChainControl This commit takes the best block tracker and adds it to the ChainControl objects and appropriately initializes it on ChainControl creation --- chainreg/chainregistry.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/chainreg/chainregistry.go b/chainreg/chainregistry.go index 3c54a0aef..93d5c1dd1 100644 --- a/chainreg/chainregistry.go +++ b/chainreg/chainregistry.go @@ -145,6 +145,10 @@ type PartialChainControl struct { // interested in. ChainNotifier chainntnfs.ChainNotifier + // BestBlockTracker is used to maintain a view of the global + // chain state that changes over time + BestBlockTracker *chainntnfs.BestBlockTracker + // MempoolNotifier is used to watch for spending events happened in // mempool. MempoolNotifier chainntnfs.MempoolWatcher @@ -667,6 +671,9 @@ func NewPartialChainControl(cfg *Config) (*PartialChainControl, func(), error) { cfg.Bitcoin.Node) } + cc.BestBlockTracker = + chainntnfs.NewBestBlockTracker(cc.ChainNotifier) + switch { // If the fee URL isn't set, and the user is running mainnet, then // we'll return an error to instruct them to set a proper fee From 74b30a71cba6462e7d816b9a21c891ce0554a410 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 12 Jul 2023 16:30:34 -0600 Subject: [PATCH 3/8] lnd: start the BestBlockTracker during server startup --- server.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/server.go b/server.go index e09a6f062..0882fa23c 100644 --- a/server.go +++ b/server.go @@ -1901,6 +1901,12 @@ func (s *server) Start() error { } cleanup = cleanup.add(s.cc.ChainNotifier.Stop) + if err := s.cc.BestBlockTracker.Start(); err != nil { + startErr = err + return + } + cleanup = cleanup.add(s.cc.BestBlockTracker.Stop) + if err := s.channelNotifier.Start(); err != nil { startErr = err return @@ -2282,6 +2288,10 @@ func (s *server) Stop() error { if err := s.cc.ChainNotifier.Stop(); err != nil { srvrLog.Warnf("Unable to stop ChainNotifier: %v", err) } + if err := s.cc.BestBlockTracker.Stop(); err != nil { + srvrLog.Warnf("Unable to stop BestBlockTracker: %v", + err) + } s.chanEventStore.Stop() s.missionControl.StopStoreTicker() From ac265812fc5de1934038eeb7e7973e02923c4eeb Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 12 Jul 2023 17:26:43 -0600 Subject: [PATCH 4/8] peer+lnd: make BestBlockView available to Brontide --- peer/brontide.go | 4 ++++ server.go | 1 + 2 files changed, 5 insertions(+) diff --git a/peer/brontide.go b/peer/brontide.go index 12ea39cf3..e7ebf974c 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -233,6 +233,10 @@ type Config struct { // transaction. ChainNotifier chainntnfs.ChainNotifier + // BestBlockView is used to efficiently query for up-to-date + // blockchain state information + BestBlockView chainntnfs.BestBlockView + // RoutingPolicy is used to set the forwarding policy for links created by // the Brontide. RoutingPolicy models.ForwardingPolicy diff --git a/server.go b/server.go index 0882fa23c..c47f79757 100644 --- a/server.go +++ b/server.go @@ -3837,6 +3837,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, SigPool: s.sigPool, Wallet: s.cc.Wallet, ChainNotifier: s.cc.ChainNotifier, + BestBlockView: s.cc.BestBlockTracker, RoutingPolicy: s.cc.RoutingPolicy, Sphinx: s.sphinx, WitnessBeacon: s.witnessBeacon, From 1eab7826d2444a77078acc9cdfbbb460d56f6773 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Wed, 12 Jul 2023 17:27:29 -0600 Subject: [PATCH 5/8] peer: abstract out ping payload generation from the pingHandler This change makes the generation of the ping payload a no-arg closure parameter, relieving the pingHandler of having to directly monitor the chain state. This makes use of the BestBlockView that was introduced in earlier commits. --- peer/brontide.go | 67 ++++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 36 deletions(-) diff --git a/peer/brontide.go b/peer/brontide.go index e7ebf974c..e151ebeb7 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -649,7 +649,35 @@ func (p *Brontide) Start() error { go p.writeHandler() go p.readHandler() go p.channelManager() - go p.pingHandler() + + var ( + lastBlockHeader *wire.BlockHeader + lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte + ) + newPingPayload := func() []byte { + // We query the BestBlockHeader from our BestBlockView each time + // this is called, and update our serialized block header if + // they differ. Over time, we'll use this to disseminate the + // latest block header between all our peers, which can later be + // used to cross-check our own view of the network to mitigate + // various types of eclipse attacks. + header, err := p.cfg.BestBlockView.BestBlockHeader() + if err == nil && header != lastBlockHeader { + buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0]) + err := header.Serialize(buf) + if err == nil { + lastBlockHeader = header + } else { + p.log.Warn("unable to serialize current block" + + "header for ping payload generation." + + "This should be impossible and means" + + "there is an implementation bug.") + } + } + + return lastSerializedBlockHeader[:] + } + go p.pingHandler(newPingPayload) // Signal to any external processes that the peer is now active. close(p.activeSignal) @@ -2257,7 +2285,7 @@ func (p *Brontide) queueHandler() { // connection is still active. // // NOTE: This method MUST be run as a goroutine. -func (p *Brontide) pingHandler() { +func (p *Brontide) pingHandler(newPayload func() []byte) { defer p.wg.Done() pingTicker := time.NewTicker(pingInterval) @@ -2266,47 +2294,14 @@ func (p *Brontide) pingHandler() { // TODO(roasbeef): make dynamic in order to create fake cover traffic const numPongBytes = 16 - blockEpochs, err := p.cfg.ChainNotifier.RegisterBlockEpochNtfn(nil) - if err != nil { - p.log.Errorf("unable to establish block epoch "+ - "subscription: %v", err) - return - } - defer blockEpochs.Cancel() - - var ( - pingPayload [wire.MaxBlockHeaderPayload]byte - blockHeader *wire.BlockHeader - ) out: for { select { - // Each time a new block comes in, we'll copy the raw header - // contents over to our ping payload declared above. Over time, - // we'll use this to disseminate the latest block header - // between all our peers, which can later be used to - // cross-check our own view of the network to mitigate various - // types of eclipse attacks. - case epoch, ok := <-blockEpochs.Epochs: - if !ok { - p.log.Debugf("block notifications " + - "canceled") - return - } - - blockHeader = epoch.BlockHeader - headerBuf := bytes.NewBuffer(pingPayload[0:0]) - err := blockHeader.Serialize(headerBuf) - if err != nil { - p.log.Errorf("unable to encode header: %v", - err) - } - case <-pingTicker.C: pingMsg := &lnwire.Ping{ NumPongBytes: numPongBytes, - PaddingBytes: pingPayload[:], + PaddingBytes: newPayload(), } p.queueMsg(pingMsg, nil) From 99226e37efc7c0b72f53c66d0e35a7799d1f6ab0 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Fri, 14 Jul 2023 11:24:10 -0600 Subject: [PATCH 6/8] peer: Add machinery to track the state and validity of remote pongs This commit refactors some of the bookkeeping around the ping logic inside of the Brontide. If the pong response is noncompliant with the spec or if it times out, we disconnect from the peer. --- peer/brontide.go | 172 ++++++++++++------------ peer/ping_manager.go | 266 ++++++++++++++++++++++++++++++++++++++ peer/ping_manager_test.go | 88 +++++++++++++ 3 files changed, 437 insertions(+), 89 deletions(-) create mode 100644 peer/ping_manager.go create mode 100644 peer/ping_manager_test.go diff --git a/peer/brontide.go b/peer/brontide.go index e151ebeb7..eba3b30cc 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -5,6 +5,7 @@ import ( "container/list" "errors" "fmt" + "math/rand" "net" "sync" "sync/atomic" @@ -50,6 +51,12 @@ const ( // pingInterval is the interval at which ping messages are sent. pingInterval = 1 * time.Minute + // pingTimeout is the amount of time we will wait for a pong response + // before considering the peer to be unresponsive. + // + // This MUST be a smaller value than the pingInterval. + pingTimeout = 30 * time.Second + // idleTimeout is the duration of inactivity before we time out a peer. idleTimeout = 5 * time.Minute @@ -379,15 +386,7 @@ type Brontide struct { bytesReceived uint64 bytesSent uint64 - // pingTime is a rough estimate of the RTT (round-trip-time) between us - // and the connected peer. This time is expressed in microseconds. - // To be used atomically. - // TODO(roasbeef): also use a WMA or EMA? - pingTime int64 - - // pingLastSend is the Unix time expressed in nanoseconds when we sent - // our last ping message. To be used atomically. - pingLastSend int64 + pingManager *PingManager // lastPingPayload stores an unsafe pointer wrapped as an atomic // variable which points to the last payload the remote party sent us @@ -525,6 +524,66 @@ func NewBrontide(cfg Config) *Brontide { log: build.NewPrefixLog(logPrefix, peerLog), } + var ( + lastBlockHeader *wire.BlockHeader + lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte + ) + newPingPayload := func() []byte { + // We query the BestBlockHeader from our BestBlockView each time + // this is called, and update our serialized block header if + // they differ. Over time, we'll use this to disseminate the + // latest block header between all our peers, which can later be + // used to cross-check our own view of the network to mitigate + // various types of eclipse attacks. + header, err := p.cfg.BestBlockView.BestBlockHeader() + if err != nil && header == lastBlockHeader { + return lastSerializedBlockHeader[:] + } + + buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0]) + err = header.Serialize(buf) + if err == nil { + lastBlockHeader = header + } else { + p.log.Warn("unable to serialize current block" + + "header for ping payload generation." + + "This should be impossible and means" + + "there is an implementation bug.") + } + + return lastSerializedBlockHeader[:] + } + + // TODO(roasbeef): make dynamic in order to + // create fake cover traffic + // NOTE(proofofkeags): this was changed to be + // dynamic to allow better pong identification, + // however, more thought is needed to make this + // actually usable as a traffic decoy + randPongSize := func() uint16 { + return uint16( + // We don't need cryptographic randomness here. + /* #nosec */ + rand.Intn(lnwire.MaxPongBytes + 1), + ) + } + + p.pingManager = NewPingManager(&PingManagerConfig{ + NewPingPayload: newPingPayload, + NewPongSize: randPongSize, + IntervalDuration: pingInterval, + TimeoutDuration: pingTimeout, + SendPing: func(ping *lnwire.Ping) { + p.queueMsg(ping, nil) + }, + OnPongFailure: func(err error) { + eStr := "pong response failure for %s: %v " + + "-- disconnecting" + p.log.Warnf(eStr, p, err) + p.Disconnect(fmt.Errorf(eStr, p, err)) + }, + }) + return p } @@ -644,40 +703,16 @@ func (p *Brontide) Start() error { p.startTime = time.Now() - p.wg.Add(5) + err = p.pingManager.Start() + if err != nil { + return fmt.Errorf("could not start ping manager %w", err) + } + + p.wg.Add(4) go p.queueHandler() go p.writeHandler() - go p.readHandler() go p.channelManager() - - var ( - lastBlockHeader *wire.BlockHeader - lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte - ) - newPingPayload := func() []byte { - // We query the BestBlockHeader from our BestBlockView each time - // this is called, and update our serialized block header if - // they differ. Over time, we'll use this to disseminate the - // latest block header between all our peers, which can later be - // used to cross-check our own view of the network to mitigate - // various types of eclipse attacks. - header, err := p.cfg.BestBlockView.BestBlockHeader() - if err == nil && header != lastBlockHeader { - buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0]) - err := header.Serialize(buf) - if err == nil { - lastBlockHeader = header - } else { - p.log.Warn("unable to serialize current block" + - "header for ping payload generation." + - "This should be impossible and means" + - "there is an implementation bug.") - } - } - - return lastSerializedBlockHeader[:] - } - go p.pingHandler(newPingPayload) + go p.readHandler() // Signal to any external processes that the peer is now active. close(p.activeSignal) @@ -1159,6 +1194,11 @@ func (p *Brontide) Disconnect(reason error) { p.cfg.Conn.Close() close(p.quit) + + if err := p.pingManager.Stop(); err != nil { + p.log.Errorf("couldn't stop pingManager during disconnect: %v", + err) + } } // String returns the string representation of this peer. @@ -1606,12 +1646,8 @@ out: switch msg := nextMsg.(type) { case *lnwire.Pong: // When we receive a Pong message in response to our - // last ping message, we'll use the time in which we - // sent the ping message to measure a rough estimate of - // round trip time. - pingSendTime := atomic.LoadInt64(&p.pingLastSend) - delay := (time.Now().UnixNano() - pingSendTime) / 1000 - atomic.StoreInt64(&p.pingTime, delay) + // last ping message, we send it to the pingManager + p.pingManager.ReceivedPong(msg) case *lnwire.Ping: // First, we'll store their latest ping payload within @@ -2137,17 +2173,6 @@ out: for { select { case outMsg := <-p.sendQueue: - // If we're about to send a ping message, then log the - // exact time in which we send the message so we can - // use the delay as a rough estimate of latency to the - // remote peer. - if _, ok := outMsg.msg.(*lnwire.Ping); ok { - // TODO(roasbeef): do this before the write? - // possibly account for processing within func? - now := time.Now().UnixNano() - atomic.StoreInt64(&p.pingLastSend, now) - } - // Record the time at which we first attempt to send the // message. startTime := time.Now() @@ -2280,40 +2305,9 @@ func (p *Brontide) queueHandler() { } } -// pingHandler is responsible for periodically sending ping messages to the -// remote peer in order to keep the connection alive and/or determine if the -// connection is still active. -// -// NOTE: This method MUST be run as a goroutine. -func (p *Brontide) pingHandler(newPayload func() []byte) { - defer p.wg.Done() - - pingTicker := time.NewTicker(pingInterval) - defer pingTicker.Stop() - - // TODO(roasbeef): make dynamic in order to create fake cover traffic - const numPongBytes = 16 - -out: - for { - select { - case <-pingTicker.C: - - pingMsg := &lnwire.Ping{ - NumPongBytes: numPongBytes, - PaddingBytes: newPayload(), - } - - p.queueMsg(pingMsg, nil) - case <-p.quit: - break out - } - } -} - // PingTime returns the estimated ping time to the peer in microseconds. func (p *Brontide) PingTime() int64 { - return atomic.LoadInt64(&p.pingTime) + return p.pingManager.GetPingTimeMicroSeconds() } // queueMsg adds the lnwire.Message to the back of the high priority send queue. diff --git a/peer/ping_manager.go b/peer/ping_manager.go new file mode 100644 index 000000000..3e62cb097 --- /dev/null +++ b/peer/ping_manager.go @@ -0,0 +1,266 @@ +package peer + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// PingManagerConfig is a structure containing various parameters that govern +// how the PingManager behaves. +type PingManagerConfig struct { + + // NewPingPayload is a closure that returns the payload to be packaged + // in the Ping message. + NewPingPayload func() []byte + + // NewPongSize is a closure that returns a random value between + // [0, lnwire.MaxPongBytes]. This random value helps to more effectively + // pair Pong messages with Ping. + NewPongSize func() uint16 + + // IntervalDuration is the Duration between attempted pings. + IntervalDuration time.Duration + + // TimeoutDuration is the Duration we wait before declaring a ping + // attempt failed. + TimeoutDuration time.Duration + + // SendPing is a closure that is responsible for sending the Ping + // message out to our peer + SendPing func(ping *lnwire.Ping) + + // OnPongFailure is a closure that is responsible for executing the + // logic when a Pong message is either late or does not match our + // expectations for that Pong + OnPongFailure func(error) +} + +// PingManager is a structure that is designed to manage the internal state +// of the ping pong lifecycle with the remote peer. We assume there is only one +// ping outstanding at once. +// +// NOTE: This structure MUST be initialized with NewPingManager. +type PingManager struct { + cfg *PingManagerConfig + + // pingTime is a rough estimate of the RTT (round-trip-time) between us + // and the connected peer. + // To be used atomically. + // TODO(roasbeef): also use a WMA or EMA? + pingTime atomic.Pointer[time.Duration] + + // pingLastSend is the time when we sent our last ping message. + // To be used atomically. + pingLastSend *time.Time + + // outstandingPongSize is the current size of the requested pong + // payload. This value can only validly range from [0,65531]. Any + // value < 0 is interpreted as if there is no outstanding ping message. + outstandingPongSize int32 + + // pingTicker is a pointer to a Ticker that fires on every ping + // interval. + pingTicker *time.Ticker + + // pingTimeout is a Timer that will fire when we want to time out a + // ping + pingTimeout *time.Timer + + // pongChan is the channel on which the pingManager will write Pong + // messages it is evaluating + pongChan chan *lnwire.Pong + + started sync.Once + stopped sync.Once + + quit chan struct{} + wg sync.WaitGroup +} + +// NewPingManager constructs a pingManager in a valid state. It must be started +// before it does anything useful, though. +func NewPingManager(cfg *PingManagerConfig) *PingManager { + m := PingManager{ + cfg: cfg, + outstandingPongSize: -1, + pongChan: make(chan *lnwire.Pong, 1), + quit: make(chan struct{}), + } + + return &m +} + +// Start launches the primary goroutine that is owned by the pingManager. +func (m *PingManager) Start() error { + var err error + m.started.Do(func() { + err = m.start() + }) + + return err +} + +func (m *PingManager) start() error { + m.pingTicker = time.NewTicker(m.cfg.IntervalDuration) + + m.pingTimeout = time.NewTimer(0) + defer m.pingTimeout.Stop() + + // Ensure that the pingTimeout channel is empty. + if !m.pingTimeout.Stop() { + <-m.pingTimeout.C + } + + m.wg.Add(1) + go func() { + defer m.wg.Done() + for { + select { + case <-m.pingTicker.C: + // If this occurs it means that the new ping + // cycle has begun while there is still an + // outstanding ping awaiting a pong response. + // This should never occur, but if it does, it + // implies a timeout. + if m.outstandingPongSize >= 0 { + e := errors.New("impossible: new ping" + + "in unclean state", + ) + m.cfg.OnPongFailure(e) + + return + } + + pongSize := m.cfg.NewPongSize() + ping := &lnwire.Ping{ + NumPongBytes: pongSize, + PaddingBytes: m.cfg.NewPingPayload(), + } + + // Set up our bookkeeping for the new Ping. + if err := m.setPingState(pongSize); err != nil { + m.cfg.OnPongFailure(err) + + return + } + + m.cfg.SendPing(ping) + + case <-m.pingTimeout.C: + m.resetPingState() + + e := errors.New("timeout while waiting for " + + "pong response", + ) + m.cfg.OnPongFailure(e) + + return + + case pong := <-m.pongChan: + pongSize := int32(len(pong.PongBytes)) + + // Save off values we are about to override + // when we call resetPingState. + expected := m.outstandingPongSize + lastPing := m.pingLastSend + + m.resetPingState() + + // If the pong we receive doesn't match the + // ping we sent out, then we fail out. + if pongSize != expected { + e := errors.New("pong response does " + + "not match expected size", + ) + m.cfg.OnPongFailure(e) + + return + } + + // Compute RTT of ping and save that for future + // querying. + if lastPing != nil { + rtt := time.Since(*lastPing) + m.pingTime.Store(&rtt) + } + case <-m.quit: + return + } + } + }() + + return nil +} + +// Stop interrupts the goroutines that the PingManager owns. Can only be called +// when the PingManager is running. +func (m *PingManager) Stop() error { + if m.pingTicker == nil { + return errors.New("PingManager cannot be stopped because it " + + "isn't running") + } + + m.stopped.Do(func() { + close(m.quit) + m.wg.Wait() + + m.pingTicker.Stop() + m.pingTimeout.Stop() + }) + + return nil +} + +// setPingState is a private method to keep track of all of the fields we need +// to set when we send out a Ping. +func (m *PingManager) setPingState(pongSize uint16) error { + t := time.Now() + m.pingLastSend = &t + m.outstandingPongSize = int32(pongSize) + if m.pingTimeout.Reset(m.cfg.TimeoutDuration) { + return fmt.Errorf( + "impossible: ping timeout reset when already active", + ) + } + + return nil +} + +// resetPingState is a private method that resets all of the bookkeeping that +// is tracking a currently outstanding Ping. +func (m *PingManager) resetPingState() { + m.pingLastSend = nil + m.outstandingPongSize = -1 + if !m.pingTimeout.Stop() { + select { + case <-m.pingTimeout.C: + default: + } + } +} + +// GetPingTimeMicroSeconds reports back the RTT calculated by the pingManager. +func (m *PingManager) GetPingTimeMicroSeconds() int64 { + rtt := m.pingTime.Load() + + if rtt == nil { + return -1 + } + + return rtt.Microseconds() +} + +// ReceivedPong is called to evaluate a Pong message against the expectations +// we have for it. It will cause the PingManager to invoke the supplied +// OnPongFailure function if the Pong argument supplied violates expectations. +func (m *PingManager) ReceivedPong(msg *lnwire.Pong) { + select { + case m.pongChan <- msg: + case <-m.quit: + } +} diff --git a/peer/ping_manager_test.go b/peer/ping_manager_test.go new file mode 100644 index 000000000..bdfeeb6af --- /dev/null +++ b/peer/ping_manager_test.go @@ -0,0 +1,88 @@ +package peer + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// TestPingManager tests three main properties about the ping manager. It +// ensures that if the pong response exceeds the timeout, that a failure is +// emitted on the failure channel. It ensures that if the Pong response is +// not congruent with the outstanding ping then a failure is emitted on the +// failure channel, and otherwise the failure channel remains empty. +func TestPingManager(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + delay int + pongSize uint16 + result bool + }{ + { + name: "Happy Path", + delay: 0, + pongSize: 4, + result: true, + }, + { + name: "Bad Pong", + delay: 0, + pongSize: 3, + result: false, + }, + { + name: "Timeout", + delay: 2, + pongSize: 4, + result: false, + }, + } + + payload := make([]byte, 4) + for _, test := range testCases { + // Set up PingManager. + pingSent := make(chan struct{}) + disconnected := make(chan struct{}) + mgr := NewPingManager(&PingManagerConfig{ + NewPingPayload: func() []byte { + return payload + }, + NewPongSize: func() uint16 { + return 4 + }, + IntervalDuration: time.Second * 2, + TimeoutDuration: time.Second, + SendPing: func(ping *lnwire.Ping) { + close(pingSent) + }, + OnPongFailure: func(err error) { + close(disconnected) + }, + }) + require.NoError(t, mgr.Start(), "Could not start pingManager") + + // Wait for initial Ping. + <-pingSent + + // Wait for pre-determined time before sending Pong response. + time.Sleep(time.Duration(test.delay) * time.Second) + + // Send Pong back. + res := lnwire.Pong{PongBytes: make([]byte, test.pongSize)} + mgr.ReceivedPong(&res) + + // Evaluate result + select { + case <-time.NewTimer(time.Second / 2).C: + require.True(t, test.result) + case <-disconnected: + require.False(t, test.result) + } + + require.NoError(t, mgr.Stop(), "Could not stop pingManager") + } +} From 5762061487f8f7828eeae8c4869d386e19c84b83 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 18 Jul 2023 17:48:26 -0600 Subject: [PATCH 7/8] docs: Update release notes to include pong enforcement --- docs/release-notes/release-notes-0.18.0.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/docs/release-notes/release-notes-0.18.0.md b/docs/release-notes/release-notes-0.18.0.md index 0b0f2d508..e80c4f422 100644 --- a/docs/release-notes/release-notes-0.18.0.md +++ b/docs/release-notes/release-notes-0.18.0.md @@ -25,6 +25,9 @@ that when sweeping inputs with locktime, an unexpected lower fee rate is applied. +* LND will now [enforce pong responses + ](https://github.com/lightningnetwork/lnd/pull/7828) from its peers + # New Features ## Functional Enhancements @@ -94,5 +97,6 @@ * Andras Banki-Horvath * Carla Kirk-Cohen * Elle Mouton -* Yong Yu +* Keagan McClelland * Ononiwu Maureen Chiamaka +* Yong Yu From 012cc6af8cd72c971b593a8ad772201f3a0de97b Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Thu, 14 Sep 2023 15:30:49 -0700 Subject: [PATCH 8/8] peer: eliminate unnecessary log spam from received pong msgs --- peer/brontide.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/peer/brontide.go b/peer/brontide.go index eba3b30cc..86aa99aae 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1996,7 +1996,7 @@ func messageSummary(msg lnwire.Message) string { return fmt.Sprintf("ping_bytes=%x", msg.PaddingBytes[:]) case *lnwire.Pong: - return fmt.Sprintf("pong_bytes=%x", msg.PongBytes[:]) + return fmt.Sprintf("len(pong_bytes)=%d", len(msg.PongBytes[:])) case *lnwire.UpdateFee: return fmt.Sprintf("chan_id=%v, fee_update_sat=%v",