htlcswitch: implement InitStfu link operation

This commit is contained in:
Keagan McClelland
2024-03-12 12:32:43 -07:00
parent bca1516429
commit 7255b7357c
3 changed files with 244 additions and 15 deletions

View File

@@ -18,7 +18,9 @@ type quiescerTestHarness struct {
conn <-chan lnwire.Stfu
}
func initQuiescerTestHarness() *quiescerTestHarness {
func initQuiescerTestHarness(
channelInitiator lntypes.ChannelParty) *quiescerTestHarness {
conn := make(chan lnwire.Stfu, 1)
harness := &quiescerTestHarness{
pendingUpdates: lntypes.Dual[uint64]{},
@@ -26,7 +28,8 @@ func initQuiescerTestHarness() *quiescerTestHarness {
}
harness.quiescer = NewQuiescer(QuiescerCfg{
chanID: cid,
chanID: cid,
channelInitiator: channelInitiator,
sendMsg: func(msg lnwire.Stfu) error {
conn <- msg
return nil
@@ -41,7 +44,7 @@ func initQuiescerTestHarness() *quiescerTestHarness {
func TestQuiescerDoubleRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@@ -60,7 +63,7 @@ func TestQuiescerDoubleRecvInvalid(t *testing.T) {
func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@@ -77,7 +80,7 @@ func TestQuiescerPendingUpdatesRecvInvalid(t *testing.T) {
func TestQuiescenceRemoteInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@@ -110,12 +113,61 @@ func TestQuiescenceRemoteInit(t *testing.T) {
}
}
func TestQuiescenceLocalInit(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
stfuReq, stfuRes := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(stfuReq)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
err := harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case <-harness.conn:
t.Fatalf("stfu sent when not expected")
default:
}
harness.pendingUpdates.SetForParty(lntypes.Local, 0)
err = harness.quiescer.SendOwedStfu(harness.pendingUpdates.Local)
require.NoError(t, err)
select {
case msg := <-harness.conn:
require.True(t, msg.Initiator)
default:
t.Fatalf("stfu not sent when expected")
}
err = harness.quiescer.RecvStfu(msg, harness.pendingUpdates.Remote)
require.NoError(t, err)
select {
case party := <-stfuRes:
require.Equal(t, fn.Ok(lntypes.Local), party)
default:
t.Fatalf("quiescence request not resolved")
}
}
// TestQuiescenceInitiator ensures that the quiescenceInitiator is the Remote
// party when we have a receive first traversal of the quiescer's state graph.
func TestQuiescenceInitiator(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
// Remote Initiated
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.QuiescenceInitiator().IsErr())
// Receive
@@ -138,6 +190,48 @@ func TestQuiescenceInitiator(t *testing.T) {
t, harness.quiescer.QuiescenceInitiator(),
fn.Ok(lntypes.Remote),
)
// Local Initiated
harness = initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req)
req2, res2 := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.initStfu(req2)
select {
case initiator := <-res2:
require.True(t, initiator.IsErr())
default:
t.Fatal("quiescence request not resolved")
}
require.NoError(
t, harness.quiescer.sendOwedStfu(harness.pendingUpdates.Local),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsErr())
msg = lnwire.Stfu{
ChanID: cid,
Initiator: false,
}
require.NoError(
t, harness.quiescer.recvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.True(t, harness.quiescer.quiescenceInitiator().IsOk())
select {
case initiator := <-res:
require.Equal(t, fn.Ok(lntypes.Local), initiator)
default:
t.Fatal("quiescence request not resolved")
}
}
// TestQuiescenceCantReceiveUpdatesAfterStfu tests that we can receive channel
@@ -145,7 +239,7 @@ func TestQuiescenceInitiator(t *testing.T) {
func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanRecvUpdates())
msg := lnwire.Stfu{
@@ -165,7 +259,7 @@ func TestQuiescenceCantReceiveUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
require.True(t, harness.quiescer.CanSendUpdates())
msg := lnwire.Stfu{
@@ -188,7 +282,7 @@ func TestQuiescenceCantSendUpdatesAfterStfu(t *testing.T) {
func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
msg := lnwire.Stfu{
ChanID: cid,
@@ -210,7 +304,7 @@ func TestQuiescenceStfuNotNeededAfterRecv(t *testing.T) {
func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
t.Parallel()
harness := initQuiescerTestHarness()
harness := initQuiescerTestHarness(lntypes.Local)
harness.pendingUpdates.SetForParty(lntypes.Local, 1)
@@ -245,3 +339,44 @@ func TestQuiescenceInappropriateMakeStfuReturnsErr(t *testing.T) {
).IsErr(),
)
}
// TestQuiescerTieBreaker ensures that if both parties attempt to claim the
// initiator role that the result of the negotiation breaks the tie using the
// channel initiator.
func TestQuiescerTieBreaker(t *testing.T) {
t.Parallel()
for _, initiator := range []lntypes.ChannelParty{
lntypes.Local, lntypes.Remote,
} {
harness := initQuiescerTestHarness(initiator)
msg := lnwire.Stfu{
ChanID: cid,
Initiator: true,
}
req, res := fn.NewReq[fn.Unit, fn.Result[lntypes.ChannelParty]](
fn.Unit{},
)
harness.quiescer.InitStfu(req)
require.NoError(
t, harness.quiescer.RecvStfu(
msg, harness.pendingUpdates.Remote,
),
)
require.NoError(
t, harness.quiescer.SendOwedStfu(
harness.pendingUpdates.Local,
),
)
select {
case party := <-res:
require.Equal(t, fn.Ok(initiator), party)
default:
t.Fatal("quiescence party unavailable")
}
}
}