htlcswitch: define state machine for quiescence

htlcswitch: add sendOwedStfu method to quiescer
This commit is contained in:
Keagan McClelland
2023-12-08 12:50:43 -08:00
parent fbeab726e1
commit f5b7866287
4 changed files with 589 additions and 3 deletions

2
go.mod
View File

@@ -36,7 +36,7 @@ require (
github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb
github.com/lightningnetwork/lnd/cert v1.2.2
github.com/lightningnetwork/lnd/clock v1.1.1
github.com/lightningnetwork/lnd/fn v1.2.3
github.com/lightningnetwork/lnd/fn v1.2.5
github.com/lightningnetwork/lnd/healthcheck v1.2.6
github.com/lightningnetwork/lnd/kvdb v1.4.11
github.com/lightningnetwork/lnd/queue v1.1.1

4
go.sum
View File

@@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf
github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U=
github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0=
github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ=
github.com/lightningnetwork/lnd/fn v1.2.3 h1:Q1OrgNSgQynVheBNa16CsKVov1JI5N2AR6G07x9Mles=
github.com/lightningnetwork/lnd/fn v1.2.3/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0=
github.com/lightningnetwork/lnd/fn v1.2.5 h1:pGMz0BDUxrhvOtShD4FIysdVy+ulfFAnFvTKjZO5Pp8=
github.com/lightningnetwork/lnd/fn v1.2.5/go.mod h1:SyFohpVrARPKH3XVAJZlXdVe+IwMYc4OMAvrDY32kw0=
github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI=
github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ=
github.com/lightningnetwork/lnd/kvdb v1.4.11 h1:fk1HMVFrsVK3xqU7q+JWHRgBltw/a2qIg1E3zazMb/8=

339
htlcswitch/quiescer.go Normal file
View File

@@ -0,0 +1,339 @@
package htlcswitch
import (
"fmt"
"sync"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
)
var (
// ErrInvalidStfu indicates that the Stfu we have received is invalid.
// This can happen in instances where we have not sent Stfu but we have
// received one with the initiator field set to false.
ErrInvalidStfu = fmt.Errorf("stfu received is invalid")
// ErrStfuAlreadySent indicates that this channel has already sent an
// Stfu message for this negotiation.
ErrStfuAlreadySent = fmt.Errorf("stfu already sent")
// ErrStfuAlreadyRcvd indicates that this channel has already received
// an Stfu message for this negotiation.
ErrStfuAlreadyRcvd = fmt.Errorf("stfu already received")
// ErrNoQuiescenceInitiator indicates that the caller has requested the
// quiescence initiator for a channel that is not yet quiescent.
ErrNoQuiescenceInitiator = fmt.Errorf(
"indeterminate quiescence initiator: channel is not quiescent",
)
// ErrPendingRemoteUpdates indicates that we have received an Stfu while
// the remote party has issued updates that are not yet bilaterally
// committed.
ErrPendingRemoteUpdates = fmt.Errorf(
"stfu received with pending remote updates",
)
// ErrPendingLocalUpdates indicates that we are attempting to send an
// Stfu while we have issued updates that are not yet bilaterally
// committed.
ErrPendingLocalUpdates = fmt.Errorf(
"stfu send attempted with pending local updates",
)
)
// QuiescerCfg is a config structure used to initialize a quiescer giving it the
// appropriate functionality to interact with the channel state that the
// quiescer must syncrhonize with.
type QuiescerCfg struct {
// chanID marks what channel we are managing the state machine for. This
// is important because the quiescer needs to know the ChannelID to
// construct the Stfu message.
chanID lnwire.ChannelID
// channelInitiator indicates which ChannelParty originally opened the
// channel. This is used to break ties when both sides of the channel
// send Stfu claiming to be the initiator.
channelInitiator lntypes.ChannelParty
// sendMsg is a function that can be used to send an Stfu message over
// the wire.
sendMsg func(lnwire.Stfu) error
}
// Quiescer is a state machine that tracks progression through the quiescence
// protocol.
type Quiescer struct {
cfg QuiescerCfg
// localInit indicates whether our path through this state machine was
// initiated by our node. This can be true or false independently of
// remoteInit.
localInit bool
// remoteInit indicates whether we received Stfu from our peer where the
// message indicated that the remote node believes it was the initiator.
// This can be true or false independently of localInit.
remoteInit bool
// sent tracks whether or not we have emitted Stfu for sending.
sent bool
// received tracks whether or not we have received Stfu from our peer.
received bool
sync.RWMutex
}
// NewQuiescer creates a new quiescer for the given channel.
func NewQuiescer(cfg QuiescerCfg) Quiescer {
return Quiescer{
cfg: cfg,
}
}
// RecvStfu is called when we receive an Stfu message from the remote.
func (q *Quiescer) RecvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
q.Lock()
defer q.Unlock()
return q.recvStfu(msg, numPendingRemoteUpdates)
}
// recvStfu is called when we receive an Stfu message from the remote.
func (q *Quiescer) recvStfu(msg lnwire.Stfu,
numPendingRemoteUpdates uint64) error {
// At the time of this writing, this check that we have already received
// an Stfu is not strictly necessary, according to the specification.
// However, it is fishy if we do and it is unclear how we should handle
// such a case so we will err on the side of caution.
if q.received {
return fmt.Errorf("%w for channel %v", ErrStfuAlreadyRcvd,
q.cfg.chanID)
}
// We need to check that the Stfu we are receiving is valid.
if !q.sent && !msg.Initiator {
return fmt.Errorf("%w for channel %v", ErrInvalidStfu,
q.cfg.chanID)
}
if !q.canRecvStfu(numPendingRemoteUpdates) {
return fmt.Errorf("%w for channel %v", ErrPendingRemoteUpdates,
q.cfg.chanID)
}
q.received = true
// If the remote party sets the initiator bit to true then we will
// remember that they are making a claim to the initiator role. This
// does not necessarily mean they will get it, though.
q.remoteInit = msg.Initiator
return nil
}
// MakeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *Quiescer) MakeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
q.RLock()
defer q.RUnlock()
return q.makeStfu(numPendingLocalUpdates)
}
// makeStfu is called when we are ready to send an Stfu message. It returns the
// Stfu message to be sent.
func (q *Quiescer) makeStfu(
numPendingLocalUpdates uint64) fn.Result[lnwire.Stfu] {
if q.sent {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrStfuAlreadySent, q.cfg.chanID)
}
if !q.canSendStfu(numPendingLocalUpdates) {
return fn.Errf[lnwire.Stfu]("%w for channel %v",
ErrPendingLocalUpdates, q.cfg.chanID)
}
stfu := lnwire.Stfu{
ChanID: q.cfg.chanID,
Initiator: q.localInit,
}
return fn.Ok(stfu)
}
// OweStfu returns true if we owe the other party an Stfu. We owe the remote an
// Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu.
func (q *Quiescer) OweStfu() bool {
q.RLock()
defer q.RUnlock()
return q.oweStfu()
}
// oweStfu returns true if we owe the other party an Stfu. We owe the remote an
// Stfu when we have received but not yet sent an Stfu, or we are the initiator
// but have not yet sent an Stfu.
func (q *Quiescer) oweStfu() bool {
return q.received && !q.sent
}
// NeedStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
// we have sent but not yet received an Stfu.
func (q *Quiescer) NeedStfu() bool {
q.RLock()
defer q.RUnlock()
return q.needStfu()
}
// needStfu returns true if the remote owes us an Stfu. They owe us an Stfu when
// we have sent but not yet received an Stfu.
func (q *Quiescer) needStfu() bool {
q.RLock()
defer q.RUnlock()
return q.sent && !q.received
}
// IsQuiescent returns true if the state machine has been driven all the way to
// completion. If this returns true, processes that depend on channel quiescence
// may proceed.
func (q *Quiescer) IsQuiescent() bool {
q.RLock()
defer q.RUnlock()
return q.isQuiescent()
}
// isQuiescent returns true if the state machine has been driven all the way to
// completion. If this returns true, processes that depend on channel quiescence
// may proceed.
func (q *Quiescer) isQuiescent() bool {
return q.sent && q.received
}
// QuiescenceInitiator determines which ChannelParty is the initiator of
// quiescence for the purposes of downstream protocols. If the channel is not
// currently quiescent, this method will return ErrNoQuiescenceInitiator.
func (q *Quiescer) QuiescenceInitiator() fn.Result[lntypes.ChannelParty] {
q.RLock()
defer q.RUnlock()
return q.quiescenceInitiator()
}
// quiescenceInitiator determines which ChannelParty is the initiator of
// quiescence for the purposes of downstream protocols. If the channel is not
// currently quiescent, this method will return ErrNoQuiescenceInitiator.
func (q *Quiescer) quiescenceInitiator() fn.Result[lntypes.ChannelParty] {
switch {
case !q.isQuiescent():
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
case q.localInit && q.remoteInit:
// In the case of a tie, the channel initiator wins.
return fn.Ok(q.cfg.channelInitiator)
case q.localInit:
return fn.Ok(lntypes.Local)
case q.remoteInit:
return fn.Ok(lntypes.Remote)
}
// unreachable
return fn.Err[lntypes.ChannelParty](ErrNoQuiescenceInitiator)
}
// CanSendUpdates returns true if we haven't yet sent an Stfu which would mark
// the end of our ability to send updates.
func (q *Quiescer) CanSendUpdates() bool {
q.RLock()
defer q.RUnlock()
return q.canSendUpdates()
}
// canSendUpdates returns true if we haven't yet sent an Stfu which would mark
// the end of our ability to send updates.
func (q *Quiescer) canSendUpdates() bool {
return !q.sent && !q.localInit
}
// CanRecvUpdates returns true if we haven't yet received an Stfu which would
// mark the end of the remote's ability to send updates.
func (q *Quiescer) CanRecvUpdates() bool {
q.RLock()
defer q.RUnlock()
return q.canRecvUpdates()
}
// canRecvUpdates returns true if we haven't yet received an Stfu which would
// mark the end of the remote's ability to send updates.
func (q *Quiescer) canRecvUpdates() bool {
return !q.received
}
// CanSendStfu returns true if we can send an Stfu.
func (q *Quiescer) CanSendStfu(numPendingLocalUpdates uint64) bool {
q.RLock()
defer q.RUnlock()
return q.canSendStfu(numPendingLocalUpdates)
}
// canSendStfu returns true if we can send an Stfu.
func (q *Quiescer) canSendStfu(numPendingLocalUpdates uint64) bool {
return numPendingLocalUpdates == 0 && !q.sent
}
// CanRecvStfu returns true if we can receive an Stfu.
func (q *Quiescer) CanRecvStfu(numPendingRemoteUpdates uint64) bool {
q.RLock()
defer q.RUnlock()
return q.canRecvStfu(numPendingRemoteUpdates)
}
// canRecvStfu returns true if we can receive an Stfu.
func (q *Quiescer) canRecvStfu(numPendingRemoteUpdates uint64) bool {
return numPendingRemoteUpdates == 0 && !q.received
}
// SendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *Quiescer) SendOwedStfu(numPendingLocalUpdates uint64) error {
q.Lock()
defer q.Unlock()
return q.sendOwedStfu(numPendingLocalUpdates)
}
// sendOwedStfu sends Stfu if it owes one. It returns an error if the state
// machine is in an invalid state.
func (q *Quiescer) sendOwedStfu(numPendingLocalUpdates uint64) error {
if !q.oweStfu() || !q.canSendStfu(numPendingLocalUpdates) {
return nil
}
err := q.makeStfu(numPendingLocalUpdates).Sink(q.cfg.sendMsg)
if err == nil {
q.sent = true
}
return err
}

