mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-19 21:31:04 +02:00
protofsm: use new fn.GoroutineManager to manage goroutines
This fixes an isuse that can occur when we have concurrent calls to `Stop` while the state machine is driving forward.
This commit is contained in:
parent
6de0615cd5
commit
2e3c0b2a7d
@ -1,6 +1,7 @@
|
|||||||
package protofsm
|
package protofsm
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
|
|||||||
// query the internal state machine state.
|
// query the internal state machine state.
|
||||||
stateQuery chan stateQuery[Event, Env]
|
stateQuery chan stateQuery[Event, Env]
|
||||||
|
|
||||||
|
wg fn.GoroutineManager
|
||||||
|
quit chan struct{}
|
||||||
|
|
||||||
startOnce sync.Once
|
startOnce sync.Once
|
||||||
stopOnce sync.Once
|
stopOnce sync.Once
|
||||||
|
|
||||||
// TODO(roasbeef): also use that context guard here?
|
|
||||||
quit chan struct{}
|
|
||||||
wg sync.WaitGroup
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorReporter is an interface that's used to report errors that occur during
|
// ErrorReporter is an interface that's used to report errors that occur during
|
||||||
@ -194,8 +194,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
|
|||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
events: make(chan Event, 1),
|
events: make(chan Event, 1),
|
||||||
stateQuery: make(chan stateQuery[Event, Env]),
|
stateQuery: make(chan stateQuery[Event, Env]),
|
||||||
quit: make(chan struct{}),
|
wg: *fn.NewGoroutineManager(context.Background()),
|
||||||
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
|
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
|
||||||
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -203,8 +204,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
|
|||||||
// the state machine to completion.
|
// the state machine to completion.
|
||||||
func (s *StateMachine[Event, Env]) Start() {
|
func (s *StateMachine[Event, Env]) Start() {
|
||||||
s.startOnce.Do(func() {
|
s.startOnce.Do(func() {
|
||||||
s.wg.Add(1)
|
_ = s.wg.Go(func(ctx context.Context) {
|
||||||
go s.driveMachine()
|
s.driveMachine()
|
||||||
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
|
|||||||
func (s *StateMachine[Event, Env]) Stop() {
|
func (s *StateMachine[Event, Env]) Stop() {
|
||||||
s.stopOnce.Do(func() {
|
s.stopOnce.Do(func() {
|
||||||
close(s.quit)
|
close(s.quit)
|
||||||
s.wg.Wait()
|
s.wg.Stop()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
|
|||||||
// executeDaemonEvent executes a daemon event, which is a special type of event
|
// executeDaemonEvent executes a daemon event, which is a special type of event
|
||||||
// that can be emitted as part of the state transition function of the state
|
// that can be emitted as part of the state transition function of the state
|
||||||
// machine. An error is returned if the type of event is unknown.
|
// machine. An error is returned if the type of event is unknown.
|
||||||
func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
func (s *StateMachine[Event, Env]) executeDaemonEvent(
|
||||||
event DaemonEvent) error {
|
event DaemonEvent) error {
|
||||||
|
|
||||||
switch daemonEvent := event.(type) {
|
switch daemonEvent := event.(type) {
|
||||||
@ -342,14 +344,10 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
err)
|
err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// If a post-send event was specified, then we'll
|
// If a post-send event was specified, then we'll funnel
|
||||||
// funnel that back into the main state machine now as
|
// that back into the main state machine now as well.
|
||||||
// well.
|
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll
|
||||||
daemonEvent.PostSendEvent.WhenSome(func(event Event) {
|
return s.wg.Go(func(ctx context.Context) {
|
||||||
s.wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
|
|
||||||
log.Debugf("FSM(%v): sending "+
|
log.Debugf("FSM(%v): sending "+
|
||||||
"post-send event: %v",
|
"post-send event: %v",
|
||||||
s.cfg.Env.Name(),
|
s.cfg.Env.Name(),
|
||||||
@ -357,10 +355,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
)
|
)
|
||||||
|
|
||||||
s.SendEvent(event)
|
s.SendEvent(event)
|
||||||
}()
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this doesn't have a SendWhen predicate, then we can just
|
// If this doesn't have a SendWhen predicate, then we can just
|
||||||
@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
// Otherwise, this has a SendWhen predicate, so we'll need
|
// Otherwise, this has a SendWhen predicate, so we'll need
|
||||||
// launch a goroutine to poll the SendWhen, then send only once
|
// launch a goroutine to poll the SendWhen, then send only once
|
||||||
// the predicate is true.
|
// the predicate is true.
|
||||||
s.wg.Add(1)
|
return s.wg.Go(func(ctx context.Context) {
|
||||||
go func() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
|
|
||||||
predicateTicker := time.NewTicker(
|
predicateTicker := time.NewTicker(
|
||||||
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
|
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
|
||||||
)
|
)
|
||||||
@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-s.quit:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
// If this is a broadcast transaction event, then we'll broadcast with
|
// If this is a broadcast transaction event, then we'll broadcast with
|
||||||
// the label attached.
|
// the label attached.
|
||||||
@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
return fmt.Errorf("unable to register spend: %w", err)
|
return fmt.Errorf("unable to register spend: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.wg.Add(1)
|
return s.wg.Go(func(ctx context.Context) {
|
||||||
go func() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case spend, ok := <-spendEvent.Spend:
|
case spend, ok := <-spendEvent.Spend:
|
||||||
@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-s.quit:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
|
|
||||||
return nil
|
|
||||||
|
|
||||||
// The state machine has requested a new event to be sent once a
|
// The state machine has requested a new event to be sent once a
|
||||||
// specified txid+pkScript pair has confirmed.
|
// specified txid+pkScript pair has confirmed.
|
||||||
@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
return fmt.Errorf("unable to register conf: %w", err)
|
return fmt.Errorf("unable to register conf: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.wg.Add(1)
|
return s.wg.Go(func(ctx context.Context) {
|
||||||
go func() {
|
|
||||||
defer s.wg.Done()
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-confEvent.Confirmed:
|
case <-confEvent.Confirmed:
|
||||||
@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
|
|||||||
|
|
||||||
return
|
return
|
||||||
|
|
||||||
case <-s.quit:
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return fmt.Errorf("unknown daemon event: %T", event)
|
return fmt.Errorf("unknown daemon event: %T", event)
|
||||||
@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
|
|||||||
// incoming events, and then drives the state machine forward until it reaches
|
// incoming events, and then drives the state machine forward until it reaches
|
||||||
// a terminal state.
|
// a terminal state.
|
||||||
func (s *StateMachine[Event, Env]) driveMachine() {
|
func (s *StateMachine[Event, Env]) driveMachine() {
|
||||||
defer s.wg.Done()
|
|
||||||
|
|
||||||
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
|
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
|
||||||
|
|
||||||
currentState := s.cfg.InitialState
|
currentState := s.cfg.InitialState
|
||||||
@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
|
|||||||
// An outside caller is querying our state, so we'll return the
|
// An outside caller is querying our state, so we'll return the
|
||||||
// latest state.
|
// latest state.
|
||||||
case stateQuery := <-s.stateQuery:
|
case stateQuery := <-s.stateQuery:
|
||||||
if !fn.SendOrQuit(
|
if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll
|
||||||
stateQuery.CurrentState, currentState, s.quit,
|
|
||||||
) {
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
case <-s.quit:
|
case <-s.wg.Done():
|
||||||
// TODO(roasbeef): logs, etc
|
|
||||||
// * something in env?
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user