mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-26 21:51:27 +02:00
protofsm: use updated GoroutineManager API
Update to use the latest version of the GoroutineManager which takes a context via the `Go` method instead of the constructor.
This commit is contained in:
@@ -193,14 +193,14 @@ type StateMachineCfg[Event any, Env Environment] struct {
|
||||
// an initial state, an environment, and an event to process as if emitted at
|
||||
// the onset of the state machine. Such an event can be used to set up tracking
|
||||
// state such as a txid confirmation event.
|
||||
func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:ll
|
||||
) StateMachine[Event, Env] {
|
||||
func NewStateMachine[Event any, Env Environment](
|
||||
cfg StateMachineCfg[Event, Env]) StateMachine[Event, Env] {
|
||||
|
||||
return StateMachine[Event, Env]{
|
||||
cfg: cfg,
|
||||
events: make(chan Event, 1),
|
||||
stateQuery: make(chan stateQuery[Event, Env]),
|
||||
wg: *fn.NewGoroutineManager(context.Background()),
|
||||
wg: *fn.NewGoroutineManager(),
|
||||
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
@@ -208,10 +208,10 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
|
||||
|
||||
// Start starts the state machine. This will spawn a goroutine that will drive
|
||||
// the state machine to completion.
|
||||
func (s *StateMachine[Event, Env]) Start() {
|
||||
func (s *StateMachine[Event, Env]) Start(ctx context.Context) {
|
||||
s.startOnce.Do(func() {
|
||||
_ = s.wg.Go(func(ctx context.Context) {
|
||||
s.driveMachine()
|
||||
_ = s.wg.Go(ctx, func(ctx context.Context) {
|
||||
s.driveMachine(ctx)
|
||||
})
|
||||
})
|
||||
}
|
||||
@@ -228,13 +228,15 @@ func (s *StateMachine[Event, Env]) Stop() {
|
||||
// SendEvent sends a new event to the state machine.
|
||||
//
|
||||
// TODO(roasbeef): bool if processed?
|
||||
func (s *StateMachine[Event, Env]) SendEvent(event Event) {
|
||||
func (s *StateMachine[Event, Env]) SendEvent(ctx context.Context, event Event) {
|
||||
log.Debugf("FSM(%v): sending event: %v", s.cfg.Env.Name(),
|
||||
lnutils.SpewLogClosure(event),
|
||||
)
|
||||
|
||||
select {
|
||||
case s.events <- event:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-s.quit:
|
||||
return
|
||||
}
|
||||
@@ -258,7 +260,9 @@ func (s *StateMachine[Event, Env]) Name() string {
|
||||
// message can be mapped using the default message mapper, then true is
|
||||
// returned indicating that the message was processed. Otherwise, false is
|
||||
// returned.
|
||||
func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool {
|
||||
func (s *StateMachine[Event, Env]) SendMessage(ctx context.Context,
|
||||
msg lnwire.Message) bool {
|
||||
|
||||
// If we have no message mapper, then return false as we can't process
|
||||
// this message.
|
||||
if !s.cfg.MsgMapper.IsSome() {
|
||||
@@ -277,7 +281,7 @@ func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool {
|
||||
event := mapper.MapMsg(msg)
|
||||
|
||||
event.WhenSome(func(event Event) {
|
||||
s.SendEvent(event)
|
||||
s.SendEvent(ctx, event)
|
||||
|
||||
processed = true
|
||||
})
|
||||
@@ -330,7 +334,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
|
||||
// machine. An error is returned if the type of event is unknown.
|
||||
//
|
||||
//nolint:funlen
|
||||
func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
func (s *StateMachine[Event, Env]) executeDaemonEvent(ctx context.Context,
|
||||
event DaemonEvent) error {
|
||||
|
||||
switch daemonEvent := event.(type) {
|
||||
@@ -355,15 +359,16 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
// If a post-send event was specified, then we'll funnel
|
||||
// that back into the main state machine now as well.
|
||||
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:ll
|
||||
launched := s.wg.Go(func(ctx context.Context) {
|
||||
log.Debugf("FSM(%v): sending "+
|
||||
"post-send event: %v",
|
||||
s.cfg.Env.Name(),
|
||||
lnutils.SpewLogClosure(event),
|
||||
)
|
||||
launched := s.wg.Go(
|
||||
ctx, func(ctx context.Context) {
|
||||
log.Debugf("FSM(%v): sending "+
|
||||
"post-send event: %v",
|
||||
s.cfg.Env.Name(),
|
||||
lnutils.SpewLogClosure(event))
|
||||
|
||||
s.SendEvent(event)
|
||||
})
|
||||
s.SendEvent(ctx, event)
|
||||
},
|
||||
)
|
||||
|
||||
if !launched {
|
||||
return ErrStateMachineShutdown
|
||||
@@ -382,7 +387,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
// Otherwise, this has a SendWhen predicate, so we'll need
|
||||
// launch a goroutine to poll the SendWhen, then send only once
|
||||
// the predicate is true.
|
||||
launched := s.wg.Go(func(ctx context.Context) {
|
||||
launched := s.wg.Go(ctx, func(ctx context.Context) {
|
||||
predicateTicker := time.NewTicker(
|
||||
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
|
||||
)
|
||||
@@ -456,7 +461,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
return fmt.Errorf("unable to register spend: %w", err)
|
||||
}
|
||||
|
||||
launched := s.wg.Go(func(ctx context.Context) {
|
||||
launched := s.wg.Go(ctx, func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case spend, ok := <-spendEvent.Spend:
|
||||
@@ -470,7 +475,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
postSpend := daemonEvent.PostSpendEvent
|
||||
postSpend.WhenSome(func(f SpendMapper[Event]) { //nolint:ll
|
||||
customEvent := f(spend)
|
||||
s.SendEvent(customEvent)
|
||||
s.SendEvent(ctx, customEvent)
|
||||
})
|
||||
|
||||
return
|
||||
@@ -502,7 +507,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
return fmt.Errorf("unable to register conf: %w", err)
|
||||
}
|
||||
|
||||
launched := s.wg.Go(func(ctx context.Context) {
|
||||
launched := s.wg.Go(ctx, func(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-confEvent.Confirmed:
|
||||
@@ -514,7 +519,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
// dispatchAfterRecv w/ above
|
||||
postConf := daemonEvent.PostConfEvent
|
||||
postConf.WhenSome(func(e Event) {
|
||||
s.SendEvent(e)
|
||||
s.SendEvent(ctx, e)
|
||||
})
|
||||
|
||||
return
|
||||
@@ -538,8 +543,9 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||
// applyEvents applies a new event to the state machine. This will continue
|
||||
// until no further events are emitted by the state machine. Along the way,
|
||||
// we'll also ensure to execute any daemon events that are emitted.
|
||||
func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
|
||||
newEvent Event) (State[Event, Env], error) {
|
||||
func (s *StateMachine[Event, Env]) applyEvents(ctx context.Context,
|
||||
currentState State[Event, Env], newEvent Event) (State[Event, Env],
|
||||
error) {
|
||||
|
||||
log.Debugf("FSM(%v): applying new event", s.cfg.Env.Name(),
|
||||
lnutils.SpewLogClosure(newEvent),
|
||||
@@ -575,7 +581,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
|
||||
// of this new state transition.
|
||||
for _, dEvent := range events.ExternalEvents {
|
||||
err := s.executeDaemonEvent(
|
||||
dEvent,
|
||||
ctx, dEvent,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -633,7 +639,7 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
|
||||
// driveMachine is the main event loop of the state machine. It accepts any new
|
||||
// incoming events, and then drives the state machine forward until it reaches
|
||||
// a terminal state.
|
||||
func (s *StateMachine[Event, Env]) driveMachine() {
|
||||
func (s *StateMachine[Event, Env]) driveMachine(ctx context.Context) {
|
||||
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
|
||||
|
||||
currentState := s.cfg.InitialState
|
||||
@@ -641,7 +647,7 @@ func (s *StateMachine[Event, Env]) driveMachine() {
|
||||
// Before we start, if we have an init daemon event specified, then
|
||||
// we'll handle that now.
|
||||
err := fn.MapOptionZ(s.cfg.InitEvent, func(event DaemonEvent) error {
|
||||
return s.executeDaemonEvent(event)
|
||||
return s.executeDaemonEvent(ctx, event)
|
||||
})
|
||||
if err != nil {
|
||||
log.Errorf("unable to execute init event: %w", err)
|
||||
@@ -658,7 +664,9 @@ func (s *StateMachine[Event, Env]) driveMachine() {
|
||||
// machine forward until we either run out of internal events,
|
||||
// or we reach a terminal state.
|
||||
case newEvent := <-s.events:
|
||||
newState, err := s.applyEvents(currentState, newEvent)
|
||||
newState, err := s.applyEvents(
|
||||
ctx, currentState, newEvent,
|
||||
)
|
||||
if err != nil {
|
||||
s.cfg.ErrorReporter.ReportError(err)
|
||||
|
||||
|
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/stretchr/testify/mock"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
type dummyEvents interface {
|
||||
@@ -223,6 +224,8 @@ func (d *dummyAdapters) RegisterSpendNtfn(outpoint *wire.OutPoint,
|
||||
// TestStateMachineOnInitDaemonEvent tests that the state machine will properly
|
||||
// execute any init-level daemon events passed into it.
|
||||
func TestStateMachineOnInitDaemonEvent(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// First, we'll create our state machine given the env, and our
|
||||
// starting state.
|
||||
env := &dummyEnv{}
|
||||
@@ -254,7 +257,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) {
|
||||
stateSub := stateMachine.RegisterStateEvents()
|
||||
defer stateMachine.RemoveStateSub(stateSub)
|
||||
|
||||
stateMachine.Start()
|
||||
stateMachine.Start(ctx)
|
||||
defer stateMachine.Stop()
|
||||
|
||||
// Assert that we go from the starting state to the final state. The
|
||||
@@ -275,6 +278,7 @@ func TestStateMachineOnInitDaemonEvent(t *testing.T) {
|
||||
// transition.
|
||||
func TestStateMachineInternalEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
// First, we'll create our state machine given the env, and our
|
||||
// starting state.
|
||||
@@ -296,12 +300,12 @@ func TestStateMachineInternalEvents(t *testing.T) {
|
||||
stateSub := stateMachine.RegisterStateEvents()
|
||||
defer stateMachine.RemoveStateSub(stateSub)
|
||||
|
||||
stateMachine.Start()
|
||||
stateMachine.Start(ctx)
|
||||
defer stateMachine.Stop()
|
||||
|
||||
// For this transition, we'll send in the emitInternal event, which'll
|
||||
// send us back to the starting event, but emit an internal event.
|
||||
stateMachine.SendEvent(&emitInternal{})
|
||||
stateMachine.SendEvent(ctx, &emitInternal{})
|
||||
|
||||
// We'll now also assert the path we took to get here to ensure the
|
||||
// internal events were processed.
|
||||
@@ -323,6 +327,7 @@ func TestStateMachineInternalEvents(t *testing.T) {
|
||||
// daemon emitted as part of the state transition process.
|
||||
func TestStateMachineDaemonEvents(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
// First, we'll create our state machine given the env, and our
|
||||
// starting state.
|
||||
@@ -348,7 +353,7 @@ func TestStateMachineDaemonEvents(t *testing.T) {
|
||||
stateSub := stateMachine.RegisterStateEvents()
|
||||
defer stateMachine.RemoveStateSub(stateSub)
|
||||
|
||||
stateMachine.Start()
|
||||
stateMachine.Start(ctx)
|
||||
defer stateMachine.Stop()
|
||||
|
||||
// As soon as we send in the daemon event, we expect the
|
||||
@@ -360,7 +365,7 @@ func TestStateMachineDaemonEvents(t *testing.T) {
|
||||
|
||||
// We'll start off by sending in the daemon event, which'll trigger the
|
||||
// state machine to execute the series of daemon events.
|
||||
stateMachine.SendEvent(&daemonEvents{})
|
||||
stateMachine.SendEvent(ctx, &daemonEvents{})
|
||||
|
||||
// We should transition back to the starting state now, after we
|
||||
// started from the very same state.
|
||||
@@ -402,6 +407,8 @@ func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] {
|
||||
// TestStateMachineMsgMapper tests that given a message mapper, we can properly
|
||||
// send in wire messages get mapped to FSM events.
|
||||
func TestStateMachineMsgMapper(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// First, we'll create our state machine given the env, and our
|
||||
// starting state.
|
||||
env := &dummyEnv{}
|
||||
@@ -436,7 +443,7 @@ func TestStateMachineMsgMapper(t *testing.T) {
|
||||
stateSub := stateMachine.RegisterStateEvents()
|
||||
defer stateMachine.RemoveStateSub(stateSub)
|
||||
|
||||
stateMachine.Start()
|
||||
stateMachine.Start(ctx)
|
||||
defer stateMachine.Stop()
|
||||
|
||||
// First, we'll verify that the CanHandle method works as expected.
|
||||
@@ -445,7 +452,7 @@ func TestStateMachineMsgMapper(t *testing.T) {
|
||||
|
||||
// Next, we'll attempt to send the wire message into the state machine.
|
||||
// We should transition to the final state.
|
||||
require.True(t, stateMachine.SendMessage(wireError))
|
||||
require.True(t, stateMachine.SendMessage(ctx, wireError))
|
||||
|
||||
// We should transition to the final state.
|
||||
expectedStates := []State[dummyEvents, *dummyEnv]{
|
||||
|
Reference in New Issue
Block a user