diff --git a/peer/brontide.go b/peer/brontide.go index da4aa610a..c3cb70f47 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1730,12 +1730,9 @@ type msgStream struct { startMsg string stopMsg string - msgCond *sync.Cond - msgs []lnwire.Message - - mtx sync.Mutex - - producerSema chan struct{} + // queue is the underlying backpressure-aware queue that manages + // messages. + queue *queue.BackpressureQueue[lnwire.Message] wg sync.WaitGroup quit chan struct{} @@ -1744,28 +1741,28 @@ type msgStream struct { // newMsgStream creates a new instance of a chanMsgStream for a particular // channel identified by its channel ID. bufSize is the max number of messages // that should be buffered in the internal queue. Callers should set this to a -// sane value that avoids blocking unnecessarily, but doesn't allow an -// unbounded amount of memory to be allocated to buffer incoming messages. -func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize uint32, +// sane value that avoids blocking unnecessarily, but doesn't allow an unbounded +// amount of memory to be allocated to buffer incoming messages. +func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize int, apply func(lnwire.Message)) *msgStream { stream := &msgStream{ - peer: p, - apply: apply, - startMsg: startMsg, - stopMsg: stopMsg, - producerSema: make(chan struct{}, bufSize), - quit: make(chan struct{}), + peer: p, + apply: apply, + startMsg: startMsg, + stopMsg: stopMsg, + quit: make(chan struct{}), } - stream.msgCond = sync.NewCond(&stream.mtx) - // Before we return the active stream, we'll populate the producer's - // semaphore channel. We'll use this to ensure that the producer won't - // attempt to allocate memory in the queue for an item until it has - // sufficient extra space. - for i := uint32(0); i < bufSize; i++ { - stream.producerSema <- struct{}{} + // Initialize the backpressure queue. The predicate always returns + // false, meaning we don't proactively drop messages based on queue + // length here. + alwaysFalsePredicate := func(int, lnwire.Message) bool { + return false } + stream.queue = queue.NewBackpressureQueue[lnwire.Message]( + bufSize, alwaysFalsePredicate, + ) return stream } @@ -1778,17 +1775,8 @@ func (ms *msgStream) Start() { // Stop stops the chanMsgStream. func (ms *msgStream) Stop() { - // TODO(roasbeef): signal too? - close(ms.quit) - // Now that we've closed the channel, we'll repeatedly signal the msg - // consumer until we've detected that it has exited. - for atomic.LoadInt32(&ms.streamShutdown) == 0 { - ms.msgCond.Signal() - time.Sleep(time.Millisecond * 100) - } - ms.wg.Wait() } @@ -1796,82 +1784,57 @@ func (ms *msgStream) Stop() { // readHandler directly to the target channel. func (ms *msgStream) msgConsumer() { defer ms.wg.Done() - defer peerLog.Tracef(ms.stopMsg) + defer ms.peer.log.Tracef(ms.stopMsg) defer atomic.StoreInt32(&ms.streamShutdown, 1) - peerLog.Tracef(ms.startMsg) + ms.peer.log.Tracef(ms.startMsg) + + ctx, _ := ms.peer.cg.Create(context.Background()) for { - // First, we'll check our condition. If the queue of messages - // is empty, then we'll wait until a new item is added. - ms.msgCond.L.Lock() - for len(ms.msgs) == 0 { - ms.msgCond.Wait() - - // If we woke up in order to exit, then we'll do so. - // Otherwise, we'll check the message queue for any new - // items. - select { - case <-ms.peer.cg.Done(): - ms.msgCond.L.Unlock() - return - case <-ms.quit: - ms.msgCond.L.Unlock() - return - default: - } + // Dequeue the next message. This will block until a message is + // available or the context is canceled. + msg, err := ms.queue.Dequeue(ctx) + if err != nil { + ms.peer.log.Warnf("unable to dequeue message: %v", err) + return } - // Grab the message off the front of the queue, shifting the - // slice's reference down one in order to remove the message - // from the queue. - msg := ms.msgs[0] - ms.msgs[0] = nil // Set to nil to prevent GC leak. - ms.msgs = ms.msgs[1:] - - ms.msgCond.L.Unlock() - + // Apply the dequeued message. ms.apply(msg) - // We've just successfully processed an item, so we'll signal - // to the producer that a new slot in the buffer. We'll use - // this to bound the size of the buffer to avoid allowing it to - // grow indefinitely. + // As a precaution, we'll check to see if we're already shutting + // down before adding a new message to the queue. select { - case ms.producerSema <- struct{}{}: case <-ms.peer.cg.Done(): return case <-ms.quit: return + default: } } } // AddMsg adds a new message to the msgStream. This function is safe for // concurrent access. -func (ms *msgStream) AddMsg(msg lnwire.Message) { - // First, we'll attempt to receive from the producerSema struct. This - // acts as a semaphore to prevent us from indefinitely buffering - // incoming items from the wire. Either the msg queue isn't full, and - // we'll not block, or the queue is full, and we'll block until either - // we're signalled to quit, or a slot is freed up. - select { - case <-ms.producerSema: - case <-ms.peer.cg.Done(): - return - case <-ms.quit: - return +func (ms *msgStream) AddMsg(ctx context.Context, msg lnwire.Message) { + dropped, err := ms.queue.Enqueue(ctx, msg).Unpack() + if err != nil { + if !errors.Is(err, context.Canceled) && + !errors.Is(err, context.DeadlineExceeded) { + + ms.peer.log.Warnf("msgStream.AddMsg: failed to "+ + "enqueue message: %v", err) + } else { + ms.peer.log.Tracef("msgStream.AddMsg: context "+ + "canceled during enqueue: %v", err) + } } - // Next, we'll lock the condition, and add the message to the end of - // the message queue. - ms.msgCond.L.Lock() - ms.msgs = append(ms.msgs, msg) - ms.msgCond.L.Unlock() - - // With the message added, we signal to the msgConsumer that there are - // additional messages to consume. - ms.msgCond.Signal() + if dropped { + ms.peer.log.Debugf("msgStream.AddMsg: message %T "+ + "dropped from queue", msg) + } } // waitUntilLinkActive waits until the target link is active and returns a @@ -2026,6 +1989,8 @@ func (p *Brontide) readHandler() { // gossiper? p.initGossipSync() + ctx, _ := p.cg.Create(context.Background()) + discStream := newDiscMsgStream(p) discStream.Start() defer discStream.Stop() @@ -2141,11 +2106,15 @@ out: case *lnwire.Warning: targetChan = msg.ChanID - isLinkUpdate = p.handleWarningOrError(targetChan, msg) + isLinkUpdate = p.handleWarningOrError( + ctx, targetChan, msg, + ) case *lnwire.Error: targetChan = msg.ChanID - isLinkUpdate = p.handleWarningOrError(targetChan, msg) + isLinkUpdate = p.handleWarningOrError( + ctx, targetChan, msg, + ) case *lnwire.ChannelReestablish: targetChan = msg.ChanID @@ -2193,7 +2162,7 @@ out: *lnwire.ReplyChannelRange, *lnwire.ReplyShortChanIDsEnd: - discStream.AddMsg(msg) + discStream.AddMsg(ctx, msg) case *lnwire.Custom: err := p.handleCustomMessage(msg) @@ -2215,7 +2184,7 @@ out: if isLinkUpdate { // If this is a channel update, then we need to feed it // into the channel's in-order message stream. - p.sendLinkUpdateMsg(targetChan, nextMsg) + p.sendLinkUpdateMsg(ctx, targetChan, nextMsg) } idleTimer.Reset(idleTimeout) @@ -2330,8 +2299,8 @@ func (p *Brontide) storeError(err error) { // an error from a peer with an active channel, we'll store it in memory. // // NOTE: This method should only be called from within the readHandler. -func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID, - msg lnwire.Message) bool { +func (p *Brontide) handleWarningOrError(ctx context.Context, + chanID lnwire.ChannelID, msg lnwire.Message) bool { if errMsg, ok := msg.(*lnwire.Error); ok { p.storeError(errMsg) @@ -2342,7 +2311,7 @@ func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID, // with this peer. case chanID == lnwire.ConnectionWideID: for _, chanStream := range p.activeMsgStreams { - chanStream.AddMsg(msg) + chanStream.AddMsg(ctx, msg) } return false @@ -5297,7 +5266,9 @@ func (p *Brontide) handleRemovePendingChannel(req *newChannelMsg) { // sendLinkUpdateMsg sends a message that updates the channel to the // channel's message stream. -func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { +func (p *Brontide) sendLinkUpdateMsg(ctx context.Context, + cid lnwire.ChannelID, msg lnwire.Message) { + p.log.Tracef("Sending link update msg=%v", msg.MsgType()) chanStream, ok := p.activeMsgStreams[cid] @@ -5317,7 +5288,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { // With the stream obtained, add the message to the stream so we can // continue processing message. - chanStream.AddMsg(msg) + chanStream.AddMsg(ctx, msg) } // scaleTimeout multiplies the argument duration by a constant factor depending