diff --git a/go.mod b/go.mod index 6e2bd9f77..d0ff97ac7 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.2.3 + github.com/lightningnetwork/lnd/fn v1.2.5 github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/kvdb v1.4.11 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index 86c1c8a21..2ea42fd8c 100644 --- a/go.sum +++ b/go.sum @@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.2.3 h1:Q1OrgNSgQynVheBNa16CsKVov1JI5N2AR6G07x9Mles= -github.com/lightningnetwork/lnd/fn v1.2.3/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= +github.com/lightningnetwork/lnd/fn v1.2.5 h1:pGMz0BDUxrhvOtShD4FIysdVy+ulfFAnFvTKjZO5Pp8= +github.com/lightningnetwork/lnd/fn v1.2.5/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8= diff --git a/htlcswitch/quiescer.go b/htlcswitch/quiescer.go new file mode 100644 index 000000000..9bde04c32 --- /dev/null +++ b/htlcswitch/quiescer.go @@ -0,0 +1,339 @@ +package htlcswitch + +import ( + "fmt" + "sync" + + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" +) + +var ( + // ErrInvalidStfu indicates that the Stfu we have received is invalid. + // This can happen in instances where we have not sent Stfu but we have + // received one with the initiator field set to false. + ErrInvalidStfu = fmt.Errorf("stfu received is invalid") + + // ErrStfuAlreadySent indicates that this channel has already sent an + // Stfu message for this negotiation. + ErrStfuAlreadySent = fmt.Errorf("stfu already sent") + + // ErrStfuAlreadyRcvd indicates that this channel has already received + // an Stfu message for this negotiation. + ErrStfuAlreadyRcvd = fmt.Errorf("stfu already received") + + // ErrNoQuiescenceInitiator indicates that the caller has requested the + // quiescence initiator for a channel that is not yet quiescent. + ErrNoQuiescenceInitiator = fmt.Errorf( + "indeterminate quiescence initiator: channel is not quiescent", + ) + + // ErrPendingRemoteUpdates indicates that we have received an Stfu while + // the remote party has issued updates that are not yet bilaterally + // committed. + ErrPendingRemoteUpdates = fmt.Errorf( + "stfu received with pending remote updates", + ) + + // ErrPendingLocalUpdates indicates that we are attempting to send an + // Stfu while we have issued updates that are not yet bilaterally + // committed. + ErrPendingLocalUpdates = fmt.Errorf( + "stfu send attempted with pending local updates", + ) +) + +// QuiescerCfg is a config structure used to initialize a quiescer giving it the +// appropriate functionality to interact with the channel state that the +// quiescer must syncrhonize with. +type QuiescerCfg struct { + // chanID marks what channel we are managing the state machine for. This + // is important because the quiescer needs to know the ChannelID to + // construct the Stfu message. + chanID lnwire.ChannelID + + // channelInitiator indicates which ChannelParty originally opened the + // channel. This is used to break ties when both sides of the channel + // send Stfu claiming to be the initiator. + channelInitiator lntypes.ChannelParty + + // sendMsg is a function that can be used to send an Stfu message over + // the wire. + sendMsg func(lnwire.Stfu) error +} + +// Quiescer is a state machine that tracks progression through the quiescence +// protocol. +type Quiescer struct { + cfg QuiescerCfg + + // localInit indicates whether our path through this state machine was + // initiated by our node. This can be true or false independently of + // remoteInit. + localInit bool + + // remoteInit indicates whether we received Stfu from our peer where the + // message indicated that the remote node believes it was the initiator. + // This can be true or false independently of localInit. + remoteInit bool + + // sent tracks whether or not we have emitted Stfu for sending. + sent bool + + // received tracks whether or not we have received Stfu from our peer. + received bool + + sync.RWMutex +} + +// NewQuiescer creates a new quiescer for the given channel. +func NewQuiescer(cfg QuiescerCfg) Quiescer { + return Quiescer{ + cfg: cfg, + } +} + +// RecvStfu is called when we receive an Stfu message from the remote. +func (q *Quiescer) RecvStfu(msg lnwire.Stfu, + numPendingRemoteUpdates uint64) error { + + q.Lock() + defer q.Unlock() + + return q.recvStfu(msg, numPendingRemoteUpdates) +} + +// recvStfu is called when we receive an Stfu message from the remote. +func (q *Quiescer) recvStfu(msg lnwire.Stfu, + numPendingRemoteUpdates uint64) error { + + // At the time of this writing, this check that we have already received + // an Stfu is not strictly necessary, according to the specification. + // However, it is fishy if we do and it is unclear how we should handle + // such a case so we will err on the side of caution. + if q.received { + return fmt.Errorf("%w for channel %v", ErrStfuAlreadyRcvd, + q.cfg.chanID) + } + + // We need to check that the Stfu we are receiving is valid. + if !q.sent && !msg.Initiator { + return fmt.Errorf("%w for channel %v", ErrInvalidStfu, + q.cfg.chanID) + } + + if !q.canRecvStfu(numPendingRemoteUpdates) { + return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates, + q.cfg.chanID) + } + + q.received = true + + // If the remote party sets the initiator bit to true then we will + // remember that they are making a claim to the initiator role. This + // does not necessarily mean they will get it, though. + q.remoteInit = msg.Initiator + + return nil +} + +// MakeStfu is called when we are ready to send an Stfu message. It returns the +// Stfu message to be sent. +func (q *Quiescer) MakeStfu( + numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] { + + q.RLock() + defer q.RUnlock() + + return q.makeStfu(numPendingLocalUpdates) +} + +// makeStfu is called when we are ready to send an Stfu message. It returns the +// Stfu message to be sent. +func (q *Quiescer) makeStfu( + numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] { + + if q.sent { + return fn.Errf[lnwire.Stfu]("%w for channel %v", + ErrStfuAlreadySent, q.cfg.chanID) + } + + if !q.canSendStfu(numPendingLocalUpdates) { + return fn.Errf[lnwire.Stfu]("%w for channel %v", + ErrPendingLocalUpdates, q.cfg.chanID) + } + + stfu := lnwire.Stfu{ + ChanID: q.cfg.chanID, + Initiator: q.localInit, + } + + return fn.Ok(stfu) +} + +// OweStfu returns true if we owe the other party an Stfu. We owe the remote an +// Stfu when we have received but not yet sent an Stfu, or we are the initiator +// but have not yet sent an Stfu. +func (q *Quiescer) OweStfu() bool { + q.RLock() + defer q.RUnlock() + + return q.oweStfu() +} + +// oweStfu returns true if we owe the other party an Stfu. We owe the remote an +// Stfu when we have received but not yet sent an Stfu, or we are the initiator +// but have not yet sent an Stfu. +func (q *Quiescer) oweStfu() bool { + return q.received && !q.sent +} + +// NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when +// we have sent but not yet received an Stfu. +func (q *Quiescer) NeedStfu() bool { + q.RLock() + defer q.RUnlock() + + return q.needStfu() +} + +// needStfu returns true if the remote owes us an Stfu. They owe us an Stfu when +// we have sent but not yet received an Stfu. +func (q *Quiescer) needStfu() bool { + q.RLock() + defer q.RUnlock() + + return q.sent && !q.received +} + +// IsQuiescent returns true if the state machine has been driven all the way to +// completion. If this returns true, processes that depend on channel quiescence +// may proceed. +func (q *Quiescer) IsQuiescent() bool { + q.RLock() + defer q.RUnlock() + + return q.isQuiescent() +} + +// isQuiescent returns true if the state machine has been driven all the way to +// completion. If this returns true, processes that depend on channel quiescence +// may proceed. +func (q *Quiescer) isQuiescent() bool { + return q.sent && q.received +} + +// QuiescenceInitiator determines which ChannelParty is the initiator of +// quiescence for the purposes of downstream protocols. If the channel is not +// currently quiescent, this method will return ErrNoQuiescenceInitiator. +func (q *Quiescer) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] { + q.RLock() + defer q.RUnlock() + + return q.quiescenceInitiator() +} + +// quiescenceInitiator determines which ChannelParty is the initiator of +// quiescence for the purposes of downstream protocols. If the channel is not +// currently quiescent, this method will return ErrNoQuiescenceInitiator. +func (q *Quiescer) quiescenceInitiator() fn.Result[lntypes.ChannelParty] { + switch { + case !q.isQuiescent(): + return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator) + + case q.localInit && q.remoteInit: + // In the case of a tie, the channel initiator wins. + return fn.Ok(q.cfg.channelInitiator) + + case q.localInit: + return fn.Ok(lntypes.Local) + + case q.remoteInit: + return fn.Ok(lntypes.Remote) + } + + // unreachable + return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator) +} + +// CanSendUpdates returns true if we haven't yet sent an Stfu which would mark +// the end of our ability to send updates. +func (q *Quiescer) CanSendUpdates() bool { + q.RLock() + defer q.RUnlock() + + return q.canSendUpdates() +} + +// canSendUpdates returns true if we haven't yet sent an Stfu which would mark +// the end of our ability to send updates. +func (q *Quiescer) canSendUpdates() bool { + return !q.sent && !q.localInit +} + +// CanRecvUpdates returns true if we haven't yet received an Stfu which would +// mark the end of the remote's ability to send updates. +func (q *Quiescer) CanRecvUpdates() bool { + q.RLock() + defer q.RUnlock() + + return q.canRecvUpdates() +} + +// canRecvUpdates returns true if we haven't yet received an Stfu which would +// mark the end of the remote's ability to send updates. +func (q *Quiescer) canRecvUpdates() bool { + return !q.received +} + +// CanSendStfu returns true if we can send an Stfu. +func (q *Quiescer) CanSendStfu(numPendingLocalUpdates uint64) bool { + q.RLock() + defer q.RUnlock() + + return q.canSendStfu(numPendingLocalUpdates) +} + +// canSendStfu returns true if we can send an Stfu. +func (q *Quiescer) canSendStfu(numPendingLocalUpdates uint64) bool { + return numPendingLocalUpdates == 0 && !q.sent +} + +// CanRecvStfu returns true if we can receive an Stfu. +func (q *Quiescer) CanRecvStfu(numPendingRemoteUpdates uint64) bool { + q.RLock() + defer q.RUnlock() + + return q.canRecvStfu(numPendingRemoteUpdates) +} + +// canRecvStfu returns true if we can receive an Stfu. +func (q *Quiescer) canRecvStfu(numPendingRemoteUpdates uint64) bool { + return numPendingRemoteUpdates == 0 && !q.received +} + +// SendOwedStfu sends Stfu if it owes one. It returns an error if the state +// machine is in an invalid state. +func (q *Quiescer) SendOwedStfu(numPendingLocalUpdates uint64) error { + q.Lock() + defer q.Unlock() + + return q.sendOwedStfu(numPendingLocalUpdates) +} + +// sendOwedStfu sends Stfu if it owes one. It returns an error if the state +// machine is in an invalid state. +func (q *Quiescer) sendOwedStfu(numPendingLocalUpdates uint64) error { + if !q.oweStfu() || !q.canSendStfu(numPendingLocalUpdates) { + return nil + } + + err := q.makeStfu(numPendingLocalUpdates).Sink(q.cfg.sendMsg) + + if err == nil { + q.sent = true + } + + return err +} diff --git a/htlcswitch/quiescer_test.go b/htlcswitch/quiescer_test.go new file mode 100644 index 000000000..914f64b3c --- /dev/null +++ b/htlcswitch/quiescer_test.go @@ -0,0 +1,247 @@ +package htlcswitch + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +var cid = lnwire.ChannelID(bytes.Repeat([]byte{0x00}, 32)) + +type quiescerTestHarness struct { + pendingUpdates lntypes.Dual[uint64] + quiescer Quiescer + conn <-chan lnwire.Stfu +} + +func initQuiescerTestHarness() *quiescerTestHarness { + conn := make(chan lnwire.Stfu, 1) + harness := &quiescerTestHarness{ + pendingUpdates: lntypes.Dual[uint64]{}, + conn: conn, + } + + harness.quiescer = NewQuiescer(QuiescerCfg{ + chanID: cid, + sendMsg: func(msg lnwire.Stfu) error { + conn <- msg + return nil + }, + }) + + return harness +} + +// TestQuiescerDoubleRecvInvalid ensures that we get an error response when we +// receive the Stfu message twice during the lifecycle of the quiescer. +func TestQuiescerDoubleRecvInvalid(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + + err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.NoError(t, err) + err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.Error(t, err, ErrStfuAlreadyRcvd) +} + +// TestQuiescerPendingUpdatesRecvInvalid ensures that we get an error if we +// receive the Stfu message while the Remote party has panding updates on the +// channel. +func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + + harness.pendingUpdates.SetForParty(lntypes.Remote, 1) + err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.ErrorIs(t, err, ErrPendingRemoteUpdates) +} + +// TestQuiescenceRemoteInit ensures that we can successfully traverse the state +// graph of quiescence beginning with the Remote party initiating quiescence. +func TestQuiescenceRemoteInit(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + + harness.pendingUpdates.SetForParty(lntypes.Local, 1) + + err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.NoError(t, err) + + err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + require.NoError(t, err) + + select { + case <-harness.conn: + t.Fatalf("stfu sent when not expected") + default: + } + + harness.pendingUpdates.SetForParty(lntypes.Local, 0) + err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + require.NoError(t, err) + + select { + case msg := <-harness.conn: + require.False(t, msg.Initiator) + default: + t.Fatalf("stfu not sent when expected") + } +} + +// TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote +// party when we have a receive first traversal of the quiescer's state graph. +func TestQuiescenceInitiator(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + require.True(t, harness.quiescer.QuiescenceInitiator().IsErr()) + + // Receive + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + require.NoError( + t, harness.quiescer.RecvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + require.True(t, harness.quiescer.QuiescenceInitiator().IsErr()) + + // Send + require.NoError( + t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local), + ) + require.Equal( + t, harness.quiescer.QuiescenceInitiator(), + fn.Ok(lntypes.Remote), + ) +} + +// TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel +// updates prior to but not after we receive Stfu. +func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + require.True(t, harness.quiescer.CanRecvUpdates()) + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + require.NoError( + t, harness.quiescer.RecvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + require.False(t, harness.quiescer.CanRecvUpdates()) +} + +// TestQuiescenceCantSendUpdatesAfterStfu tests that we can send channel updates +// prior to but not after we send Stfu. +func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + require.True(t, harness.quiescer.CanSendUpdates()) + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + + err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote) + require.NoError(t, err) + + err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local) + require.NoError(t, err) + + require.False(t, harness.quiescer.CanSendUpdates()) +} + +// TestQuiescenceStfuNotNeededAfterRecv tests that after we receive an Stfu we +// do not needStfu either before or after receiving it if we do not initiate +// quiescence. +func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + require.False(t, harness.quiescer.NeedStfu()) + + require.NoError( + t, harness.quiescer.RecvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + + require.False(t, harness.quiescer.NeedStfu()) +} + +// TestQuiescenceInappropriateMakeStfuReturnsErr ensures that we cannot call +// makeStfu at times when it would be a protocol violation to send it. +func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) { + t.Parallel() + + harness := initQuiescerTestHarness() + + harness.pendingUpdates.SetForParty(lntypes.Local, 1) + + require.True( + t, harness.quiescer.MakeStfu( + harness.pendingUpdates.Local, + ).IsErr(), + ) + + harness.pendingUpdates.SetForParty(lntypes.Local, 0) + msg := lnwire.Stfu{ + ChanID: cid, + Initiator: true, + } + require.NoError( + t, harness.quiescer.RecvStfu( + msg, harness.pendingUpdates.Remote, + ), + ) + require.True( + t, harness.quiescer.MakeStfu( + harness.pendingUpdates.Local, + ).IsOk(), + ) + + require.NoError( + t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local), + ) + require.True( + t, harness.quiescer.MakeStfu( + harness.pendingUpdates.Local, + ).IsErr(), + ) +}