peer: replace the old cond based msgStream w/ BackpressureQueue[T]

In this commit, we replace the old condition variable based msgStream
with the new back pressure queue. The implementation details at this
abstraction level have been greatly simplified. For now we just pass a
predicate that'll never drop the incoming packets.
This commit is contained in:
Olaoluwa Osuntokun
2025-05-19 17:07:53 -07:00
parent 9748f98110
commit 3b27b69989

View File

@@ -1730,12 +1730,9 @@ type msgStream struct {
startMsg string startMsg string
stopMsg string stopMsg string
msgCond *sync.Cond // queue is the underlying backpressure-aware queue that manages
msgs []lnwire.Message // messages.
queue *queue.BackpressureQueue[lnwire.Message]
mtx sync.Mutex
producerSema chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
@@ -1744,28 +1741,28 @@ type msgStream struct {
// newMsgStream creates a new instance of a chanMsgStream for a particular // newMsgStream creates a new instance of a chanMsgStream for a particular
// channel identified by its channel ID. bufSize is the max number of messages // channel identified by its channel ID. bufSize is the max number of messages
// that should be buffered in the internal queue. Callers should set this to a // that should be buffered in the internal queue. Callers should set this to a
// sane value that avoids blocking unnecessarily, but doesn't allow an // sane value that avoids blocking unnecessarily, but doesn't allow an unbounded
// unbounded amount of memory to be allocated to buffer incoming messages. // amount of memory to be allocated to buffer incoming messages.
func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize uint32, func newMsgStream(p *Brontide, startMsg, stopMsg string, bufSize int,
apply func(lnwire.Message)) *msgStream { apply func(lnwire.Message)) *msgStream {
stream := &msgStream{ stream := &msgStream{
peer: p, peer: p,
apply: apply, apply: apply,
startMsg: startMsg, startMsg: startMsg,
stopMsg: stopMsg, stopMsg: stopMsg,
producerSema: make(chan struct{}, bufSize), quit: make(chan struct{}),
quit: make(chan struct{}),
} }
stream.msgCond = sync.NewCond(&stream.mtx)
// Before we return the active stream, we'll populate the producer's // Initialize the backpressure queue. The predicate always returns
// semaphore channel. We'll use this to ensure that the producer won't // false, meaning we don't proactively drop messages based on queue
// attempt to allocate memory in the queue for an item until it has // length here.
// sufficient extra space. alwaysFalsePredicate := func(int, lnwire.Message) bool {
for i := uint32(0); i < bufSize; i++ { return false
stream.producerSema <- struct{}{}
} }
stream.queue = queue.NewBackpressureQueue[lnwire.Message](
bufSize, alwaysFalsePredicate,
)
return stream return stream
} }
@@ -1778,17 +1775,8 @@ func (ms *msgStream) Start() {
// Stop stops the chanMsgStream. // Stop stops the chanMsgStream.
func (ms *msgStream) Stop() { func (ms *msgStream) Stop() {
// TODO(roasbeef): signal too?
close(ms.quit) close(ms.quit)
// Now that we've closed the channel, we'll repeatedly signal the msg
// consumer until we've detected that it has exited.
for atomic.LoadInt32(&ms.streamShutdown) == 0 {
ms.msgCond.Signal()
time.Sleep(time.Millisecond * 100)
}
ms.wg.Wait() ms.wg.Wait()
} }
@@ -1796,82 +1784,57 @@ func (ms *msgStream) Stop() {
// readHandler directly to the target channel. // readHandler directly to the target channel.
func (ms *msgStream) msgConsumer() { func (ms *msgStream) msgConsumer() {
defer ms.wg.Done() defer ms.wg.Done()
defer peerLog.Tracef(ms.stopMsg) defer ms.peer.log.Tracef(ms.stopMsg)
defer atomic.StoreInt32(&ms.streamShutdown, 1) defer atomic.StoreInt32(&ms.streamShutdown, 1)
peerLog.Tracef(ms.startMsg) ms.peer.log.Tracef(ms.startMsg)
ctx, _ := ms.peer.cg.Create(context.Background())
for { for {
// First, we'll check our condition. If the queue of messages // Dequeue the next message. This will block until a message is
// is empty, then we'll wait until a new item is added. // available or the context is canceled.
ms.msgCond.L.Lock() msg, err := ms.queue.Dequeue(ctx)
for len(ms.msgs) == 0 { if err != nil {
ms.msgCond.Wait() ms.peer.log.Warnf("unable to dequeue message: %v", err)
return
// If we woke up in order to exit, then we'll do so.
// Otherwise, we'll check the message queue for any new
// items.
select {
case <-ms.peer.cg.Done():
ms.msgCond.L.Unlock()
return
case <-ms.quit:
ms.msgCond.L.Unlock()
return
default:
}
} }
// Grab the message off the front of the queue, shifting the // Apply the dequeued message.
// slice's reference down one in order to remove the message
// from the queue.
msg := ms.msgs[0]
ms.msgs[0] = nil // Set to nil to prevent GC leak.
ms.msgs = ms.msgs[1:]
ms.msgCond.L.Unlock()
ms.apply(msg) ms.apply(msg)
// We've just successfully processed an item, so we'll signal // As a precaution, we'll check to see if we're already shutting
// to the producer that a new slot in the buffer. We'll use // down before adding a new message to the queue.
// this to bound the size of the buffer to avoid allowing it to
// grow indefinitely.
select { select {
case ms.producerSema <- struct{}{}:
case <-ms.peer.cg.Done(): case <-ms.peer.cg.Done():
return return
case <-ms.quit: case <-ms.quit:
return return
default:
} }
} }
} }
// AddMsg adds a new message to the msgStream. This function is safe for // AddMsg adds a new message to the msgStream. This function is safe for
// concurrent access. // concurrent access.
func (ms *msgStream) AddMsg(msg lnwire.Message) { func (ms *msgStream) AddMsg(ctx context.Context, msg lnwire.Message) {
// First, we'll attempt to receive from the producerSema struct. This dropped, err := ms.queue.Enqueue(ctx, msg).Unpack()
// acts as a semaphore to prevent us from indefinitely buffering if err != nil {
// incoming items from the wire. Either the msg queue isn't full, and if !errors.Is(err, context.Canceled) &&
// we'll not block, or the queue is full, and we'll block until either !errors.Is(err, context.DeadlineExceeded) {
// we're signalled to quit, or a slot is freed up.
select { ms.peer.log.Warnf("msgStream.AddMsg: failed to "+
case <-ms.producerSema: "enqueue message: %v", err)
case <-ms.peer.cg.Done(): } else {
return ms.peer.log.Tracef("msgStream.AddMsg: context "+
case <-ms.quit: "canceled during enqueue: %v", err)
return }
} }
// Next, we'll lock the condition, and add the message to the end of if dropped {
// the message queue. ms.peer.log.Debugf("msgStream.AddMsg: message %T "+
ms.msgCond.L.Lock() "dropped from queue", msg)
ms.msgs = append(ms.msgs, msg) }
ms.msgCond.L.Unlock()
// With the message added, we signal to the msgConsumer that there are
// additional messages to consume.
ms.msgCond.Signal()
} }
// waitUntilLinkActive waits until the target link is active and returns a // waitUntilLinkActive waits until the target link is active and returns a
@@ -2026,6 +1989,8 @@ func (p *Brontide) readHandler() {
// gossiper? // gossiper?
p.initGossipSync() p.initGossipSync()
ctx, _ := p.cg.Create(context.Background())
discStream := newDiscMsgStream(p) discStream := newDiscMsgStream(p)
discStream.Start() discStream.Start()
defer discStream.Stop() defer discStream.Stop()
@@ -2141,11 +2106,15 @@ out:
case *lnwire.Warning: case *lnwire.Warning:
targetChan = msg.ChanID targetChan = msg.ChanID
isLinkUpdate = p.handleWarningOrError(targetChan, msg) isLinkUpdate = p.handleWarningOrError(
ctx, targetChan, msg,
)
case *lnwire.Error: case *lnwire.Error:
targetChan = msg.ChanID targetChan = msg.ChanID
isLinkUpdate = p.handleWarningOrError(targetChan, msg) isLinkUpdate = p.handleWarningOrError(
ctx, targetChan, msg,
)
case *lnwire.ChannelReestablish: case *lnwire.ChannelReestablish:
targetChan = msg.ChanID targetChan = msg.ChanID
@@ -2193,7 +2162,7 @@ out:
*lnwire.ReplyChannelRange, *lnwire.ReplyChannelRange,
*lnwire.ReplyShortChanIDsEnd: *lnwire.ReplyShortChanIDsEnd:
discStream.AddMsg(msg) discStream.AddMsg(ctx, msg)
case *lnwire.Custom: case *lnwire.Custom:
err := p.handleCustomMessage(msg) err := p.handleCustomMessage(msg)
@@ -2215,7 +2184,7 @@ out:
if isLinkUpdate { if isLinkUpdate {
// If this is a channel update, then we need to feed it // If this is a channel update, then we need to feed it
// into the channel's in-order message stream. // into the channel's in-order message stream.
p.sendLinkUpdateMsg(targetChan, nextMsg) p.sendLinkUpdateMsg(ctx, targetChan, nextMsg)
} }
idleTimer.Reset(idleTimeout) idleTimer.Reset(idleTimeout)
@@ -2330,8 +2299,8 @@ func (p *Brontide) storeError(err error) {
// an error from a peer with an active channel, we'll store it in memory. // an error from a peer with an active channel, we'll store it in memory.
// //
// NOTE: This method should only be called from within the readHandler. // NOTE: This method should only be called from within the readHandler.
func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID, func (p *Brontide) handleWarningOrError(ctx context.Context,
msg lnwire.Message) bool { chanID lnwire.ChannelID, msg lnwire.Message) bool {
if errMsg, ok := msg.(*lnwire.Error); ok { if errMsg, ok := msg.(*lnwire.Error); ok {
p.storeError(errMsg) p.storeError(errMsg)
@@ -2342,7 +2311,7 @@ func (p *Brontide) handleWarningOrError(chanID lnwire.ChannelID,
// with this peer. // with this peer.
case chanID == lnwire.ConnectionWideID: case chanID == lnwire.ConnectionWideID:
for _, chanStream := range p.activeMsgStreams { for _, chanStream := range p.activeMsgStreams {
chanStream.AddMsg(msg) chanStream.AddMsg(ctx, msg)
} }
return false return false
@@ -5297,7 +5266,9 @@ func (p *Brontide) handleRemovePendingChannel(req *newChannelMsg) {
// sendLinkUpdateMsg sends a message that updates the channel to the // sendLinkUpdateMsg sends a message that updates the channel to the
// channel's message stream. // channel's message stream.
func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) { func (p *Brontide) sendLinkUpdateMsg(ctx context.Context,
cid lnwire.ChannelID, msg lnwire.Message) {
p.log.Tracef("Sending link update msg=%v", msg.MsgType()) p.log.Tracef("Sending link update msg=%v", msg.MsgType())
chanStream, ok := p.activeMsgStreams[cid] chanStream, ok := p.activeMsgStreams[cid]
@@ -5317,7 +5288,7 @@ func (p *Brontide) sendLinkUpdateMsg(cid lnwire.ChannelID, msg lnwire.Message) {
// With the stream obtained, add the message to the stream so we can // With the stream obtained, add the message to the stream so we can
// continue processing message. // continue processing message.
chanStream.AddMsg(msg) chanStream.AddMsg(ctx, msg)
} }
// scaleTimeout multiplies the argument duration by a constant factor depending // scaleTimeout multiplies the argument duration by a constant factor depending