mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 17:52:25 +01:00
Merge pull request #7828 from ProofOfKeags/pong-enforcement
multi: pong enforcement
This commit is contained in:
commit
15f4213793
135
chainntnfs/best_block_view.go
Normal file
135
chainntnfs/best_block_view.go
Normal file
@ -0,0 +1,135 @@
|
||||
package chainntnfs
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
)
|
||||
|
||||
// BestBlockView is an interface that allows the querying of the most
|
||||
// up-to-date blockchain state with low overhead. Valid implementations of this
|
||||
// interface must track the latest chain state.
|
||||
type BestBlockView interface {
|
||||
// BestHeight gets the most recent block height known to the view.
|
||||
BestHeight() (uint32, error)
|
||||
|
||||
// BestBlockHeader gets the most recent block header known to the view.
|
||||
BestBlockHeader() (*wire.BlockHeader, error)
|
||||
}
|
||||
|
||||
// BestBlockTracker is a tiny subsystem that tracks the blockchain tip
|
||||
// and saves the most recent tip information in memory for querying. It is a
|
||||
// valid implementation of BestBlockView and additionally includes
|
||||
// methods for starting and stopping the system.
|
||||
type BestBlockTracker struct {
|
||||
notifier ChainNotifier
|
||||
blockNtfnStream *BlockEpochEvent
|
||||
current atomic.Pointer[BlockEpoch]
|
||||
mu sync.Mutex
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// This is a compile time check to ensure that BestBlockTracker implements
|
||||
// BestBlockView.
|
||||
var _ BestBlockView = (*BestBlockTracker)(nil)
|
||||
|
||||
// NewBestBlockTracker creates a new BestBlockTracker that isn't running yet.
|
||||
// It will not provide up to date information unless it has been started. The
|
||||
// ChainNotifier parameter must also be started prior to starting the
|
||||
// BestBlockTracker.
|
||||
func NewBestBlockTracker(chainNotifier ChainNotifier) *BestBlockTracker {
|
||||
return &BestBlockTracker{
|
||||
notifier: chainNotifier,
|
||||
blockNtfnStream: nil,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// BestHeight gets the most recent block height known to the
|
||||
// BestBlockTracker.
|
||||
func (t *BestBlockTracker) BestHeight() (uint32, error) {
|
||||
epoch := t.current.Load()
|
||||
if epoch == nil {
|
||||
return 0, errors.New("best block height not yet known")
|
||||
}
|
||||
|
||||
return uint32(epoch.Height), nil
|
||||
}
|
||||
|
||||
// BestBlockHeader gets the most recent block header known to the
|
||||
// BestBlockTracker.
|
||||
func (t *BestBlockTracker) BestBlockHeader() (*wire.BlockHeader, error) {
|
||||
epoch := t.current.Load()
|
||||
if epoch == nil {
|
||||
return nil, errors.New("best block header not yet known")
|
||||
}
|
||||
|
||||
return epoch.BlockHeader, nil
|
||||
}
|
||||
|
||||
// updateLoop is a helper that subscribes to the underlying BlockEpochEvent
|
||||
// stream and updates the internal values to match the new BlockEpochs that
|
||||
// are discovered.
|
||||
//
|
||||
// MUST be run as a goroutine.
|
||||
func (t *BestBlockTracker) updateLoop() {
|
||||
defer t.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case epoch, ok := <-t.blockNtfnStream.Epochs:
|
||||
if !ok {
|
||||
Log.Error("dead epoch stream in " +
|
||||
"BestBlockTracker")
|
||||
|
||||
return
|
||||
}
|
||||
t.current.Store(epoch)
|
||||
case <-t.quit:
|
||||
t.current.Store(nil)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the BestBlockTracker. It is an error to start it if it
|
||||
// is already started.
|
||||
func (t *BestBlockTracker) Start() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.blockNtfnStream != nil {
|
||||
return fmt.Errorf("BestBlockTracker is already started")
|
||||
}
|
||||
|
||||
var err error
|
||||
t.blockNtfnStream, err = t.notifier.RegisterBlockEpochNtfn(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t.wg.Add(1)
|
||||
go t.updateLoop()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the BestBlockTracker. It is an error to stop it if it has
|
||||
// not been started or if it has already been stopped.
|
||||
func (t *BestBlockTracker) Stop() error {
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.blockNtfnStream == nil {
|
||||
return fmt.Errorf("BestBlockTracker is not running")
|
||||
}
|
||||
close(t.quit)
|
||||
t.wg.Wait()
|
||||
t.blockNtfnStream.Cancel()
|
||||
t.blockNtfnStream = nil
|
||||
|
||||
return nil
|
||||
}
|
112
chainntnfs/best_block_view_test.go
Normal file
112
chainntnfs/best_block_view_test.go
Normal file
@ -0,0 +1,112 @@
|
||||
package chainntnfs_test
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/lntest/mock"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type blockEpoch chainntnfs.BlockEpoch
|
||||
|
||||
func (blockEpoch) Generate(r *rand.Rand, size int) reflect.Value {
|
||||
var chainHash, prevBlockHash, merkleRootHash chainhash.Hash
|
||||
r.Read(chainHash[:])
|
||||
r.Read(prevBlockHash[:])
|
||||
r.Read(merkleRootHash[:])
|
||||
|
||||
return reflect.ValueOf(blockEpoch(chainntnfs.BlockEpoch{
|
||||
Hash: &chainHash,
|
||||
Height: r.Int31n(1000000),
|
||||
BlockHeader: &wire.BlockHeader{
|
||||
Version: 2,
|
||||
PrevBlock: prevBlockHash,
|
||||
MerkleRoot: merkleRootHash,
|
||||
Timestamp: time.Now(),
|
||||
Bits: r.Uint32(),
|
||||
Nonce: r.Uint32(),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
// TestBestBlockTracker ensures that the most recent event pushed on the
|
||||
// underlying EpochChan is remembered by the BestBlockView functions as well
|
||||
// as testing the idempotence of the BestBlockView interface.
|
||||
func TestBestBlockTracker(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
notifier := &mock.ChainNotifier{
|
||||
SpendChan: nil,
|
||||
EpochChan: make(chan *chainntnfs.BlockEpoch),
|
||||
ConfChan: nil,
|
||||
}
|
||||
|
||||
chainNotifierI := chainntnfs.ChainNotifier(notifier)
|
||||
|
||||
tracker := chainntnfs.NewBestBlockTracker(chainNotifierI)
|
||||
require.Nil(t, tracker.Start(),
|
||||
"BestBlockTacker could not be started")
|
||||
|
||||
// we have to limit test cases because the poll interval of
|
||||
// wait.Predicate isn't tight enough to support the usual 100
|
||||
cfg := quick.Config{MaxCount: 50}
|
||||
correctness := func(epochRand blockEpoch) bool {
|
||||
epoch := chainntnfs.BlockEpoch(epochRand)
|
||||
notifier.EpochChan <- &epoch
|
||||
|
||||
// wait for new block to propagate
|
||||
err := wait.Predicate(
|
||||
func() bool {
|
||||
_, err := tracker.BestHeight()
|
||||
return err == nil
|
||||
},
|
||||
1*time.Second,
|
||||
)
|
||||
require.Nil(t, err,
|
||||
"BestBlockTracker: block propagation timeout")
|
||||
|
||||
height, _ := tracker.BestHeight()
|
||||
header, _ := tracker.BestBlockHeader()
|
||||
|
||||
return height == uint32(epoch.Height) &&
|
||||
header == epoch.BlockHeader
|
||||
}
|
||||
idempotence := func(epochRand blockEpoch) bool {
|
||||
epoch := chainntnfs.BlockEpoch(epochRand)
|
||||
notifier.EpochChan <- &epoch
|
||||
|
||||
// wait for new block to propagate
|
||||
err := wait.Predicate(
|
||||
func() bool {
|
||||
_, err := tracker.BestHeight()
|
||||
return err == nil
|
||||
},
|
||||
1*time.Second,
|
||||
)
|
||||
require.Nil(t, err,
|
||||
"ChainStateTracker: block propagation timeout")
|
||||
|
||||
height0, _ := tracker.BestHeight()
|
||||
height1, _ := tracker.BestHeight()
|
||||
header0, _ := tracker.BestBlockHeader()
|
||||
header1, _ := tracker.BestBlockHeader()
|
||||
|
||||
return height0 == height1 && header0 == header1
|
||||
}
|
||||
err := quick.Check(correctness, &cfg)
|
||||
require.Nil(t, err,
|
||||
"ChainStateTracker does not give up to date info: %v", err)
|
||||
|
||||
require.Nil(t, quick.Check(idempotence, &cfg),
|
||||
"ChainStateTracker is not idempotent")
|
||||
|
||||
require.Nil(t, tracker.Stop(), "ChainStateTracker could not be stopped")
|
||||
}
|
@ -145,6 +145,10 @@ type PartialChainControl struct {
|
||||
// interested in.
|
||||
ChainNotifier chainntnfs.ChainNotifier
|
||||
|
||||
// BestBlockTracker is used to maintain a view of the global
|
||||
// chain state that changes over time
|
||||
BestBlockTracker *chainntnfs.BestBlockTracker
|
||||
|
||||
// MempoolNotifier is used to watch for spending events happened in
|
||||
// mempool.
|
||||
MempoolNotifier chainntnfs.MempoolWatcher
|
||||
@ -667,6 +671,9 @@ func NewPartialChainControl(cfg *Config) (*PartialChainControl, func(), error) {
|
||||
cfg.Bitcoin.Node)
|
||||
}
|
||||
|
||||
cc.BestBlockTracker =
|
||||
chainntnfs.NewBestBlockTracker(cc.ChainNotifier)
|
||||
|
||||
switch {
|
||||
// If the fee URL isn't set, and the user is running mainnet, then
|
||||
// we'll return an error to instruct them to set a proper fee
|
||||
|
@ -25,6 +25,9 @@
|
||||
that when sweeping inputs with locktime, an unexpected lower fee rate is
|
||||
applied.
|
||||
|
||||
* LND will now [enforce pong responses
|
||||
](https://github.com/lightningnetwork/lnd/pull/7828) from its peers
|
||||
|
||||
# New Features
|
||||
## Functional Enhancements
|
||||
|
||||
@ -94,5 +97,6 @@
|
||||
* Andras Banki-Horvath
|
||||
* Carla Kirk-Cohen
|
||||
* Elle Mouton
|
||||
* Yong Yu
|
||||
* Keagan McClelland
|
||||
* Ononiwu Maureen Chiamaka
|
||||
* Yong Yu
|
||||
|
183
peer/brontide.go
183
peer/brontide.go
@ -5,6 +5,7 @@ import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@ -50,6 +51,12 @@ const (
|
||||
// pingInterval is the interval at which ping messages are sent.
|
||||
pingInterval = 1 * time.Minute
|
||||
|
||||
// pingTimeout is the amount of time we will wait for a pong response
|
||||
// before considering the peer to be unresponsive.
|
||||
//
|
||||
// This MUST be a smaller value than the pingInterval.
|
||||
pingTimeout = 30 * time.Second
|
||||
|
||||
// idleTimeout is the duration of inactivity before we time out a peer.
|
||||
idleTimeout = 5 * time.Minute
|
||||
|
||||
@ -233,6 +240,10 @@ type Config struct {
|
||||
// transaction.
|
||||
ChainNotifier chainntnfs.ChainNotifier
|
||||
|
||||
// BestBlockView is used to efficiently query for up-to-date
|
||||
// blockchain state information
|
||||
BestBlockView chainntnfs.BestBlockView
|
||||
|
||||
// RoutingPolicy is used to set the forwarding policy for links created by
|
||||
// the Brontide.
|
||||
RoutingPolicy models.ForwardingPolicy
|
||||
@ -375,15 +386,7 @@ type Brontide struct {
|
||||
bytesReceived uint64
|
||||
bytesSent uint64
|
||||
|
||||
// pingTime is a rough estimate of the RTT (round-trip-time) between us
|
||||
// and the connected peer. This time is expressed in microseconds.
|
||||
// To be used atomically.
|
||||
// TODO(roasbeef): also use a WMA or EMA?
|
||||
pingTime int64
|
||||
|
||||
// pingLastSend is the Unix time expressed in nanoseconds when we sent
|
||||
// our last ping message. To be used atomically.
|
||||
pingLastSend int64
|
||||
pingManager *PingManager
|
||||
|
||||
// lastPingPayload stores an unsafe pointer wrapped as an atomic
|
||||
// variable which points to the last payload the remote party sent us
|
||||
@ -521,6 +524,66 @@ func NewBrontide(cfg Config) *Brontide {
|
||||
log: build.NewPrefixLog(logPrefix, peerLog),
|
||||
}
|
||||
|
||||
var (
|
||||
lastBlockHeader *wire.BlockHeader
|
||||
lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte
|
||||
)
|
||||
newPingPayload := func() []byte {
|
||||
// We query the BestBlockHeader from our BestBlockView each time
|
||||
// this is called, and update our serialized block header if
|
||||
// they differ. Over time, we'll use this to disseminate the
|
||||
// latest block header between all our peers, which can later be
|
||||
// used to cross-check our own view of the network to mitigate
|
||||
// various types of eclipse attacks.
|
||||
header, err := p.cfg.BestBlockView.BestBlockHeader()
|
||||
if err != nil && header == lastBlockHeader {
|
||||
return lastSerializedBlockHeader[:]
|
||||
}
|
||||
|
||||
buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0])
|
||||
err = header.Serialize(buf)
|
||||
if err == nil {
|
||||
lastBlockHeader = header
|
||||
} else {
|
||||
p.log.Warn("unable to serialize current block" +
|
||||
"header for ping payload generation." +
|
||||
"This should be impossible and means" +
|
||||
"there is an implementation bug.")
|
||||
}
|
||||
|
||||
return lastSerializedBlockHeader[:]
|
||||
}
|
||||
|
||||
// TODO(roasbeef): make dynamic in order to
|
||||
// create fake cover traffic
|
||||
// NOTE(proofofkeags): this was changed to be
|
||||
// dynamic to allow better pong identification,
|
||||
// however, more thought is needed to make this
|
||||
// actually usable as a traffic decoy
|
||||
randPongSize := func() uint16 {
|
||||
return uint16(
|
||||
// We don't need cryptographic randomness here.
|
||||
/* #nosec */
|
||||
rand.Intn(lnwire.MaxPongBytes + 1),
|
||||
)
|
||||
}
|
||||
|
||||
p.pingManager = NewPingManager(&PingManagerConfig{
|
||||
NewPingPayload: newPingPayload,
|
||||
NewPongSize: randPongSize,
|
||||
IntervalDuration: pingInterval,
|
||||
TimeoutDuration: pingTimeout,
|
||||
SendPing: func(ping *lnwire.Ping) {
|
||||
p.queueMsg(ping, nil)
|
||||
},
|
||||
OnPongFailure: func(err error) {
|
||||
eStr := "pong response failure for %s: %v " +
|
||||
"-- disconnecting"
|
||||
p.log.Warnf(eStr, p, err)
|
||||
p.Disconnect(fmt.Errorf(eStr, p, err))
|
||||
},
|
||||
})
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
@ -640,12 +703,16 @@ func (p *Brontide) Start() error {
|
||||
|
||||
p.startTime = time.Now()
|
||||
|
||||
p.wg.Add(5)
|
||||
err = p.pingManager.Start()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not start ping manager %w", err)
|
||||
}
|
||||
|
||||
p.wg.Add(4)
|
||||
go p.queueHandler()
|
||||
go p.writeHandler()
|
||||
go p.readHandler()
|
||||
go p.channelManager()
|
||||
go p.pingHandler()
|
||||
go p.readHandler()
|
||||
|
||||
// Signal to any external processes that the peer is now active.
|
||||
close(p.activeSignal)
|
||||
@ -1127,6 +1194,11 @@ func (p *Brontide) Disconnect(reason error) {
|
||||
p.cfg.Conn.Close()
|
||||
|
||||
close(p.quit)
|
||||
|
||||
if err := p.pingManager.Stop(); err != nil {
|
||||
p.log.Errorf("couldn't stop pingManager during disconnect: %v",
|
||||
err)
|
||||
}
|
||||
}
|
||||
|
||||
// String returns the string representation of this peer.
|
||||
@ -1574,12 +1646,8 @@ out:
|
||||
switch msg := nextMsg.(type) {
|
||||
case *lnwire.Pong:
|
||||
// When we receive a Pong message in response to our
|
||||
// last ping message, we'll use the time in which we
|
||||
// sent the ping message to measure a rough estimate of
|
||||
// round trip time.
|
||||
pingSendTime := atomic.LoadInt64(&p.pingLastSend)
|
||||
delay := (time.Now().UnixNano() - pingSendTime) / 1000
|
||||
atomic.StoreInt64(&p.pingTime, delay)
|
||||
// last ping message, we send it to the pingManager
|
||||
p.pingManager.ReceivedPong(msg)
|
||||
|
||||
case *lnwire.Ping:
|
||||
// First, we'll store their latest ping payload within
|
||||
@ -1928,7 +1996,7 @@ func messageSummary(msg lnwire.Message) string {
|
||||
return fmt.Sprintf("ping_bytes=%x", msg.PaddingBytes[:])
|
||||
|
||||
case *lnwire.Pong:
|
||||
return fmt.Sprintf("pong_bytes=%x", msg.PongBytes[:])
|
||||
return fmt.Sprintf("len(pong_bytes)=%d", len(msg.PongBytes[:]))
|
||||
|
||||
case *lnwire.UpdateFee:
|
||||
return fmt.Sprintf("chan_id=%v, fee_update_sat=%v",
|
||||
@ -2105,17 +2173,6 @@ out:
|
||||
for {
|
||||
select {
|
||||
case outMsg := <-p.sendQueue:
|
||||
// If we're about to send a ping message, then log the
|
||||
// exact time in which we send the message so we can
|
||||
// use the delay as a rough estimate of latency to the
|
||||
// remote peer.
|
||||
if _, ok := outMsg.msg.(*lnwire.Ping); ok {
|
||||
// TODO(roasbeef): do this before the write?
|
||||
// possibly account for processing within func?
|
||||
now := time.Now().UnixNano()
|
||||
atomic.StoreInt64(&p.pingLastSend, now)
|
||||
}
|
||||
|
||||
// Record the time at which we first attempt to send the
|
||||
// message.
|
||||
startTime := time.Now()
|
||||
@ -2248,73 +2305,9 @@ func (p *Brontide) queueHandler() {
|
||||
}
|
||||
}
|
||||
|
||||
// pingHandler is responsible for periodically sending ping messages to the
|
||||
// remote peer in order to keep the connection alive and/or determine if the
|
||||
// connection is still active.
|
||||
//
|
||||
// NOTE: This method MUST be run as a goroutine.
|
||||
func (p *Brontide) pingHandler() {
|
||||
defer p.wg.Done()
|
||||
|
||||
pingTicker := time.NewTicker(pingInterval)
|
||||
defer pingTicker.Stop()
|
||||
|
||||
// TODO(roasbeef): make dynamic in order to create fake cover traffic
|
||||
const numPongBytes = 16
|
||||
|
||||
blockEpochs, err := p.cfg.ChainNotifier.RegisterBlockEpochNtfn(nil)
|
||||
if err != nil {
|
||||
p.log.Errorf("unable to establish block epoch "+
|
||||
"subscription: %v", err)
|
||||
return
|
||||
}
|
||||
defer blockEpochs.Cancel()
|
||||
|
||||
var (
|
||||
pingPayload [wire.MaxBlockHeaderPayload]byte
|
||||
blockHeader *wire.BlockHeader
|
||||
)
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
// Each time a new block comes in, we'll copy the raw header
|
||||
// contents over to our ping payload declared above. Over time,
|
||||
// we'll use this to disseminate the latest block header
|
||||
// between all our peers, which can later be used to
|
||||
// cross-check our own view of the network to mitigate various
|
||||
// types of eclipse attacks.
|
||||
case epoch, ok := <-blockEpochs.Epochs:
|
||||
if !ok {
|
||||
p.log.Debugf("block notifications " +
|
||||
"canceled")
|
||||
return
|
||||
}
|
||||
|
||||
blockHeader = epoch.BlockHeader
|
||||
headerBuf := bytes.NewBuffer(pingPayload[0:0])
|
||||
err := blockHeader.Serialize(headerBuf)
|
||||
if err != nil {
|
||||
p.log.Errorf("unable to encode header: %v",
|
||||
err)
|
||||
}
|
||||
|
||||
case <-pingTicker.C:
|
||||
|
||||
pingMsg := &lnwire.Ping{
|
||||
NumPongBytes: numPongBytes,
|
||||
PaddingBytes: pingPayload[:],
|
||||
}
|
||||
|
||||
p.queueMsg(pingMsg, nil)
|
||||
case <-p.quit:
|
||||
break out
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// PingTime returns the estimated ping time to the peer in microseconds.
|
||||
func (p *Brontide) PingTime() int64 {
|
||||
return atomic.LoadInt64(&p.pingTime)
|
||||
return p.pingManager.GetPingTimeMicroSeconds()
|
||||
}
|
||||
|
||||
// queueMsg adds the lnwire.Message to the back of the high priority send queue.
|
||||
|
266
peer/ping_manager.go
Normal file
266
peer/ping_manager.go
Normal file
@ -0,0 +1,266 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// PingManagerConfig is a structure containing various parameters that govern
|
||||
// how the PingManager behaves.
|
||||
type PingManagerConfig struct {
|
||||
|
||||
// NewPingPayload is a closure that returns the payload to be packaged
|
||||
// in the Ping message.
|
||||
NewPingPayload func() []byte
|
||||
|
||||
// NewPongSize is a closure that returns a random value between
|
||||
// [0, lnwire.MaxPongBytes]. This random value helps to more effectively
|
||||
// pair Pong messages with Ping.
|
||||
NewPongSize func() uint16
|
||||
|
||||
// IntervalDuration is the Duration between attempted pings.
|
||||
IntervalDuration time.Duration
|
||||
|
||||
// TimeoutDuration is the Duration we wait before declaring a ping
|
||||
// attempt failed.
|
||||
TimeoutDuration time.Duration
|
||||
|
||||
// SendPing is a closure that is responsible for sending the Ping
|
||||
// message out to our peer
|
||||
SendPing func(ping *lnwire.Ping)
|
||||
|
||||
// OnPongFailure is a closure that is responsible for executing the
|
||||
// logic when a Pong message is either late or does not match our
|
||||
// expectations for that Pong
|
||||
OnPongFailure func(error)
|
||||
}
|
||||
|
||||
// PingManager is a structure that is designed to manage the internal state
|
||||
// of the ping pong lifecycle with the remote peer. We assume there is only one
|
||||
// ping outstanding at once.
|
||||
//
|
||||
// NOTE: This structure MUST be initialized with NewPingManager.
|
||||
type PingManager struct {
|
||||
cfg *PingManagerConfig
|
||||
|
||||
// pingTime is a rough estimate of the RTT (round-trip-time) between us
|
||||
// and the connected peer.
|
||||
// To be used atomically.
|
||||
// TODO(roasbeef): also use a WMA or EMA?
|
||||
pingTime atomic.Pointer[time.Duration]
|
||||
|
||||
// pingLastSend is the time when we sent our last ping message.
|
||||
// To be used atomically.
|
||||
pingLastSend *time.Time
|
||||
|
||||
// outstandingPongSize is the current size of the requested pong
|
||||
// payload. This value can only validly range from [0,65531]. Any
|
||||
// value < 0 is interpreted as if there is no outstanding ping message.
|
||||
outstandingPongSize int32
|
||||
|
||||
// pingTicker is a pointer to a Ticker that fires on every ping
|
||||
// interval.
|
||||
pingTicker *time.Ticker
|
||||
|
||||
// pingTimeout is a Timer that will fire when we want to time out a
|
||||
// ping
|
||||
pingTimeout *time.Timer
|
||||
|
||||
// pongChan is the channel on which the pingManager will write Pong
|
||||
// messages it is evaluating
|
||||
pongChan chan *lnwire.Pong
|
||||
|
||||
started sync.Once
|
||||
stopped sync.Once
|
||||
|
||||
quit chan struct{}
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// NewPingManager constructs a pingManager in a valid state. It must be started
|
||||
// before it does anything useful, though.
|
||||
func NewPingManager(cfg *PingManagerConfig) *PingManager {
|
||||
m := PingManager{
|
||||
cfg: cfg,
|
||||
outstandingPongSize: -1,
|
||||
pongChan: make(chan *lnwire.Pong, 1),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
return &m
|
||||
}
|
||||
|
||||
// Start launches the primary goroutine that is owned by the pingManager.
|
||||
func (m *PingManager) Start() error {
|
||||
var err error
|
||||
m.started.Do(func() {
|
||||
err = m.start()
|
||||
})
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (m *PingManager) start() error {
|
||||
m.pingTicker = time.NewTicker(m.cfg.IntervalDuration)
|
||||
|
||||
m.pingTimeout = time.NewTimer(0)
|
||||
defer m.pingTimeout.Stop()
|
||||
|
||||
// Ensure that the pingTimeout channel is empty.
|
||||
if !m.pingTimeout.Stop() {
|
||||
<-m.pingTimeout.C
|
||||
}
|
||||
|
||||
m.wg.Add(1)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-m.pingTicker.C:
|
||||
// If this occurs it means that the new ping
|
||||
// cycle has begun while there is still an
|
||||
// outstanding ping awaiting a pong response.
|
||||
// This should never occur, but if it does, it
|
||||
// implies a timeout.
|
||||
if m.outstandingPongSize >= 0 {
|
||||
e := errors.New("impossible: new ping" +
|
||||
"in unclean state",
|
||||
)
|
||||
m.cfg.OnPongFailure(e)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
pongSize := m.cfg.NewPongSize()
|
||||
ping := &lnwire.Ping{
|
||||
NumPongBytes: pongSize,
|
||||
PaddingBytes: m.cfg.NewPingPayload(),
|
||||
}
|
||||
|
||||
// Set up our bookkeeping for the new Ping.
|
||||
if err := m.setPingState(pongSize); err != nil {
|
||||
m.cfg.OnPongFailure(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
m.cfg.SendPing(ping)
|
||||
|
||||
case <-m.pingTimeout.C:
|
||||
m.resetPingState()
|
||||
|
||||
e := errors.New("timeout while waiting for " +
|
||||
"pong response",
|
||||
)
|
||||
m.cfg.OnPongFailure(e)
|
||||
|
||||
return
|
||||
|
||||
case pong := <-m.pongChan:
|
||||
pongSize := int32(len(pong.PongBytes))
|
||||
|
||||
// Save off values we are about to override
|
||||
// when we call resetPingState.
|
||||
expected := m.outstandingPongSize
|
||||
lastPing := m.pingLastSend
|
||||
|
||||
m.resetPingState()
|
||||
|
||||
// If the pong we receive doesn't match the
|
||||
// ping we sent out, then we fail out.
|
||||
if pongSize != expected {
|
||||
e := errors.New("pong response does " +
|
||||
"not match expected size",
|
||||
)
|
||||
m.cfg.OnPongFailure(e)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Compute RTT of ping and save that for future
|
||||
// querying.
|
||||
if lastPing != nil {
|
||||
rtt := time.Since(*lastPing)
|
||||
m.pingTime.Store(&rtt)
|
||||
}
|
||||
case <-m.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop interrupts the goroutines that the PingManager owns. Can only be called
|
||||
// when the PingManager is running.
|
||||
func (m *PingManager) Stop() error {
|
||||
if m.pingTicker == nil {
|
||||
return errors.New("PingManager cannot be stopped because it " +
|
||||
"isn't running")
|
||||
}
|
||||
|
||||
m.stopped.Do(func() {
|
||||
close(m.quit)
|
||||
m.wg.Wait()
|
||||
|
||||
m.pingTicker.Stop()
|
||||
m.pingTimeout.Stop()
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// setPingState is a private method to keep track of all of the fields we need
|
||||
// to set when we send out a Ping.
|
||||
func (m *PingManager) setPingState(pongSize uint16) error {
|
||||
t := time.Now()
|
||||
m.pingLastSend = &t
|
||||
m.outstandingPongSize = int32(pongSize)
|
||||
if m.pingTimeout.Reset(m.cfg.TimeoutDuration) {
|
||||
return fmt.Errorf(
|
||||
"impossible: ping timeout reset when already active",
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// resetPingState is a private method that resets all of the bookkeeping that
|
||||
// is tracking a currently outstanding Ping.
|
||||
func (m *PingManager) resetPingState() {
|
||||
m.pingLastSend = nil
|
||||
m.outstandingPongSize = -1
|
||||
if !m.pingTimeout.Stop() {
|
||||
select {
|
||||
case <-m.pingTimeout.C:
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetPingTimeMicroSeconds reports back the RTT calculated by the pingManager.
|
||||
func (m *PingManager) GetPingTimeMicroSeconds() int64 {
|
||||
rtt := m.pingTime.Load()
|
||||
|
||||
if rtt == nil {
|
||||
return -1
|
||||
}
|
||||
|
||||
return rtt.Microseconds()
|
||||
}
|
||||
|
||||
// ReceivedPong is called to evaluate a Pong message against the expectations
|
||||
// we have for it. It will cause the PingManager to invoke the supplied
|
||||
// OnPongFailure function if the Pong argument supplied violates expectations.
|
||||
func (m *PingManager) ReceivedPong(msg *lnwire.Pong) {
|
||||
select {
|
||||
case m.pongChan <- msg:
|
||||
case <-m.quit:
|
||||
}
|
||||
}
|
88
peer/ping_manager_test.go
Normal file
88
peer/ping_manager_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package peer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestPingManager tests three main properties about the ping manager. It
|
||||
// ensures that if the pong response exceeds the timeout, that a failure is
|
||||
// emitted on the failure channel. It ensures that if the Pong response is
|
||||
// not congruent with the outstanding ping then a failure is emitted on the
|
||||
// failure channel, and otherwise the failure channel remains empty.
|
||||
func TestPingManager(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
delay int
|
||||
pongSize uint16
|
||||
result bool
|
||||
}{
|
||||
{
|
||||
name: "Happy Path",
|
||||
delay: 0,
|
||||
pongSize: 4,
|
||||
result: true,
|
||||
},
|
||||
{
|
||||
name: "Bad Pong",
|
||||
delay: 0,
|
||||
pongSize: 3,
|
||||
result: false,
|
||||
},
|
||||
{
|
||||
name: "Timeout",
|
||||
delay: 2,
|
||||
pongSize: 4,
|
||||
result: false,
|
||||
},
|
||||
}
|
||||
|
||||
payload := make([]byte, 4)
|
||||
for _, test := range testCases {
|
||||
// Set up PingManager.
|
||||
pingSent := make(chan struct{})
|
||||
disconnected := make(chan struct{})
|
||||
mgr := NewPingManager(&PingManagerConfig{
|
||||
NewPingPayload: func() []byte {
|
||||
return payload
|
||||
},
|
||||
NewPongSize: func() uint16 {
|
||||
return 4
|
||||
},
|
||||
IntervalDuration: time.Second * 2,
|
||||
TimeoutDuration: time.Second,
|
||||
SendPing: func(ping *lnwire.Ping) {
|
||||
close(pingSent)
|
||||
},
|
||||
OnPongFailure: func(err error) {
|
||||
close(disconnected)
|
||||
},
|
||||
})
|
||||
require.NoError(t, mgr.Start(), "Could not start pingManager")
|
||||
|
||||
// Wait for initial Ping.
|
||||
<-pingSent
|
||||
|
||||
// Wait for pre-determined time before sending Pong response.
|
||||
time.Sleep(time.Duration(test.delay) * time.Second)
|
||||
|
||||
// Send Pong back.
|
||||
res := lnwire.Pong{PongBytes: make([]byte, test.pongSize)}
|
||||
mgr.ReceivedPong(&res)
|
||||
|
||||
// Evaluate result
|
||||
select {
|
||||
case <-time.NewTimer(time.Second / 2).C:
|
||||
require.True(t, test.result)
|
||||
case <-disconnected:
|
||||
require.False(t, test.result)
|
||||
}
|
||||
|
||||
require.NoError(t, mgr.Stop(), "Could not stop pingManager")
|
||||
}
|
||||
}
|
11
server.go
11
server.go
@ -1901,6 +1901,12 @@ func (s *server) Start() error {
|
||||
}
|
||||
cleanup = cleanup.add(s.cc.ChainNotifier.Stop)
|
||||
|
||||
if err := s.cc.BestBlockTracker.Start(); err != nil {
|
||||
startErr = err
|
||||
return
|
||||
}
|
||||
cleanup = cleanup.add(s.cc.BestBlockTracker.Stop)
|
||||
|
||||
if err := s.channelNotifier.Start(); err != nil {
|
||||
startErr = err
|
||||
return
|
||||
@ -2282,6 +2288,10 @@ func (s *server) Stop() error {
|
||||
if err := s.cc.ChainNotifier.Stop(); err != nil {
|
||||
srvrLog.Warnf("Unable to stop ChainNotifier: %v", err)
|
||||
}
|
||||
if err := s.cc.BestBlockTracker.Stop(); err != nil {
|
||||
srvrLog.Warnf("Unable to stop BestBlockTracker: %v",
|
||||
err)
|
||||
}
|
||||
s.chanEventStore.Stop()
|
||||
s.missionControl.StopStoreTicker()
|
||||
|
||||
@ -3827,6 +3837,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq,
|
||||
SigPool: s.sigPool,
|
||||
Wallet: s.cc.Wallet,
|
||||
ChainNotifier: s.cc.ChainNotifier,
|
||||
BestBlockView: s.cc.BestBlockTracker,
|
||||
RoutingPolicy: s.cc.RoutingPolicy,
|
||||
Sphinx: s.sphinx,
|
||||
WitnessBeacon: s.witnessBeacon,
|
||||
|
Loading…
x
Reference in New Issue
Block a user