From 442f1dd677fe46b968e075a7c88ac125174ec0a2 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Sun, 26 Nov 2023 12:57:11 -0800 Subject: [PATCH] peer: handle close messages using link lifecycle hooks --- peer/brontide.go | 58 +++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/peer/brontide.go b/peer/brontide.go index df571018d..25499c75c 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -3602,6 +3602,8 @@ func (p *Brontide) StartTime() time.Time { // message is received from the remote peer. We'll use this message to advance // the chan closer state machine. func (p *Brontide) handleCloseMsg(msg *closeMsg) { + link := p.fetchLinkFromKeyAndCid(msg.cid) + // We'll now fetch the matching closing state machine in order to continue, // or finalize the channel closure process. chanCloser, err := p.fetchActiveChanCloser(msg.cid) @@ -3641,6 +3643,15 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { // We'll either continue negotiation, or halt. switch typed := msg.msg.(type) { case *lnwire.Shutdown: + // Disable incoming adds immediately. + if link != nil { + err := link.DisableAdds(htlcswitch.Incoming) + if err != nil { + p.log.Warnf("incoming link adds already disabled: %v", + link.ChanID()) + } + } + oShutdown, err := chanCloser.ReceiveShutdown(*typed) if err != nil { handleErr(err) @@ -3648,18 +3659,49 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) { } oShutdown.WhenSome(func(msg lnwire.Shutdown) { - p.queueMsg(&msg, nil) + // if the link is nil it means we can immediately queue + // the Shutdown message since we don't have to wait for + // commitment transaction synchronization + if link == nil { + p.queueMsg(typed, nil) + return + } + // When we have a Shutdown to send, we defer it til the + // next time we send a CommitSig to remain spec + // compliant. + link.OnCommitOnce(htlcswitch.Outgoing, func() { + err := link.DisableAdds(htlcswitch.Outgoing) + if err != nil { + p.log.Warn(err.Error()) + } + p.queueMsg(&msg, nil) + }) }) - oClosingSigned, err := chanCloser.BeginNegotiation() - if err != nil { - handleErr(err) - return + beginNegotiation := func() { + oClosingSigned, err := chanCloser.BeginNegotiation() + if err != nil { + handleErr(err) + return + } + + oClosingSigned.WhenSome(func(msg lnwire.ClosingSigned) { + p.queueMsg(&msg, nil) + }) } - oClosingSigned.WhenSome(func(msg lnwire.ClosingSigned) { - p.queueMsg(&msg, nil) - }) + if link == nil { + beginNegotiation() + } else { + // Now we register a flush hook to advance the + // ChanCloser and possibly send out a ClosingSigned + // when the link finishes draining. + link.OnFlushedOnce(func() { + // Remove link in goroutine to prevent deadlock. + go p.cfg.Switch.RemoveLink(msg.cid) + beginNegotiation() + }) + } case *lnwire.ClosingSigned: oClosingSigned, err := chanCloser.ReceiveClosingSigned(*typed)