protofsm: use updated GoroutineManager API

Update to use the latest version of the GoroutineManager which takes a
context via the `Go` method instead of the constructor.
This commit is contained in:
Elle Mouton 2024-12-09 20:41:43 +02:00
parent 4e0498faa4
commit b887c1cc5d
No known key found for this signature in database
GPG Key ID: D7D916376026F177
5 changed files with 99 additions and 67 deletions

2
go.mod
View File

@ -36,7 +36,7 @@ require (
github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb
github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/cert v1.2.2
github.com/lightningnetwork/lnd/clock v1.1.1 github.com/lightningnetwork/lnd/clock v1.1.1
github.com/lightningnetwork/lnd/fn/v2 v2.0.2 github.com/lightningnetwork/lnd/fn/v2 v2.0.4
github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/healthcheck v1.2.6
github.com/lightningnetwork/lnd/kvdb v1.4.12 github.com/lightningnetwork/lnd/kvdb v1.4.12
github.com/lightningnetwork/lnd/queue v1.1.1 github.com/lightningnetwork/lnd/queue v1.1.1

4
go.sum
View File

@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf
github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U=
github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0=
github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ=
github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs= github.com/lightningnetwork/lnd/fn/v2 v2.0.4 h1:DiC/AEa7DhnY4qOEQBISu1cp+1+51LjbVDzNLVBwNjI=
github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= github.com/lightningnetwork/lnd/fn/v2 v2.0.4/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s=
github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI=
github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ=
github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJlfxc8w5SlJeQ= github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJlfxc8w5SlJeQ=

View File

@ -2,6 +2,7 @@ package chancloser
import ( import (
"bytes" "bytes"
"context"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
@ -144,7 +145,9 @@ func assertUnknownEventFail(t *testing.T, startingState ProtocolState) {
closeHarness.expectFailure(ErrInvalidStateTransition) closeHarness.expectFailure(ErrInvalidStateTransition)
closeHarness.chanCloser.SendEvent(&unknownEvent{}) closeHarness.chanCloser.SendEvent(
context.Background(), &unknownEvent{},
)
// There should be no further state transitions. // There should be no further state transitions.
closeHarness.assertNoStateTransitions() closeHarness.assertNoStateTransitions()
@ -481,6 +484,7 @@ func (r *rbfCloserTestHarness) expectHalfSignerIteration(
initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount,
dustExpect dustExpectation) { dustExpect dustExpectation) {
ctx := context.Background()
numFeeCalls := 2 numFeeCalls := 2
// If we're using the SendOfferEvent as a trigger, we only need to call // If we're using the SendOfferEvent as a trigger, we only need to call
@ -527,7 +531,7 @@ func (r *rbfCloserTestHarness) expectHalfSignerIteration(
}) })
r.expectMsgSent(msgExpect) r.expectMsgSent(msgExpect)
r.chanCloser.SendEvent(initEvent) r.chanCloser.SendEvent(ctx, initEvent)
// Based on the init event, we'll either just go to the closing // Based on the init event, we'll either just go to the closing
// negotiation state, or go through the channel flushing state first. // negotiation state, or go through the channel flushing state first.
@ -582,6 +586,8 @@ func (r *rbfCloserTestHarness) assertSingleRbfIteration(
initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount,
dustExpect dustExpectation) { dustExpect dustExpectation) {
ctx := context.Background()
// We'll now send in the send offer event, which should trigger 1/2 of // We'll now send in the send offer event, which should trigger 1/2 of
// the RBF loop, ending us in the LocalOfferSent state. // the RBF loop, ending us in the LocalOfferSent state.
r.expectHalfSignerIteration( r.expectHalfSignerIteration(
@ -607,7 +613,7 @@ func (r *rbfCloserTestHarness) assertSingleRbfIteration(
balanceAfterClose, true, balanceAfterClose, true,
) )
r.chanCloser.SendEvent(localSigEvent) r.chanCloser.SendEvent(ctx, localSigEvent)
// We should transition to the pending closing state now. // We should transition to the pending closing state now.
r.assertLocalClosePending() r.assertLocalClosePending()
@ -617,6 +623,8 @@ func (r *rbfCloserTestHarness) assertSingleRemoteRbfIteration(
initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount,
sequence uint32, iteration bool) { sequence uint32, iteration bool) {
ctx := context.Background()
// If this is an iteration, then we expect some intermediate states, // If this is an iteration, then we expect some intermediate states,
// before we enter the main RBF/sign loop. // before we enter the main RBF/sign loop.
if iteration { if iteration {
@ -635,7 +643,7 @@ func (r *rbfCloserTestHarness) assertSingleRemoteRbfIteration(
absoluteFee, balanceAfterClose, false, absoluteFee, balanceAfterClose, false,
) )
r.chanCloser.SendEvent(initEvent) r.chanCloser.SendEvent(ctx, initEvent)
// Our outer state should transition to ClosingNegotiation state. // Our outer state should transition to ClosingNegotiation state.
r.assertStateTransitions(&ClosingNegotiation{}) r.assertStateTransitions(&ClosingNegotiation{})
@ -668,6 +676,8 @@ func assertStateT[T ProtocolState](h *rbfCloserTestHarness) T {
func newRbfCloserTestHarness(t *testing.T, func newRbfCloserTestHarness(t *testing.T,
cfg *harnessCfg) *rbfCloserTestHarness { cfg *harnessCfg) *rbfCloserTestHarness {
ctx := context.Background()
startingHeight := 200 startingHeight := 200
chanPoint := randOutPoint(t) chanPoint := randOutPoint(t)
@ -747,7 +757,7 @@ func newRbfCloserTestHarness(t *testing.T,
).Return(nil) ).Return(nil)
chanCloser := protofsm.NewStateMachine(protoCfg) chanCloser := protofsm.NewStateMachine(protoCfg)
chanCloser.Start() chanCloser.Start(ctx)
harness.stateSub = chanCloser.RegisterStateEvents() harness.stateSub = chanCloser.RegisterStateEvents()
@ -769,6 +779,7 @@ func newCloser(t *testing.T, cfg *harnessCfg) *rbfCloserTestHarness {
// TestRbfChannelActiveTransitions tests the transitions of from the // TestRbfChannelActiveTransitions tests the transitions of from the
// ChannelActive state. // ChannelActive state.
func TestRbfChannelActiveTransitions(t *testing.T) { func TestRbfChannelActiveTransitions(t *testing.T) {
ctx := context.Background()
localAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x01}, 20)) localAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x01}, 20))
remoteAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x02}, 20)) remoteAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x02}, 20))
@ -782,7 +793,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) {
}) })
defer closeHarness.stopAndAssert() defer closeHarness.stopAndAssert()
closeHarness.chanCloser.SendEvent(&SpendEvent{}) closeHarness.chanCloser.SendEvent(ctx, &SpendEvent{})
closeHarness.assertStateTransitions(&CloseFin{}) closeHarness.assertStateTransitions(&CloseFin{})
}) })
@ -799,7 +810,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) {
// We don't specify an upfront shutdown addr, and don't specify // We don't specify an upfront shutdown addr, and don't specify
// on here in the vent, so we should call new addr, but then // on here in the vent, so we should call new addr, but then
// fail. // fail.
closeHarness.chanCloser.SendEvent(&SendShutdown{}) closeHarness.chanCloser.SendEvent(ctx, &SendShutdown{})
// We shouldn't have transitioned to a new state. // We shouldn't have transitioned to a new state.
closeHarness.assertNoStateTransitions() closeHarness.assertNoStateTransitions()
@ -824,9 +835,9 @@ func TestRbfChannelActiveTransitions(t *testing.T) {
// If we send the shutdown event, we should transition to the // If we send the shutdown event, we should transition to the
// shutdown pending state. // shutdown pending state.
closeHarness.chanCloser.SendEvent(&SendShutdown{ closeHarness.chanCloser.SendEvent(
IdealFeeRate: feeRate, ctx, &SendShutdown{IdealFeeRate: feeRate},
}) )
closeHarness.assertStateTransitions(&ShutdownPending{}) closeHarness.assertStateTransitions(&ShutdownPending{})
// If we examine the internal state, it should be consistent // If we examine the internal state, it should be consistent
@ -869,9 +880,9 @@ func TestRbfChannelActiveTransitions(t *testing.T) {
// Next, we'll emit the recv event, with the addr of the remote // Next, we'll emit the recv event, with the addr of the remote
// party. // party.
closeHarness.chanCloser.SendEvent(&ShutdownReceived{ closeHarness.chanCloser.SendEvent(
ShutdownScript: remoteAddr, ctx, &ShutdownReceived{ShutdownScript: remoteAddr},
}) )
// We should transition to the shutdown pending state. // We should transition to the shutdown pending state.
closeHarness.assertStateTransitions(&ShutdownPending{}) closeHarness.assertStateTransitions(&ShutdownPending{})
@ -899,6 +910,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) {
// shutdown ourselves. // shutdown ourselves.
func TestRbfShutdownPendingTransitions(t *testing.T) { func TestRbfShutdownPendingTransitions(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
startingState := &ShutdownPending{} startingState := &ShutdownPending{}
@ -913,7 +925,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) {
}) })
defer closeHarness.stopAndAssert() defer closeHarness.stopAndAssert()
closeHarness.chanCloser.SendEvent(&SpendEvent{}) closeHarness.chanCloser.SendEvent(ctx, &SpendEvent{})
closeHarness.assertStateTransitions(&CloseFin{}) closeHarness.assertStateTransitions(&CloseFin{})
}) })
@ -936,7 +948,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) {
// We'll now send in a ShutdownReceived event, but with a // We'll now send in a ShutdownReceived event, but with a
// different address provided in the shutdown message. This // different address provided in the shutdown message. This
// should result in an error. // should result in an error.
closeHarness.chanCloser.SendEvent(&ShutdownReceived{ closeHarness.chanCloser.SendEvent(ctx, &ShutdownReceived{
ShutdownScript: localAddr, ShutdownScript: localAddr,
}) })
@ -972,9 +984,9 @@ func TestRbfShutdownPendingTransitions(t *testing.T) {
// We'll send in a shutdown received event, with the expected // We'll send in a shutdown received event, with the expected
// co-op close addr. // co-op close addr.
closeHarness.chanCloser.SendEvent(&ShutdownReceived{ closeHarness.chanCloser.SendEvent(
ShutdownScript: remoteAddr, ctx, &ShutdownReceived{ShutdownScript: remoteAddr},
}) )
// We should transition to the channel flushing state. // We should transition to the channel flushing state.
closeHarness.assertStateTransitions(&ChannelFlushing{}) closeHarness.assertStateTransitions(&ChannelFlushing{})
@ -1015,7 +1027,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) {
closeHarness.expectFinalBalances(fn.None[ShutdownBalances]()) closeHarness.expectFinalBalances(fn.None[ShutdownBalances]())
// We'll send in a shutdown received event. // We'll send in a shutdown received event.
closeHarness.chanCloser.SendEvent(&ShutdownComplete{}) closeHarness.chanCloser.SendEvent(ctx, &ShutdownComplete{})
// We should transition to the channel flushing state. // We should transition to the channel flushing state.
closeHarness.assertStateTransitions(&ChannelFlushing{}) closeHarness.assertStateTransitions(&ChannelFlushing{})
@ -1030,6 +1042,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) {
// transition to the negotiation state. // transition to the negotiation state.
func TestRbfChannelFlushingTransitions(t *testing.T) { func TestRbfChannelFlushingTransitions(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
localBalance := lnwire.NewMSatFromSatoshis(10_000) localBalance := lnwire.NewMSatFromSatoshis(10_000)
remoteBalance := lnwire.NewMSatFromSatoshis(50_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000)
@ -1082,7 +1095,9 @@ func TestRbfChannelFlushingTransitions(t *testing.T) {
// We'll now send in the event which should trigger // We'll now send in the event which should trigger
// this code path. // this code path.
closeHarness.chanCloser.SendEvent(&chanFlushedEvent) closeHarness.chanCloser.SendEvent(
ctx, &chanFlushedEvent,
)
// With the event sent, we should now transition // With the event sent, we should now transition
// straight to the ClosingNegotiation state, with no // straight to the ClosingNegotiation state, with no
@ -1149,6 +1164,7 @@ func TestRbfChannelFlushingTransitions(t *testing.T) {
// rate. // rate.
func TestRbfCloseClosingNegotiationLocal(t *testing.T) { func TestRbfCloseClosingNegotiationLocal(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
localBalance := lnwire.NewMSatFromSatoshis(40_000) localBalance := lnwire.NewMSatFromSatoshis(40_000)
remoteBalance := lnwire.NewMSatFromSatoshis(50_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000)
@ -1232,7 +1248,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) {
// We should fail as the remote party sent us more than one // We should fail as the remote party sent us more than one
// signature. // signature.
closeHarness.chanCloser.SendEvent(localSigEvent) closeHarness.chanCloser.SendEvent(ctx, localSigEvent)
}) })
// Next, we'll verify that if the balance of the remote party is dust, // Next, we'll verify that if the balance of the remote party is dust,
@ -1333,7 +1349,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) {
singleMsgMatcher[*lnwire.Shutdown](nil), singleMsgMatcher[*lnwire.Shutdown](nil),
) )
closeHarness.chanCloser.SendEvent(sendShutdown) closeHarness.chanCloser.SendEvent(ctx, sendShutdown)
// We should first transition to the Channel Active state // We should first transition to the Channel Active state
// momentarily, before transitioning to the shutdown pending // momentarily, before transitioning to the shutdown pending
@ -1367,6 +1383,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) {
// party. // party.
func TestRbfCloseClosingNegotiationRemote(t *testing.T) { func TestRbfCloseClosingNegotiationRemote(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
localBalance := lnwire.NewMSatFromSatoshis(40_000) localBalance := lnwire.NewMSatFromSatoshis(40_000)
remoteBalance := lnwire.NewMSatFromSatoshis(50_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000)
@ -1416,7 +1433,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) {
FeeSatoshis: absoluteFee * 10, FeeSatoshis: absoluteFee * 10,
}, },
} }
closeHarness.chanCloser.SendEvent(feeOffer) closeHarness.chanCloser.SendEvent(ctx, feeOffer)
// We shouldn't have transitioned to a new state. // We shouldn't have transitioned to a new state.
closeHarness.assertNoStateTransitions() closeHarness.assertNoStateTransitions()
@ -1460,7 +1477,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) {
}, },
}, },
} }
closeHarness.chanCloser.SendEvent(feeOffer) closeHarness.chanCloser.SendEvent(ctx, feeOffer)
// We shouldn't have transitioned to a new state. // We shouldn't have transitioned to a new state.
closeHarness.assertNoStateTransitions() closeHarness.assertNoStateTransitions()
@ -1489,7 +1506,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) {
}, },
}, },
} }
closeHarness.chanCloser.SendEvent(feeOffer) closeHarness.chanCloser.SendEvent(ctx, feeOffer)
// We shouldn't have transitioned to a new state. // We shouldn't have transitioned to a new state.
closeHarness.assertNoStateTransitions() closeHarness.assertNoStateTransitions()
@ -1561,9 +1578,9 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) {
// We'll now simulate the start of the RBF loop, by receiving a // We'll now simulate the start of the RBF loop, by receiving a
// new Shutdown message from the remote party. This signals // new Shutdown message from the remote party. This signals
// that they want to obtain a new commit sig. // that they want to obtain a new commit sig.
closeHarness.chanCloser.SendEvent(&ShutdownReceived{ closeHarness.chanCloser.SendEvent(
ShutdownScript: remoteAddr, ctx, &ShutdownReceived{ShutdownScript: remoteAddr},
}) )
// Next, we'll receive an offer from the remote party, and // Next, we'll receive an offer from the remote party, and
// drive another RBF iteration. This time, we'll increase the // drive another RBF iteration. This time, we'll increase the

View File

@ -193,14 +193,14 @@ type StateMachineCfg[Event any, Env Environment] struct {
// an initial state, an environment, and an event to process as if emitted at // an initial state, an environment, and an event to process as if emitted at
// the onset of the state machine. Such an event can be used to set up tracking // the onset of the state machine. Such an event can be used to set up tracking
// state such as a txid confirmation event. // state such as a txid confirmation event.
func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:ll func NewStateMachine[Event any, Env Environment](
) StateMachine[Event, Env] { cfg StateMachineCfg[Event, Env]) StateMachine[Event, Env] {
return StateMachine[Event, Env]{ return StateMachine[Event, Env]{
cfg: cfg, cfg: cfg,
events: make(chan Event, 1), events: make(chan Event, 1),
stateQuery: make(chan stateQuery[Event, Env]), stateQuery: make(chan stateQuery[Event, Env]),
wg: *fn.NewGoroutineManager(context.Background()), wg: *fn.NewGoroutineManager(),
newStateEvents: fn.NewEventDistributor[State[Event, Env]](), newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
quit: make(chan struct{}), quit: make(chan struct{}),
} }
@ -208,10 +208,10 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
// Start starts the state machine. This will spawn a goroutine that will drive // Start starts the state machine. This will spawn a goroutine that will drive
// the state machine to completion. // the state machine to completion.
func (s *StateMachine[Event, Env]) Start() { func (s *StateMachine[Event, Env]) Start(ctx context.Context) {
s.startOnce.Do(func() { s.startOnce.Do(func() {
_ = s.wg.Go(func(ctx context.Context) { _ = s.wg.Go(ctx, func(ctx context.Context) {
s.driveMachine() s.driveMachine(ctx)
}) })
}) })
} }
@ -228,13 +228,15 @@ func (s *StateMachine[Event, Env]) Stop() {
// SendEvent sends a new event to the state machine. // SendEvent sends a new event to the state machine.
// //
// TODO(roasbeef): bool if processed? // TODO(roasbeef): bool if processed?
func (s *StateMachine[Event, Env]) SendEvent(event Event) { func (s *StateMachine[Event, Env]) SendEvent(ctx context.Context, event Event) {
log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(), log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(),
lnutils.SpewLogClosure(event), lnutils.SpewLogClosure(event),
) )
select { select {
case s.events <- event: case s.events <- event:
case <-ctx.Done():
return
case <-s.quit: case <-s.quit:
return return
} }
@ -258,7 +260,9 @@ func (s *StateMachine[Event, Env]) Name() string {
// message can be mapped using the default message mapper, then true is // message can be mapped using the default message mapper, then true is
// returned indicating that the message was processed. Otherwise, false is // returned indicating that the message was processed. Otherwise, false is
// returned. // returned.
func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { func (s *StateMachine[Event, Env]) SendMessage(ctx context.Context,
msg lnwire.Message) bool {
// If we have no message mapper, then return false as we can't process // If we have no message mapper, then return false as we can't process
// this message. // this message.
if !s.cfg.MsgMapper.IsSome() { if !s.cfg.MsgMapper.IsSome() {
@ -277,7 +281,7 @@ func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool {
event := mapper.MapMsg(msg) event := mapper.MapMsg(msg)
event.WhenSome(func(event Event) { event.WhenSome(func(event Event) {
s.SendEvent(event) s.SendEvent(ctx, event)
processed = true processed = true
}) })
@ -330,7 +334,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
// machine. An error is returned if the type of event is unknown. // machine. An error is returned if the type of event is unknown.
// //
//nolint:funlen //nolint:funlen
func (s *StateMachine[Event, Env]) executeDaemonEvent( func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context,
event DaemonEvent) error { event DaemonEvent) error {
switch daemonEvent := event.(type) { switch daemonEvent := event.(type) {
@ -355,15 +359,16 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// If a post-send event was specified, then we'll funnel // If a post-send event was specified, then we'll funnel
// that back into the main state machine now as well. // that back into the main state machine now as well.
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll
launched := s.wg.Go(func(ctx context.Context) { launched := s.wg.Go(
log.Debugf("FSM(%v): sending "+ ctx, func(ctx context.Context) {
"post-send event: %v", log.Debugf("FSM(%v): sending "+
s.cfg.Env.Name(), "post-send event: %v",
lnutils.SpewLogClosure(event), s.cfg.Env.Name(),
) lnutils.SpewLogClosure(event))
s.SendEvent(event) s.SendEvent(ctx, event)
}) },
)
if !launched { if !launched {
return ErrStateMachineShutdown return ErrStateMachineShutdown
@ -382,7 +387,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// Otherwise, this has a SendWhen predicate, so we'll need // Otherwise, this has a SendWhen predicate, so we'll need
// launch a goroutine to poll the SendWhen, then send only once // launch a goroutine to poll the SendWhen, then send only once
// the predicate is true. // the predicate is true.
launched := s.wg.Go(func(ctx context.Context) { launched := s.wg.Go(ctx, func(ctx context.Context) {
predicateTicker := time.NewTicker( predicateTicker := time.NewTicker(
s.cfg.CustomPollInterval.UnwrapOr(pollInterval), s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
) )
@ -456,7 +461,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
return fmt.Errorf("unable to register spend: %w", err) return fmt.Errorf("unable to register spend: %w", err)
} }
launched := s.wg.Go(func(ctx context.Context) { launched := s.wg.Go(ctx, func(ctx context.Context) {
for { for {
select { select {
case spend, ok := <-spendEvent.Spend: case spend, ok := <-spendEvent.Spend:
@ -470,7 +475,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
postSpend := daemonEvent.PostSpendEvent postSpend := daemonEvent.PostSpendEvent
postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:ll postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:ll
customEvent := f(spend) customEvent := f(spend)
s.SendEvent(customEvent) s.SendEvent(ctx, customEvent)
}) })
return return
@ -502,7 +507,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
return fmt.Errorf("unable to register conf: %w", err) return fmt.Errorf("unable to register conf: %w", err)
} }
launched := s.wg.Go(func(ctx context.Context) { launched := s.wg.Go(ctx, func(ctx context.Context) {
for { for {
select { select {
case <-confEvent.Confirmed: case <-confEvent.Confirmed:
@ -514,7 +519,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// dispatchAfterRecv w/ above // dispatchAfterRecv w/ above
postConf := daemonEvent.PostConfEvent postConf := daemonEvent.PostConfEvent
postConf.WhenSome(func(e Event) { postConf.WhenSome(func(e Event) {
s.SendEvent(e) s.SendEvent(ctx, e)
}) })
return return
@ -538,8 +543,9 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
// applyEvents applies a new event to the state machine. This will continue // applyEvents applies a new event to the state machine. This will continue
// until no further events are emitted by the state machine. Along the way, // until no further events are emitted by the state machine. Along the way,
// we'll also ensure to execute any daemon events that are emitted. // we'll also ensure to execute any daemon events that are emitted.
func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context,
newEvent Event) (State[Event, Env], error) { currentState State[Event, Env], newEvent Event) (State[Event, Env],
error) {
log.Debugf("FSM(%v): applying new event", s.cfg.Env.Name(), log.Debugf("FSM(%v): applying new event", s.cfg.Env.Name(),
lnutils.SpewLogClosure(newEvent), lnutils.SpewLogClosure(newEvent),
@ -575,7 +581,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
// of this new state transition. // of this new state transition.
for _, dEvent := range events.ExternalEvents { for _, dEvent := range events.ExternalEvents {
err := s.executeDaemonEvent( err := s.executeDaemonEvent(
dEvent, ctx, dEvent,
) )
if err != nil { if err != nil {
return err return err
@ -633,7 +639,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
// driveMachine is the main event loop of the state machine. It accepts any new // driveMachine is the main event loop of the state machine. It accepts any new
// incoming events, and then drives the state machine forward until it reaches // incoming events, and then drives the state machine forward until it reaches
// a terminal state. // a terminal state.
func (s *StateMachine[Event, Env]) driveMachine() { func (s *StateMachine[Event, Env]) driveMachine(ctx context.Context) {
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name()) log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
currentState := s.cfg.InitialState currentState := s.cfg.InitialState
@ -641,7 +647,7 @@ func (s *StateMachine[Event, Env]) driveMachine() {
// Before we start, if we have an init daemon event specified, then // Before we start, if we have an init daemon event specified, then
// we'll handle that now. // we'll handle that now.
err := fn.MapOptionZ(s.cfg.InitEvent, func(event DaemonEvent) error { err := fn.MapOptionZ(s.cfg.InitEvent, func(event DaemonEvent) error {
return s.executeDaemonEvent(event) return s.executeDaemonEvent(ctx, event)
}) })
if err != nil { if err != nil {
log.Errorf("unable to execute init event: %w", err) log.Errorf("unable to execute init event: %w", err)
@ -658,7 +664,9 @@ func (s *StateMachine[Event, Env]) driveMachine() {
// machine forward until we either run out of internal events, // machine forward until we either run out of internal events,
// or we reach a terminal state. // or we reach a terminal state.
case newEvent := <-s.events: case newEvent := <-s.events:
newState, err := s.applyEvents(currentState, newEvent) newState, err := s.applyEvents(
ctx, currentState, newEvent,
)
if err != nil { if err != nil {
s.cfg.ErrorReporter.ReportError(err) s.cfg.ErrorReporter.ReportError(err)

View File

@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/net/context"
) )
type dummyEvents interface { type dummyEvents interface {
@ -223,6 +224,8 @@ func (d *dummyAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint,
// TestStateMachineOnInitDaemonEvent tests that the state machine will properly // TestStateMachineOnInitDaemonEvent tests that the state machine will properly
// execute any init-level daemon events passed into it. // execute any init-level daemon events passed into it.
func TestStateMachineOnInitDaemonEvent(t *testing.T) { func TestStateMachineOnInitDaemonEvent(t *testing.T) {
ctx := context.Background()
// First, we'll create our state machine given the env, and our // First, we'll create our state machine given the env, and our
// starting state. // starting state.
env := &dummyEnv{} env := &dummyEnv{}
@ -254,7 +257,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) {
stateSub := stateMachine.RegisterStateEvents() stateSub := stateMachine.RegisterStateEvents()
defer stateMachine.RemoveStateSub(stateSub) defer stateMachine.RemoveStateSub(stateSub)
stateMachine.Start() stateMachine.Start(ctx)
defer stateMachine.Stop() defer stateMachine.Stop()
// Assert that we go from the starting state to the final state. The // Assert that we go from the starting state to the final state. The
@ -275,6 +278,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) {
// transition. // transition.
func TestStateMachineInternalEvents(t *testing.T) { func TestStateMachineInternalEvents(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
// First, we'll create our state machine given the env, and our // First, we'll create our state machine given the env, and our
// starting state. // starting state.
@ -296,12 +300,12 @@ func TestStateMachineInternalEvents(t *testing.T) {
stateSub := stateMachine.RegisterStateEvents() stateSub := stateMachine.RegisterStateEvents()
defer stateMachine.RemoveStateSub(stateSub) defer stateMachine.RemoveStateSub(stateSub)
stateMachine.Start() stateMachine.Start(ctx)
defer stateMachine.Stop() defer stateMachine.Stop()
// For this transition, we'll send in the emitInternal event, which'll // For this transition, we'll send in the emitInternal event, which'll
// send us back to the starting event, but emit an internal event. // send us back to the starting event, but emit an internal event.
stateMachine.SendEvent(&emitInternal{}) stateMachine.SendEvent(ctx, &emitInternal{})
// We'll now also assert the path we took to get here to ensure the // We'll now also assert the path we took to get here to ensure the
// internal events were processed. // internal events were processed.
@ -323,6 +327,7 @@ func TestStateMachineInternalEvents(t *testing.T) {
// daemon emitted as part of the state transition process. // daemon emitted as part of the state transition process.
func TestStateMachineDaemonEvents(t *testing.T) { func TestStateMachineDaemonEvents(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background()
// First, we'll create our state machine given the env, and our // First, we'll create our state machine given the env, and our
// starting state. // starting state.
@ -348,7 +353,7 @@ func TestStateMachineDaemonEvents(t *testing.T) {
stateSub := stateMachine.RegisterStateEvents() stateSub := stateMachine.RegisterStateEvents()
defer stateMachine.RemoveStateSub(stateSub) defer stateMachine.RemoveStateSub(stateSub)
stateMachine.Start() stateMachine.Start(ctx)
defer stateMachine.Stop() defer stateMachine.Stop()
// As soon as we send in the daemon event, we expect the // As soon as we send in the daemon event, we expect the
@ -360,7 +365,7 @@ func TestStateMachineDaemonEvents(t *testing.T) {
// We'll start off by sending in the daemon event, which'll trigger the // We'll start off by sending in the daemon event, which'll trigger the
// state machine to execute the series of daemon events. // state machine to execute the series of daemon events.
stateMachine.SendEvent(&daemonEvents{}) stateMachine.SendEvent(ctx, &daemonEvents{})
// We should transition back to the starting state now, after we // We should transition back to the starting state now, after we
// started from the very same state. // started from the very same state.
@ -402,6 +407,8 @@ func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] {
// TestStateMachineMsgMapper tests that given a message mapper, we can properly // TestStateMachineMsgMapper tests that given a message mapper, we can properly
// send in wire messages get mapped to FSM events. // send in wire messages get mapped to FSM events.
func TestStateMachineMsgMapper(t *testing.T) { func TestStateMachineMsgMapper(t *testing.T) {
ctx := context.Background()
// First, we'll create our state machine given the env, and our // First, we'll create our state machine given the env, and our
// starting state. // starting state.
env := &dummyEnv{} env := &dummyEnv{}
@ -436,7 +443,7 @@ func TestStateMachineMsgMapper(t *testing.T) {
stateSub := stateMachine.RegisterStateEvents() stateSub := stateMachine.RegisterStateEvents()
defer stateMachine.RemoveStateSub(stateSub) defer stateMachine.RemoveStateSub(stateSub)
stateMachine.Start() stateMachine.Start(ctx)
defer stateMachine.Stop() defer stateMachine.Stop()
// First, we'll verify that the CanHandle method works as expected. // First, we'll verify that the CanHandle method works as expected.
@ -445,7 +452,7 @@ func TestStateMachineMsgMapper(t *testing.T) {
// Next, we'll attempt to send the wire message into the state machine. // Next, we'll attempt to send the wire message into the state machine.
// We should transition to the final state. // We should transition to the final state.
require.True(t, stateMachine.SendMessage(wireError)) require.True(t, stateMachine.SendMessage(ctx, wireError))
// We should transition to the final state. // We should transition to the final state.
expectedStates := []State[dummyEvents, *dummyEnv]{ expectedStates := []State[dummyEvents, *dummyEnv]{