From 7255b7357cd66e77324049cf5063e49d0d4a92d8 Mon Sep 17 00:00:00 2001 From: Keagan McClelland Date: Tue, 12 Mar 2024 12:32:43 -0700 Subject: [PATCH] htlcswitch: implement InitStfu link operation --- htlcswitch/link.go | 39 ++++++++- htlcswitch/quiescer.go | 65 ++++++++++++++- htlcswitch/quiescer_test.go | 155 +++++++++++++++++++++++++++++++++--- 3 files changed, 244 insertions(+), 15 deletions(-) diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 442522de1..449c81179 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -396,6 +396,11 @@ type channelLink struct { // respect to the quiescence protocol. quiescer Quiescer + // quiescenceReqs is a queue of requests to quiesce this link. The + // members of the queue are send-only channels we should call back with + // the result. + quiescenceReqs chan StfuReq + // ContextGuard is a helper that encapsulates a wait group and quit // channel and allows contexts that either block or cancel on those // depending on the use case. @@ -481,6 +486,10 @@ func NewChannelLink(cfg ChannelLinkConfig, }, } + quiescenceReqs := make( + chan fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]], 1, + ) + return &channelLink{ cfg: cfg, channel: channel, @@ -491,6 +500,7 @@ func NewChannelLink(cfg ChannelLinkConfig, outgoingCommitHooks: newHookMap(), incomingCommitHooks: newHookMap(), quiescer: NewQuiescer(quiescerCfg), + quiescenceReqs: quiescenceReqs, ContextGuard: fn.NewContextGuard(), } } @@ -745,12 +755,17 @@ func (l *channelLink) OnCommitOnce(direction LinkDirection, hook func()) { // may be removed or reworked in the future as RPC initiated quiescence is a // holdover until we have downstream protocols that use it. func (l *channelLink) InitStfu() <-chan fn.Result[lntypes.ChannelParty] { - // TODO(proofofkeags): Implement - c := make(chan fn.Result[lntypes.ChannelParty], 1) + req, out := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( + fn.Unit{}, + ) - c <- fn.Errf[lntypes.ChannelParty]("InitStfu not yet implemented") + select { + case l.quiescenceReqs <- req: + case <-l.Quit: + req.Resolve(fn.Err[lntypes.ChannelParty](ErrLinkShuttingDown)) + } - return c + return out } // isReestablished returns true if the link has successfully completed the @@ -1498,6 +1513,22 @@ func (l *channelLink) htlcManager() { ) } + case qReq := <-l.quiescenceReqs: + l.quiescer.InitStfu(qReq) + + pendingOnLocal := l.channel.NumPendingUpdates( + lntypes.Local, lntypes.Local, + ) + pendingOnRemote := l.channel.NumPendingUpdates( + lntypes.Local, lntypes.Remote, + ) + if err := l.quiescer.SendOwedStfu( + pendingOnLocal + pendingOnRemote, + ); err != nil { + l.stfuFailf("%s", err.Error()) + qReq.Resolve(fn.Err[lntypes.ChannelParty](err)) + } + case <-l.Quit: return } diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go index 9bde04c32..300988c7a 100644 --- a/htlcswitch/quiescer.go +++ b/htlcswitch/quiescer.go @@ -44,6 +44,8 @@ var ( ) ) +type StfuReq = fn.Req[fn.Unit, fn.Result[lntypes.ChannelParty]] + // QuiescerCfg is a config structure used to initialize a quiescer giving it the // appropriate functionality to interact with the channel state that the // quiescer must syncrhonize with. @@ -84,6 +86,10 @@ type Quiescer struct { // received tracks whether or not we have received Stfu from our peer. received bool + // activeQuiescenceRequest is a possibly None Request that we should + // resolve when we complete quiescence. + activeQuiescenceReq fn.Option[StfuReq] + sync.RWMutex } @@ -135,6 +141,10 @@ func (q *Quiescer) recvStfu(msg lnwire.Stfu, // does not necessarily mean they will get it, though. q.remoteInit = msg.Initiator + // Since we just received an Stfu, we may have a newly quiesced state. + // If so, we will try to resolve any outstanding StfuReqs. + q.tryResolveStfuReq() + return nil } @@ -186,7 +196,7 @@ func (q *Quiescer) OweStfu() bool { // Stfu when we have received but not yet sent an Stfu, or we are the initiator // but have not yet sent an Stfu. func (q *Quiescer) oweStfu() bool { - return q.received && !q.sent + return (q.received || q.localInit) && !q.sent } // NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when @@ -333,7 +343,60 @@ func (q *Quiescer) sendOwedStfu(numPendingLocalUpdates uint64) error { if err == nil { q.sent = true + + // Since we just sent an Stfu, we may have a newly quiesced + // state. If so, we will try to resolve any outstanding + // StfuReqs. + q.tryResolveStfuReq() } return err } + +// TryResolveStfuReq attempts to resolve the active quiescence request if the +// state machine has reached a quiescent state. +func (q *Quiescer) TryResolveStfuReq() { + q.Lock() + defer q.Unlock() + + q.tryResolveStfuReq() +} + +// tryResolveStfuReq attempts to resolve the active quiescence request if the +// state machine has reached a quiescent state. +func (q *Quiescer) tryResolveStfuReq() { + q.activeQuiescenceReq.WhenSome( + func(req StfuReq) { + if q.isQuiescent() { + req.Resolve(q.quiescenceInitiator()) + q.activeQuiescenceReq = fn.None[StfuReq]() + } + }, + ) +} + +// InitStfu instructs the quiescer that we intend to begin a quiescence +// negotiation where we are the initiator. We don't yet send stfu yet because +// we need to wait for the link to give us a valid opportunity to do so. +func (q *Quiescer) InitStfu(req StfuReq) { + q.Lock() + defer q.Unlock() + + q.initStfu(req) +} + +// initStfu instructs the quiescer that we intend to begin a quiescence +// negotiation where we are the initiator. We don't yet send stfu yet because +// we need to wait for the link to give us a valid opportunity to do so. +func (q *Quiescer) initStfu(req StfuReq) { + if q.localInit { + req.Resolve(fn.Errf[lntypes.ChannelParty]( + "quiescence already requested", + )) + + return + } + + q.localInit = true + q.activeQuiescenceReq = fn.Some(req) +} diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go index 914f64b3c..5c6e2fa4c 100644 --- a/htlcswitch/quiescer_test.go +++ b/htlcswitch/quiescer_test.go @@ -18,7 +18,9 @@ type quiescerTestHarness struct { conn <-chan lnwire.Stfu } -func initQuiescerTestHarness() *quiescerTestHarness { +func initQuiescerTestHarness( + channelInitiator lntypes.ChannelParty) *quiescerTestHarness { + conn := make(chan lnwire.Stfu, 1) harness := &quiescerTestHarness{ pendingUpdates: lntypes.Dual[uint64]{}, @@ -26,7 +28,8 @@ func initQuiescerTestHarness() *quiescerTestHarness { } harness.quiescer = NewQuiescer(QuiescerCfg{ - chanID: cid, + chanID: cid, + channelInitiator: channelInitiator, sendMsg: func(msg lnwire.Stfu) error { conn <- msg return nil @@ -41,7 +44,7 @@ func initQuiescerTestHarness() *quiescerTestHarness { func TestQuiescerDoubleRecvInvalid(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) msg := lnwire.Stfu{ ChanID: cid, @@ -60,7 +63,7 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) { func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) msg := lnwire.Stfu{ ChanID: cid, @@ -77,7 +80,7 @@ func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) { func TestQuiescenceRemoteInit(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) msg := lnwire.Stfu{ ChanID: cid, @@ -110,12 +113,61 @@ func TestQuiescenceRemoteInit(t *testing.T) { } } +func TestQuiescenceLocalInit(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness(lntypes.Local) + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + harness.pendingUpdates.SetForParty(lntypes.Local, 1) + + stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( + fn.Unit{}, + ) + harness.quiescer.InitStfu(stfuReq) + + harness.pendingUpdates.SetForParty(lntypes.Local, 1) + err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + require.NoError(t, err) + + select { + case <-harness.conn: + t.Fatalf("stfu sent when not expected") + default: + } + + harness.pendingUpdates.SetForParty(lntypes.Local, 0) + err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + require.NoError(t, err) + + select { + case msg := <-harness.conn: + require.True(t, msg.Initiator) + default: + t.Fatalf("stfu not sent when expected") + } + + err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.NoError(t, err) + + select { + case party := <-stfuRes: + require.Equal(t, fn.Ok(lntypes.Local), party) + default: + t.Fatalf("quiescence request not resolved") + } +} + // TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote // party when we have a receive first traversal of the quiescer's state graph. func TestQuiescenceInitiator(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + // Remote Initiated + harness := initQuiescerTestHarness(lntypes.Local) require.True(t, harness.quiescer.QuiescenceInitiator().IsErr()) // Receive @@ -138,6 +190,48 @@ func TestQuiescenceInitiator(t *testing.T) { t, harness.quiescer.QuiescenceInitiator(), fn.Ok(lntypes.Remote), ) + + // Local Initiated + harness = initQuiescerTestHarness(lntypes.Local) + require.True(t, harness.quiescer.quiescenceInitiator().IsErr()) + + req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( + fn.Unit{}, + ) + harness.quiescer.initStfu(req) + req2, res2 := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( + fn.Unit{}, + ) + harness.quiescer.initStfu(req2) + select { + case initiator := <-res2: + require.True(t, initiator.IsErr()) + default: + t.Fatal("quiescence request not resolved") + } + + require.NoError( + t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local), + ) + require.True(t, harness.quiescer.quiescenceInitiator().IsErr()) + + msg = lnwire.Stfu{ + ChanID: cid, + Initiator: false, + } + require.NoError( + t, harness.quiescer.recvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + require.True(t, harness.quiescer.quiescenceInitiator().IsOk()) + + select { + case initiator := <-res: + require.Equal(t, fn.Ok(lntypes.Local), initiator) + default: + t.Fatal("quiescence request not resolved") + } } // TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel @@ -145,7 +239,7 @@ func TestQuiescenceInitiator(t *testing.T) { func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) require.True(t, harness.quiescer.CanRecvUpdates()) msg := lnwire.Stfu{ @@ -165,7 +259,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) { func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) require.True(t, harness.quiescer.CanSendUpdates()) msg := lnwire.Stfu{ @@ -188,7 +282,7 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) { func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) msg := lnwire.Stfu{ ChanID: cid, @@ -210,7 +304,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) { func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) { t.Parallel() - harness := initQuiescerTestHarness() + harness := initQuiescerTestHarness(lntypes.Local) harness.pendingUpdates.SetForParty(lntypes.Local, 1) @@ -245,3 +339,44 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) { ).IsErr(), ) } + +// TestQuiescerTieBreaker ensures that if both parties attempt to claim the +// initiator role that the result of the negotiation breaks the tie using the +// channel initiator. +func TestQuiescerTieBreaker(t *testing.T) { + t.Parallel() + + for _, initiator := range []lntypes.ChannelParty{ + lntypes.Local, lntypes.Remote, + } { + harness := initQuiescerTestHarness(initiator) + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + + req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]]( + fn.Unit{}, + ) + + harness.quiescer.InitStfu(req) + require.NoError( + t, harness.quiescer.RecvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + require.NoError( + t, harness.quiescer.SendOwedStfu( + harness.pendingUpdates.Local, + ), + ) + + select { + case party := <-res: + require.Equal(t, fn.Ok(initiator), party) + default: + t.Fatal("quiescence party unavailable") + } + } +}