247
htlcswitch/quiescer_test.go Normal file
View File

@@ -0,0 +1,247 @@
package htlcswitch
import (
"bytes"
"testing"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/require"
)
var cid = lnwire.ChannelID(bytes.Repeat([]byte{0x00}, 32))
type quiescerTestHarness struct {
pendingUpdates lntypes.Dual[uint64]
quiescer Quiescer
conn <-chan lnwire.Stfu
}
func initQuiescerTestHarness() *quiescerTestHarness {
conn := make(chan lnwire.Stfu, 1)
harness := &quiescerTestHarness{
pendingUpdates: lntypes.Dual[uint64]{},
conn: conn,
}
harness.quiescer = NewQuiescer(QuiescerCfg{
chanID: cid,
sendMsg: func(msg lnwire.Stfu) error {
conn <- msg
return nil
},
})
return harness
}
// TestQuiescerDoubleRecvInvalid ensures that we get an error response when we
// receive the Stfu message twice during the lifecycle of the quiescer.
func TestQuiescerDoubleRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.Error(t, err, ErrStfuAlreadyRcvd)
}
// TestQuiescerPendingUpdatesRecvInvalid ensures that we get an error if we
// receive the Stfu message while the Remote party has panding updates on the
// channel.
func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Remote, 1)
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.ErrorIs(t, err, ErrPendingRemoteUpdates)
}
// TestQuiescenceRemoteInit ensures that we can successfully traverse the state
// graph of quiescence beginning with the Remote party initiating quiescence.
func TestQuiescenceRemoteInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case msg := <-harness.conn:
require.False(t, msg.Initiator)
default:
t.Fatalf("stfu not sent when expected")
}
}
// TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote
// party when we have a receive first traversal of the quiescer's state graph.
func TestQuiescenceInitiator(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Receive
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Send
require.NoError(
t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local),
)
require.Equal(
t, harness.quiescer.QuiescenceInitiator(),
fn.Ok(lntypes.Remote),
)
}
// TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel
// updates prior to but not after we receive Stfu.
func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
require.True(t, harness.quiescer.CanRecvUpdates())
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.False(t, harness.quiescer.CanRecvUpdates())
}
// TestQuiescenceCantSendUpdatesAfterStfu tests that we can send channel updates
// prior to but not after we send Stfu.
func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
require.True(t, harness.quiescer.CanSendUpdates())
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
err := harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
require.False(t, harness.quiescer.CanSendUpdates())
}
// TestQuiescenceStfuNotNeededAfterRecv tests that after we receive an Stfu we
// do not needStfu either before or after receiving it if we do not initiate
// quiescence.
func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
require.False(t, harness.quiescer.NeedStfu())
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.False(t, harness.quiescer.NeedStfu())
}
// TestQuiescenceInappropriateMakeStfuReturnsErr ensures that we cannot call
// makeStfu at times when it would be a protocol violation to send it.
func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsErr(),
)
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsOk(),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local),
)
require.True(
t, harness.quiescer.MakeStfu(
harness.pendingUpdates.Local,
).IsErr(),
)
}