diff --git a/peer/brontide.go b/peer/brontide.go index e151ebeb7..eba3b30cc 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -5,6 +5,7 @@ import ( "container/list" "errors" "fmt" + "math/rand" "net" "sync" "sync/atomic" @@ -50,6 +51,12 @@ const ( // pingInterval is the interval at which ping messages are sent. pingInterval = 1 * time.Minute + // pingTimeout is the amount of time we will wait for a pong response + // before considering the peer to be unresponsive. + // + // This MUST be a smaller value than the pingInterval. + pingTimeout = 30 * time.Second + // idleTimeout is the duration of inactivity before we time out a peer. idleTimeout = 5 * time.Minute @@ -379,15 +386,7 @@ type Brontide struct { bytesReceived uint64 bytesSent uint64 - // pingTime is a rough estimate of the RTT (round-trip-time) between us - // and the connected peer. This time is expressed in microseconds. - // To be used atomically. - // TODO(roasbeef): also use a WMA or EMA? - pingTime int64 - - // pingLastSend is the Unix time expressed in nanoseconds when we sent - // our last ping message. To be used atomically. - pingLastSend int64 + pingManager *PingManager // lastPingPayload stores an unsafe pointer wrapped as an atomic // variable which points to the last payload the remote party sent us @@ -525,6 +524,66 @@ func NewBrontide(cfg Config) *Brontide { log: build.NewPrefixLog(logPrefix, peerLog), } + var ( + lastBlockHeader *wire.BlockHeader + lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte + ) + newPingPayload := func() []byte { + // We query the BestBlockHeader from our BestBlockView each time + // this is called, and update our serialized block header if + // they differ. Over time, we'll use this to disseminate the + // latest block header between all our peers, which can later be + // used to cross-check our own view of the network to mitigate + // various types of eclipse attacks. + header, err := p.cfg.BestBlockView.BestBlockHeader() + if err != nil && header == lastBlockHeader { + return lastSerializedBlockHeader[:] + } + + buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0]) + err = header.Serialize(buf) + if err == nil { + lastBlockHeader = header + } else { + p.log.Warn("unable to serialize current block" + + "header for ping payload generation." + + "This should be impossible and means" + + "there is an implementation bug.") + } + + return lastSerializedBlockHeader[:] + } + + // TODO(roasbeef): make dynamic in order to + // create fake cover traffic + // NOTE(proofofkeags): this was changed to be + // dynamic to allow better pong identification, + // however, more thought is needed to make this + // actually usable as a traffic decoy + randPongSize := func() uint16 { + return uint16( + // We don't need cryptographic randomness here. + /* #nosec */ + rand.Intn(lnwire.MaxPongBytes + 1), + ) + } + + p.pingManager = NewPingManager(&PingManagerConfig{ + NewPingPayload: newPingPayload, + NewPongSize: randPongSize, + IntervalDuration: pingInterval, + TimeoutDuration: pingTimeout, + SendPing: func(ping *lnwire.Ping) { + p.queueMsg(ping, nil) + }, + OnPongFailure: func(err error) { + eStr := "pong response failure for %s: %v " + + "-- disconnecting" + p.log.Warnf(eStr, p, err) + p.Disconnect(fmt.Errorf(eStr, p, err)) + }, + }) + return p } @@ -644,40 +703,16 @@ func (p *Brontide) Start() error { p.startTime = time.Now() - p.wg.Add(5) + err = p.pingManager.Start() + if err != nil { + return fmt.Errorf("could not start ping manager %w", err) + } + + p.wg.Add(4) go p.queueHandler() go p.writeHandler() - go p.readHandler() go p.channelManager() - - var ( - lastBlockHeader *wire.BlockHeader - lastSerializedBlockHeader [wire.MaxBlockHeaderPayload]byte - ) - newPingPayload := func() []byte { - // We query the BestBlockHeader from our BestBlockView each time - // this is called, and update our serialized block header if - // they differ. Over time, we'll use this to disseminate the - // latest block header between all our peers, which can later be - // used to cross-check our own view of the network to mitigate - // various types of eclipse attacks. - header, err := p.cfg.BestBlockView.BestBlockHeader() - if err == nil && header != lastBlockHeader { - buf := bytes.NewBuffer(lastSerializedBlockHeader[0:0]) - err := header.Serialize(buf) - if err == nil { - lastBlockHeader = header - } else { - p.log.Warn("unable to serialize current block" + - "header for ping payload generation." + - "This should be impossible and means" + - "there is an implementation bug.") - } - } - - return lastSerializedBlockHeader[:] - } - go p.pingHandler(newPingPayload) + go p.readHandler() // Signal to any external processes that the peer is now active. close(p.activeSignal) @@ -1159,6 +1194,11 @@ func (p *Brontide) Disconnect(reason error) { p.cfg.Conn.Close() close(p.quit) + + if err := p.pingManager.Stop(); err != nil { + p.log.Errorf("couldn't stop pingManager during disconnect: %v", + err) + } } // String returns the string representation of this peer. @@ -1606,12 +1646,8 @@ out: switch msg := nextMsg.(type) { case *lnwire.Pong: // When we receive a Pong message in response to our - // last ping message, we'll use the time in which we - // sent the ping message to measure a rough estimate of - // round trip time. - pingSendTime := atomic.LoadInt64(&p.pingLastSend) - delay := (time.Now().UnixNano() - pingSendTime) / 1000 - atomic.StoreInt64(&p.pingTime, delay) + // last ping message, we send it to the pingManager + p.pingManager.ReceivedPong(msg) case *lnwire.Ping: // First, we'll store their latest ping payload within @@ -2137,17 +2173,6 @@ out: for { select { case outMsg := <-p.sendQueue: - // If we're about to send a ping message, then log the - // exact time in which we send the message so we can - // use the delay as a rough estimate of latency to the - // remote peer. - if _, ok := outMsg.msg.(*lnwire.Ping); ok { - // TODO(roasbeef): do this before the write? - // possibly account for processing within func? - now := time.Now().UnixNano() - atomic.StoreInt64(&p.pingLastSend, now) - } - // Record the time at which we first attempt to send the // message. startTime := time.Now() @@ -2280,40 +2305,9 @@ func (p *Brontide) queueHandler() { } } -// pingHandler is responsible for periodically sending ping messages to the -// remote peer in order to keep the connection alive and/or determine if the -// connection is still active. -// -// NOTE: This method MUST be run as a goroutine. -func (p *Brontide) pingHandler(newPayload func() []byte) { - defer p.wg.Done() - - pingTicker := time.NewTicker(pingInterval) - defer pingTicker.Stop() - - // TODO(roasbeef): make dynamic in order to create fake cover traffic - const numPongBytes = 16 - -out: - for { - select { - case <-pingTicker.C: - - pingMsg := &lnwire.Ping{ - NumPongBytes: numPongBytes, - PaddingBytes: newPayload(), - } - - p.queueMsg(pingMsg, nil) - case <-p.quit: - break out - } - } -} - // PingTime returns the estimated ping time to the peer in microseconds. func (p *Brontide) PingTime() int64 { - return atomic.LoadInt64(&p.pingTime) + return p.pingManager.GetPingTimeMicroSeconds() } // queueMsg adds the lnwire.Message to the back of the high priority send queue. diff --git a/peer/ping_manager.go b/peer/ping_manager.go new file mode 100644 index 000000000..3e62cb097 --- /dev/null +++ b/peer/ping_manager.go @@ -0,0 +1,266 @@ +package peer + +import ( + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/lightningnetwork/lnd/lnwire" +) + +// PingManagerConfig is a structure containing various parameters that govern +// how the PingManager behaves. +type PingManagerConfig struct { + + // NewPingPayload is a closure that returns the payload to be packaged + // in the Ping message. + NewPingPayload func() []byte + + // NewPongSize is a closure that returns a random value between + // [0, lnwire.MaxPongBytes]. This random value helps to more effectively + // pair Pong messages with Ping. + NewPongSize func() uint16 + + // IntervalDuration is the Duration between attempted pings. + IntervalDuration time.Duration + + // TimeoutDuration is the Duration we wait before declaring a ping + // attempt failed. + TimeoutDuration time.Duration + + // SendPing is a closure that is responsible for sending the Ping + // message out to our peer + SendPing func(ping *lnwire.Ping) + + // OnPongFailure is a closure that is responsible for executing the + // logic when a Pong message is either late or does not match our + // expectations for that Pong + OnPongFailure func(error) +} + +// PingManager is a structure that is designed to manage the internal state +// of the ping pong lifecycle with the remote peer. We assume there is only one +// ping outstanding at once. +// +// NOTE: This structure MUST be initialized with NewPingManager. +type PingManager struct { + cfg *PingManagerConfig + + // pingTime is a rough estimate of the RTT (round-trip-time) between us + // and the connected peer. + // To be used atomically. + // TODO(roasbeef): also use a WMA or EMA? + pingTime atomic.Pointer[time.Duration] + + // pingLastSend is the time when we sent our last ping message. + // To be used atomically. + pingLastSend *time.Time + + // outstandingPongSize is the current size of the requested pong + // payload. This value can only validly range from [0,65531]. Any + // value < 0 is interpreted as if there is no outstanding ping message. + outstandingPongSize int32 + + // pingTicker is a pointer to a Ticker that fires on every ping + // interval. + pingTicker *time.Ticker + + // pingTimeout is a Timer that will fire when we want to time out a + // ping + pingTimeout *time.Timer + + // pongChan is the channel on which the pingManager will write Pong + // messages it is evaluating + pongChan chan *lnwire.Pong + + started sync.Once + stopped sync.Once + + quit chan struct{} + wg sync.WaitGroup +} + +// NewPingManager constructs a pingManager in a valid state. It must be started +// before it does anything useful, though. +func NewPingManager(cfg *PingManagerConfig) *PingManager { + m := PingManager{ + cfg: cfg, + outstandingPongSize: -1, + pongChan: make(chan *lnwire.Pong, 1), + quit: make(chan struct{}), + } + + return &m +} + +// Start launches the primary goroutine that is owned by the pingManager. +func (m *PingManager) Start() error { + var err error + m.started.Do(func() { + err = m.start() + }) + + return err +} + +func (m *PingManager) start() error { + m.pingTicker = time.NewTicker(m.cfg.IntervalDuration) + + m.pingTimeout = time.NewTimer(0) + defer m.pingTimeout.Stop() + + // Ensure that the pingTimeout channel is empty. + if !m.pingTimeout.Stop() { + <-m.pingTimeout.C + } + + m.wg.Add(1) + go func() { + defer m.wg.Done() + for { + select { + case <-m.pingTicker.C: + // If this occurs it means that the new ping + // cycle has begun while there is still an + // outstanding ping awaiting a pong response. + // This should never occur, but if it does, it + // implies a timeout. + if m.outstandingPongSize >= 0 { + e := errors.New("impossible: new ping" + + "in unclean state", + ) + m.cfg.OnPongFailure(e) + + return + } + + pongSize := m.cfg.NewPongSize() + ping := &lnwire.Ping{ + NumPongBytes: pongSize, + PaddingBytes: m.cfg.NewPingPayload(), + } + + // Set up our bookkeeping for the new Ping. + if err := m.setPingState(pongSize); err != nil { + m.cfg.OnPongFailure(err) + + return + } + + m.cfg.SendPing(ping) + + case <-m.pingTimeout.C: + m.resetPingState() + + e := errors.New("timeout while waiting for " + + "pong response", + ) + m.cfg.OnPongFailure(e) + + return + + case pong := <-m.pongChan: + pongSize := int32(len(pong.PongBytes)) + + // Save off values we are about to override + // when we call resetPingState. + expected := m.outstandingPongSize + lastPing := m.pingLastSend + + m.resetPingState() + + // If the pong we receive doesn't match the + // ping we sent out, then we fail out. + if pongSize != expected { + e := errors.New("pong response does " + + "not match expected size", + ) + m.cfg.OnPongFailure(e) + + return + } + + // Compute RTT of ping and save that for future + // querying. + if lastPing != nil { + rtt := time.Since(*lastPing) + m.pingTime.Store(&rtt) + } + case <-m.quit: + return + } + } + }() + + return nil +} + +// Stop interrupts the goroutines that the PingManager owns. Can only be called +// when the PingManager is running. +func (m *PingManager) Stop() error { + if m.pingTicker == nil { + return errors.New("PingManager cannot be stopped because it " + + "isn't running") + } + + m.stopped.Do(func() { + close(m.quit) + m.wg.Wait() + + m.pingTicker.Stop() + m.pingTimeout.Stop() + }) + + return nil +} + +// setPingState is a private method to keep track of all of the fields we need +// to set when we send out a Ping. +func (m *PingManager) setPingState(pongSize uint16) error { + t := time.Now() + m.pingLastSend = &t + m.outstandingPongSize = int32(pongSize) + if m.pingTimeout.Reset(m.cfg.TimeoutDuration) { + return fmt.Errorf( + "impossible: ping timeout reset when already active", + ) + } + + return nil +} + +// resetPingState is a private method that resets all of the bookkeeping that +// is tracking a currently outstanding Ping. +func (m *PingManager) resetPingState() { + m.pingLastSend = nil + m.outstandingPongSize = -1 + if !m.pingTimeout.Stop() { + select { + case <-m.pingTimeout.C: + default: + } + } +} + +// GetPingTimeMicroSeconds reports back the RTT calculated by the pingManager. +func (m *PingManager) GetPingTimeMicroSeconds() int64 { + rtt := m.pingTime.Load() + + if rtt == nil { + return -1 + } + + return rtt.Microseconds() +} + +// ReceivedPong is called to evaluate a Pong message against the expectations +// we have for it. It will cause the PingManager to invoke the supplied +// OnPongFailure function if the Pong argument supplied violates expectations. +func (m *PingManager) ReceivedPong(msg *lnwire.Pong) { + select { + case m.pongChan <- msg: + case <-m.quit: + } +} diff --git a/peer/ping_manager_test.go b/peer/ping_manager_test.go new file mode 100644 index 000000000..bdfeeb6af --- /dev/null +++ b/peer/ping_manager_test.go @@ -0,0 +1,88 @@ +package peer + +import ( + "testing" + "time" + + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// TestPingManager tests three main properties about the ping manager. It +// ensures that if the pong response exceeds the timeout, that a failure is +// emitted on the failure channel. It ensures that if the Pong response is +// not congruent with the outstanding ping then a failure is emitted on the +// failure channel, and otherwise the failure channel remains empty. +func TestPingManager(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + delay int + pongSize uint16 + result bool + }{ + { + name: "Happy Path", + delay: 0, + pongSize: 4, + result: true, + }, + { + name: "Bad Pong", + delay: 0, + pongSize: 3, + result: false, + }, + { + name: "Timeout", + delay: 2, + pongSize: 4, + result: false, + }, + } + + payload := make([]byte, 4) + for _, test := range testCases { + // Set up PingManager. + pingSent := make(chan struct{}) + disconnected := make(chan struct{}) + mgr := NewPingManager(&PingManagerConfig{ + NewPingPayload: func() []byte { + return payload + }, + NewPongSize: func() uint16 { + return 4 + }, + IntervalDuration: time.Second * 2, + TimeoutDuration: time.Second, + SendPing: func(ping *lnwire.Ping) { + close(pingSent) + }, + OnPongFailure: func(err error) { + close(disconnected) + }, + }) + require.NoError(t, mgr.Start(), "Could not start pingManager") + + // Wait for initial Ping. + <-pingSent + + // Wait for pre-determined time before sending Pong response. + time.Sleep(time.Duration(test.delay) * time.Second) + + // Send Pong back. + res := lnwire.Pong{PongBytes: make([]byte, test.pongSize)} + mgr.ReceivedPong(&res) + + // Evaluate result + select { + case <-time.NewTimer(time.Second / 2).C: + require.True(t, test.result) + case <-disconnected: + require.False(t, test.result) + } + + require.NoError(t, mgr.Stop(), "Could not stop pingManager") + } +}