From b887c1cc5d29fc777c9fb25f3d00ea9daf85297e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 9 Dec 2024 20:41:43 +0200 Subject: [PATCH] protofsm: use updated GoroutineManager API Update to use the latest version of the GoroutineManager which takes a context via the `Go` method instead of the constructor. --- go.mod | 2 +- go.sum | 4 +- lnwallet/chancloser/rbf_coop_test.go | 73 +++++++++++++++++----------- protofsm/state_machine.go | 66 ++++++++++++++----------- protofsm/state_machine_test.go | 21 +++++--- 5 files changed, 99 insertions(+), 67 deletions(-) diff --git a/go.mod b/go.mod index 35f871376..8ed20be32 100644 --- a/go.mod +++ b/go.mod @@ -36,7 +36,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20240712235311-98bd56499dfb github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn/v2 v2.0.2 + github.com/lightningnetwork/lnd/fn/v2 v2.0.4 github.com/lightningnetwork/lnd/healthcheck v1.2.6 github.com/lightningnetwork/lnd/kvdb v1.4.12 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index c9fb7ff82..0433b87e5 100644 --- a/go.sum +++ b/go.sum @@ -456,8 +456,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn/v2 v2.0.2 h1:M7o2lYrh/zCp+lntPB3WP/rWTu5U+4ssyHW+kqNJ0fs= -github.com/lightningnetwork/lnd/fn/v2 v2.0.2/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= +github.com/lightningnetwork/lnd/fn/v2 v2.0.4 h1:DiC/AEa7DhnY4qOEQBISu1cp+1+51LjbVDzNLVBwNjI= +github.com/lightningnetwork/lnd/fn/v2 v2.0.4/go.mod h1:TOzwrhjB/Azw1V7aa8t21ufcQmdsQOQMDtxVOQWNl8s= github.com/lightningnetwork/lnd/healthcheck v1.2.6 h1:1sWhqr93GdkWy4+6U7JxBfcyZIE78MhIHTJZfPx7qqI= github.com/lightningnetwork/lnd/healthcheck v1.2.6/go.mod h1:Mu02um4CWY/zdTOvFje7WJgJcHyX2zq/FG3MhOAiGaQ= github.com/lightningnetwork/lnd/kvdb v1.4.12 h1:Y0WY5Tbjyjn6eCYh068qkWur5oFtioJlfxc8w5SlJeQ= diff --git a/lnwallet/chancloser/rbf_coop_test.go b/lnwallet/chancloser/rbf_coop_test.go index 07bc32e3d..17d783eae 100644 --- a/lnwallet/chancloser/rbf_coop_test.go +++ b/lnwallet/chancloser/rbf_coop_test.go @@ -2,6 +2,7 @@ package chancloser import ( "bytes" + "context" "encoding/hex" "errors" "fmt" @@ -144,7 +145,9 @@ func assertUnknownEventFail(t *testing.T, startingState ProtocolState) { closeHarness.expectFailure(ErrInvalidStateTransition) - closeHarness.chanCloser.SendEvent(&unknownEvent{}) + closeHarness.chanCloser.SendEvent( + context.Background(), &unknownEvent{}, + ) // There should be no further state transitions. closeHarness.assertNoStateTransitions() @@ -481,6 +484,7 @@ func (r *rbfCloserTestHarness) expectHalfSignerIteration( initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, dustExpect dustExpectation) { + ctx := context.Background() numFeeCalls := 2 // If we're using the SendOfferEvent as a trigger, we only need to call @@ -527,7 +531,7 @@ func (r *rbfCloserTestHarness) expectHalfSignerIteration( }) r.expectMsgSent(msgExpect) - r.chanCloser.SendEvent(initEvent) + r.chanCloser.SendEvent(ctx, initEvent) // Based on the init event, we'll either just go to the closing // negotiation state, or go through the channel flushing state first. @@ -582,6 +586,8 @@ func (r *rbfCloserTestHarness) assertSingleRbfIteration( initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, dustExpect dustExpectation) { + ctx := context.Background() + // We'll now send in the send offer event, which should trigger 1/2 of // the RBF loop, ending us in the LocalOfferSent state. r.expectHalfSignerIteration( @@ -607,7 +613,7 @@ func (r *rbfCloserTestHarness) assertSingleRbfIteration( balanceAfterClose, true, ) - r.chanCloser.SendEvent(localSigEvent) + r.chanCloser.SendEvent(ctx, localSigEvent) // We should transition to the pending closing state now. r.assertLocalClosePending() @@ -617,6 +623,8 @@ func (r *rbfCloserTestHarness) assertSingleRemoteRbfIteration( initEvent ProtocolEvent, balanceAfterClose, absoluteFee btcutil.Amount, sequence uint32, iteration bool) { + ctx := context.Background() + // If this is an iteration, then we expect some intermediate states, // before we enter the main RBF/sign loop. if iteration { @@ -635,7 +643,7 @@ func (r *rbfCloserTestHarness) assertSingleRemoteRbfIteration( absoluteFee, balanceAfterClose, false, ) - r.chanCloser.SendEvent(initEvent) + r.chanCloser.SendEvent(ctx, initEvent) // Our outer state should transition to ClosingNegotiation state. r.assertStateTransitions(&ClosingNegotiation{}) @@ -668,6 +676,8 @@ func assertStateT[T ProtocolState](h *rbfCloserTestHarness) T { func newRbfCloserTestHarness(t *testing.T, cfg *harnessCfg) *rbfCloserTestHarness { + ctx := context.Background() + startingHeight := 200 chanPoint := randOutPoint(t) @@ -747,7 +757,7 @@ func newRbfCloserTestHarness(t *testing.T, ).Return(nil) chanCloser := protofsm.NewStateMachine(protoCfg) - chanCloser.Start() + chanCloser.Start(ctx) harness.stateSub = chanCloser.RegisterStateEvents() @@ -769,6 +779,7 @@ func newCloser(t *testing.T, cfg *harnessCfg) *rbfCloserTestHarness { // TestRbfChannelActiveTransitions tests the transitions of from the // ChannelActive state. func TestRbfChannelActiveTransitions(t *testing.T) { + ctx := context.Background() localAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x01}, 20)) remoteAddr := lnwire.DeliveryAddress(bytes.Repeat([]byte{0x02}, 20)) @@ -782,7 +793,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) { }) defer closeHarness.stopAndAssert() - closeHarness.chanCloser.SendEvent(&SpendEvent{}) + closeHarness.chanCloser.SendEvent(ctx, &SpendEvent{}) closeHarness.assertStateTransitions(&CloseFin{}) }) @@ -799,7 +810,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) { // We don't specify an upfront shutdown addr, and don't specify // on here in the vent, so we should call new addr, but then // fail. - closeHarness.chanCloser.SendEvent(&SendShutdown{}) + closeHarness.chanCloser.SendEvent(ctx, &SendShutdown{}) // We shouldn't have transitioned to a new state. closeHarness.assertNoStateTransitions() @@ -824,9 +835,9 @@ func TestRbfChannelActiveTransitions(t *testing.T) { // If we send the shutdown event, we should transition to the // shutdown pending state. - closeHarness.chanCloser.SendEvent(&SendShutdown{ - IdealFeeRate: feeRate, - }) + closeHarness.chanCloser.SendEvent( + ctx, &SendShutdown{IdealFeeRate: feeRate}, + ) closeHarness.assertStateTransitions(&ShutdownPending{}) // If we examine the internal state, it should be consistent @@ -869,9 +880,9 @@ func TestRbfChannelActiveTransitions(t *testing.T) { // Next, we'll emit the recv event, with the addr of the remote // party. - closeHarness.chanCloser.SendEvent(&ShutdownReceived{ - ShutdownScript: remoteAddr, - }) + closeHarness.chanCloser.SendEvent( + ctx, &ShutdownReceived{ShutdownScript: remoteAddr}, + ) // We should transition to the shutdown pending state. closeHarness.assertStateTransitions(&ShutdownPending{}) @@ -899,6 +910,7 @@ func TestRbfChannelActiveTransitions(t *testing.T) { // shutdown ourselves. func TestRbfShutdownPendingTransitions(t *testing.T) { t.Parallel() + ctx := context.Background() startingState := &ShutdownPending{} @@ -913,7 +925,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) { }) defer closeHarness.stopAndAssert() - closeHarness.chanCloser.SendEvent(&SpendEvent{}) + closeHarness.chanCloser.SendEvent(ctx, &SpendEvent{}) closeHarness.assertStateTransitions(&CloseFin{}) }) @@ -936,7 +948,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) { // We'll now send in a ShutdownReceived event, but with a // different address provided in the shutdown message. This // should result in an error. - closeHarness.chanCloser.SendEvent(&ShutdownReceived{ + closeHarness.chanCloser.SendEvent(ctx, &ShutdownReceived{ ShutdownScript: localAddr, }) @@ -972,9 +984,9 @@ func TestRbfShutdownPendingTransitions(t *testing.T) { // We'll send in a shutdown received event, with the expected // co-op close addr. - closeHarness.chanCloser.SendEvent(&ShutdownReceived{ - ShutdownScript: remoteAddr, - }) + closeHarness.chanCloser.SendEvent( + ctx, &ShutdownReceived{ShutdownScript: remoteAddr}, + ) // We should transition to the channel flushing state. closeHarness.assertStateTransitions(&ChannelFlushing{}) @@ -1015,7 +1027,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) { closeHarness.expectFinalBalances(fn.None[ShutdownBalances]()) // We'll send in a shutdown received event. - closeHarness.chanCloser.SendEvent(&ShutdownComplete{}) + closeHarness.chanCloser.SendEvent(ctx, &ShutdownComplete{}) // We should transition to the channel flushing state. closeHarness.assertStateTransitions(&ChannelFlushing{}) @@ -1030,6 +1042,7 @@ func TestRbfShutdownPendingTransitions(t *testing.T) { // transition to the negotiation state. func TestRbfChannelFlushingTransitions(t *testing.T) { t.Parallel() + ctx := context.Background() localBalance := lnwire.NewMSatFromSatoshis(10_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000) @@ -1082,7 +1095,9 @@ func TestRbfChannelFlushingTransitions(t *testing.T) { // We'll now send in the event which should trigger // this code path. - closeHarness.chanCloser.SendEvent(&chanFlushedEvent) + closeHarness.chanCloser.SendEvent( + ctx, &chanFlushedEvent, + ) // With the event sent, we should now transition // straight to the ClosingNegotiation state, with no @@ -1149,6 +1164,7 @@ func TestRbfChannelFlushingTransitions(t *testing.T) { // rate. func TestRbfCloseClosingNegotiationLocal(t *testing.T) { t.Parallel() + ctx := context.Background() localBalance := lnwire.NewMSatFromSatoshis(40_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000) @@ -1232,7 +1248,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) { // We should fail as the remote party sent us more than one // signature. - closeHarness.chanCloser.SendEvent(localSigEvent) + closeHarness.chanCloser.SendEvent(ctx, localSigEvent) }) // Next, we'll verify that if the balance of the remote party is dust, @@ -1333,7 +1349,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) { singleMsgMatcher[*lnwire.Shutdown](nil), ) - closeHarness.chanCloser.SendEvent(sendShutdown) + closeHarness.chanCloser.SendEvent(ctx, sendShutdown) // We should first transition to the Channel Active state // momentarily, before transitioning to the shutdown pending @@ -1367,6 +1383,7 @@ func TestRbfCloseClosingNegotiationLocal(t *testing.T) { // party. func TestRbfCloseClosingNegotiationRemote(t *testing.T) { t.Parallel() + ctx := context.Background() localBalance := lnwire.NewMSatFromSatoshis(40_000) remoteBalance := lnwire.NewMSatFromSatoshis(50_000) @@ -1416,7 +1433,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) { FeeSatoshis: absoluteFee * 10, }, } - closeHarness.chanCloser.SendEvent(feeOffer) + closeHarness.chanCloser.SendEvent(ctx, feeOffer) // We shouldn't have transitioned to a new state. closeHarness.assertNoStateTransitions() @@ -1460,7 +1477,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) { }, }, } - closeHarness.chanCloser.SendEvent(feeOffer) + closeHarness.chanCloser.SendEvent(ctx, feeOffer) // We shouldn't have transitioned to a new state. closeHarness.assertNoStateTransitions() @@ -1489,7 +1506,7 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) { }, }, } - closeHarness.chanCloser.SendEvent(feeOffer) + closeHarness.chanCloser.SendEvent(ctx, feeOffer) // We shouldn't have transitioned to a new state. closeHarness.assertNoStateTransitions() @@ -1561,9 +1578,9 @@ func TestRbfCloseClosingNegotiationRemote(t *testing.T) { // We'll now simulate the start of the RBF loop, by receiving a // new Shutdown message from the remote party. This signals // that they want to obtain a new commit sig. - closeHarness.chanCloser.SendEvent(&ShutdownReceived{ - ShutdownScript: remoteAddr, - }) + closeHarness.chanCloser.SendEvent( + ctx, &ShutdownReceived{ShutdownScript: remoteAddr}, + ) // Next, we'll receive an offer from the remote party, and // drive another RBF iteration. This time, we'll increase the diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index 2cc121902..ed87bc550 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -193,14 +193,14 @@ type StateMachineCfg[Event any, Env Environment] struct { // an initial state, an environment, and an event to process as if emitted at // the onset of the state machine. Such an event can be used to set up tracking // state such as a txid confirmation event. -func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:ll -) StateMachine[Event, Env] { +func NewStateMachine[Event any, Env Environment]( + cfg StateMachineCfg[Event, Env]) StateMachine[Event, Env] { return StateMachine[Event, Env]{ cfg: cfg, events: make(chan Event, 1), stateQuery: make(chan stateQuery[Event, Env]), - wg: *fn.NewGoroutineManager(context.Background()), + wg: *fn.NewGoroutineManager(), newStateEvents: fn.NewEventDistributor[State[Event, Env]](), quit: make(chan struct{}), } @@ -208,10 +208,10 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env] // Start starts the state machine. This will spawn a goroutine that will drive // the state machine to completion. -func (s *StateMachine[Event, Env]) Start() { +func (s *StateMachine[Event, Env]) Start(ctx context.Context) { s.startOnce.Do(func() { - _ = s.wg.Go(func(ctx context.Context) { - s.driveMachine() + _ = s.wg.Go(ctx, func(ctx context.Context) { + s.driveMachine(ctx) }) }) } @@ -228,13 +228,15 @@ func (s *StateMachine[Event, Env]) Stop() { // SendEvent sends a new event to the state machine. // // TODO(roasbeef): bool if processed? -func (s *StateMachine[Event, Env]) SendEvent(event Event) { +func (s *StateMachine[Event, Env]) SendEvent(ctx context.Context, event Event) { log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(), lnutils.SpewLogClosure(event), ) select { case s.events <- event: + case <-ctx.Done(): + return case <-s.quit: return } @@ -258,7 +260,9 @@ func (s *StateMachine[Event, Env]) Name() string { // message can be mapped using the default message mapper, then true is // returned indicating that the message was processed. Otherwise, false is // returned. -func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { +func (s *StateMachine[Event, Env]) SendMessage(ctx context.Context, + msg lnwire.Message) bool { + // If we have no message mapper, then return false as we can't process // this message. if !s.cfg.MsgMapper.IsSome() { @@ -277,7 +281,7 @@ func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { event := mapper.MapMsg(msg) event.WhenSome(func(event Event) { - s.SendEvent(event) + s.SendEvent(ctx, event) processed = true }) @@ -330,7 +334,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ // machine. An error is returned if the type of event is unknown. // //nolint:funlen -func (s *StateMachine[Event, Env]) executeDaemonEvent( +func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context, event DaemonEvent) error { switch daemonEvent := event.(type) { @@ -355,15 +359,16 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // If a post-send event was specified, then we'll funnel // that back into the main state machine now as well. return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll - launched := s.wg.Go(func(ctx context.Context) { - log.Debugf("FSM(%v): sending "+ - "post-send event: %v", - s.cfg.Env.Name(), - lnutils.SpewLogClosure(event), - ) + launched := s.wg.Go( + ctx, func(ctx context.Context) { + log.Debugf("FSM(%v): sending "+ + "post-send event: %v", + s.cfg.Env.Name(), + lnutils.SpewLogClosure(event)) - s.SendEvent(event) - }) + s.SendEvent(ctx, event) + }, + ) if !launched { return ErrStateMachineShutdown @@ -382,7 +387,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // Otherwise, this has a SendWhen predicate, so we'll need // launch a goroutine to poll the SendWhen, then send only once // the predicate is true. - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(ctx, func(ctx context.Context) { predicateTicker := time.NewTicker( s.cfg.CustomPollInterval.UnwrapOr(pollInterval), ) @@ -456,7 +461,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register spend: %w", err) } - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(ctx, func(ctx context.Context) { for { select { case spend, ok := <-spendEvent.Spend: @@ -470,7 +475,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( postSpend := daemonEvent.PostSpendEvent postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:ll customEvent := f(spend) - s.SendEvent(customEvent) + s.SendEvent(ctx, customEvent) }) return @@ -502,7 +507,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( return fmt.Errorf("unable to register conf: %w", err) } - launched := s.wg.Go(func(ctx context.Context) { + launched := s.wg.Go(ctx, func(ctx context.Context) { for { select { case <-confEvent.Confirmed: @@ -514,7 +519,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // dispatchAfterRecv w/ above postConf := daemonEvent.PostConfEvent postConf.WhenSome(func(e Event) { - s.SendEvent(e) + s.SendEvent(ctx, e) }) return @@ -538,8 +543,9 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( // applyEvents applies a new event to the state machine. This will continue // until no further events are emitted by the state machine. Along the way, // we'll also ensure to execute any daemon events that are emitted. -func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], - newEvent Event) (State[Event, Env], error) { +func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context, + currentState State[Event, Env], newEvent Event) (State[Event, Env], + error) { log.Debugf("FSM(%v): applying new event", s.cfg.Env.Name(), lnutils.SpewLogClosure(newEvent), @@ -575,7 +581,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], // of this new state transition. for _, dEvent := range events.ExternalEvents { err := s.executeDaemonEvent( - dEvent, + ctx, dEvent, ) if err != nil { return err @@ -633,7 +639,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], // driveMachine is the main event loop of the state machine. It accepts any new // incoming events, and then drives the state machine forward until it reaches // a terminal state. -func (s *StateMachine[Event, Env]) driveMachine() { +func (s *StateMachine[Event, Env]) driveMachine(ctx context.Context) { log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name()) currentState := s.cfg.InitialState @@ -641,7 +647,7 @@ func (s *StateMachine[Event, Env]) driveMachine() { // Before we start, if we have an init daemon event specified, then // we'll handle that now. err := fn.MapOptionZ(s.cfg.InitEvent, func(event DaemonEvent) error { - return s.executeDaemonEvent(event) + return s.executeDaemonEvent(ctx, event) }) if err != nil { log.Errorf("unable to execute init event: %w", err) @@ -658,7 +664,9 @@ func (s *StateMachine[Event, Env]) driveMachine() { // machine forward until we either run out of internal events, // or we reach a terminal state. case newEvent := <-s.events: - newState, err := s.applyEvents(currentState, newEvent) + newState, err := s.applyEvents( + ctx, currentState, newEvent, + ) if err != nil { s.cfg.ErrorReporter.ReportError(err) diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index ea596dc25..1ff0217d9 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -14,6 +14,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "golang.org/x/net/context" ) type dummyEvents interface { @@ -223,6 +224,8 @@ func (d *dummyAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint, // TestStateMachineOnInitDaemonEvent tests that the state machine will properly // execute any init-level daemon events passed into it. func TestStateMachineOnInitDaemonEvent(t *testing.T) { + ctx := context.Background() + // First, we'll create our state machine given the env, and our // starting state. env := &dummyEnv{} @@ -254,7 +257,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) { stateSub := stateMachine.RegisterStateEvents() defer stateMachine.RemoveStateSub(stateSub) - stateMachine.Start() + stateMachine.Start(ctx) defer stateMachine.Stop() // Assert that we go from the starting state to the final state. The @@ -275,6 +278,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) { // transition. func TestStateMachineInternalEvents(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create our state machine given the env, and our // starting state. @@ -296,12 +300,12 @@ func TestStateMachineInternalEvents(t *testing.T) { stateSub := stateMachine.RegisterStateEvents() defer stateMachine.RemoveStateSub(stateSub) - stateMachine.Start() + stateMachine.Start(ctx) defer stateMachine.Stop() // For this transition, we'll send in the emitInternal event, which'll // send us back to the starting event, but emit an internal event. - stateMachine.SendEvent(&emitInternal{}) + stateMachine.SendEvent(ctx, &emitInternal{}) // We'll now also assert the path we took to get here to ensure the // internal events were processed. @@ -323,6 +327,7 @@ func TestStateMachineInternalEvents(t *testing.T) { // daemon emitted as part of the state transition process. func TestStateMachineDaemonEvents(t *testing.T) { t.Parallel() + ctx := context.Background() // First, we'll create our state machine given the env, and our // starting state. @@ -348,7 +353,7 @@ func TestStateMachineDaemonEvents(t *testing.T) { stateSub := stateMachine.RegisterStateEvents() defer stateMachine.RemoveStateSub(stateSub) - stateMachine.Start() + stateMachine.Start(ctx) defer stateMachine.Stop() // As soon as we send in the daemon event, we expect the @@ -360,7 +365,7 @@ func TestStateMachineDaemonEvents(t *testing.T) { // We'll start off by sending in the daemon event, which'll trigger the // state machine to execute the series of daemon events. - stateMachine.SendEvent(&daemonEvents{}) + stateMachine.SendEvent(ctx, &daemonEvents{}) // We should transition back to the starting state now, after we // started from the very same state. @@ -402,6 +407,8 @@ func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] { // TestStateMachineMsgMapper tests that given a message mapper, we can properly // send in wire messages get mapped to FSM events. func TestStateMachineMsgMapper(t *testing.T) { + ctx := context.Background() + // First, we'll create our state machine given the env, and our // starting state. env := &dummyEnv{} @@ -436,7 +443,7 @@ func TestStateMachineMsgMapper(t *testing.T) { stateSub := stateMachine.RegisterStateEvents() defer stateMachine.RemoveStateSub(stateSub) - stateMachine.Start() + stateMachine.Start(ctx) defer stateMachine.Stop() // First, we'll verify that the CanHandle method works as expected. @@ -445,7 +452,7 @@ func TestStateMachineMsgMapper(t *testing.T) { // Next, we'll attempt to send the wire message into the state machine. // We should transition to the final state. - require.True(t, stateMachine.SendMessage(wireError)) + require.True(t, stateMachine.SendMessage(ctx, wireError)) // We should transition to the final state. expectedStates := []State[dummyEvents, *dummyEnv]{