protofsm: use new fn.GoroutineManager to manage goroutines

This fixes an isuse that can occur when we have concurrent calls to
`Stop` while the state machine is driving forward.
This commit is contained in:
Olaoluwa Osuntokun 2024-11-13 17:10:30 -08:00
parent 6de0615cd5
commit 2e3c0b2a7d
No known key found for this signature in database
GPG Key ID: 90525F7DEEE0AD86

View File

@ -1,6 +1,7 @@
package protofsm package protofsm
import ( import (
"context"
"fmt" "fmt"
"sync" "sync"
"time" "time"
@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
// query the internal state machine state. // query the internal state machine state.
stateQuery chan stateQuery[Event, Env] stateQuery chan stateQuery[Event, Env]
wg fn.GoroutineManager
quit chan struct{}
startOnce sync.Once startOnce sync.Once
stopOnce sync.Once stopOnce sync.Once
// TODO(roasbeef): also use that context guard here?
quit chan struct{}
wg sync.WaitGroup
} }
// ErrorReporter is an interface that's used to report errors that occur during // ErrorReporter is an interface that's used to report errors that occur during
@ -194,8 +194,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
cfg: cfg, cfg: cfg,
events: make(chan Event, 1), events: make(chan Event, 1),
stateQuery: make(chan stateQuery[Event, Env]), stateQuery: make(chan stateQuery[Event, Env]),
quit: make(chan struct{}), wg: *fn.NewGoroutineManager(context.Background()),
newStateEvents: fn.NewEventDistributor[State[Event, Env]](), newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
quit: make(chan struct{}),
} }
} }
@ -203,8 +204,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
// the state machine to completion. // the state machine to completion.
func (s *StateMachine[Event, Env]) Start() { func (s *StateMachine[Event, Env]) Start() {
s.startOnce.Do(func() { s.startOnce.Do(func() {
s.wg.Add(1) _ = s.wg.Go(func(ctx context.Context) {
go s.driveMachine() s.driveMachine()
})
}) })
} }
@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
func (s *StateMachine[Event, Env]) Stop() { func (s *StateMachine[Event, Env]) Stop() {
s.stopOnce.Do(func() { s.stopOnce.Do(func() {
close(s.quit) close(s.quit)
s.wg.Wait() s.wg.Stop()
}) })
} }
@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
// executeDaemonEvent executes a daemon event, which is a special type of event // executeDaemonEvent executes a daemon event, which is a special type of event
// that can be emitted as part of the state transition function of the state // that can be emitted as part of the state transition function of the state
// machine. An error is returned if the type of event is unknown. // machine. An error is returned if the type of event is unknown.
func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen func (s *StateMachine[Event, Env]) executeDaemonEvent(
event DaemonEvent) error { event DaemonEvent) error {
switch daemonEvent := event.(type) { switch daemonEvent := event.(type) {
@ -342,14 +344,10 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
err) err)
} }
// If a post-send event was specified, then we'll // If a post-send event was specified, then we'll funnel
// funnel that back into the main state machine now as // that back into the main state machine now as well.
// well. return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll
daemonEvent.PostSendEvent.WhenSome(func(event Event) { return s.wg.Go(func(ctx context.Context) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
log.Debugf("FSM(%v): sending "+ log.Debugf("FSM(%v): sending "+
"post-send event: %v", "post-send event: %v",
s.cfg.Env.Name(), s.cfg.Env.Name(),
@ -357,10 +355,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
) )
s.SendEvent(event) s.SendEvent(event)
}() })
}) })
return nil
} }
// If this doesn't have a SendWhen predicate, then we can just // If this doesn't have a SendWhen predicate, then we can just
@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
// Otherwise, this has a SendWhen predicate, so we'll need // Otherwise, this has a SendWhen predicate, so we'll need
// launch a goroutine to poll the SendWhen, then send only once // launch a goroutine to poll the SendWhen, then send only once
// the predicate is true. // the predicate is true.
s.wg.Add(1) return s.wg.Go(func(ctx context.Context) {
go func() {
defer s.wg.Done()
predicateTicker := time.NewTicker( predicateTicker := time.NewTicker(
s.cfg.CustomPollInterval.UnwrapOr(pollInterval), s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
) )
@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return return
} }
case <-s.quit: case <-ctx.Done():
return return
} }
} }
}() })
return nil
// If this is a broadcast transaction event, then we'll broadcast with // If this is a broadcast transaction event, then we'll broadcast with
// the label attached. // the label attached.
@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register spend: %w", err) return fmt.Errorf("unable to register spend: %w", err)
} }
s.wg.Add(1) return s.wg.Go(func(ctx context.Context) {
go func() {
defer s.wg.Done()
for { for {
select { select {
case spend, ok := <-spendEvent.Spend: case spend, ok := <-spendEvent.Spend:
@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return return
case <-s.quit: case <-ctx.Done():
return return
} }
} }
}() })
return nil
// The state machine has requested a new event to be sent once a // The state machine has requested a new event to be sent once a
// specified txid+pkScript pair has confirmed. // specified txid+pkScript pair has confirmed.
@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register conf: %w", err) return fmt.Errorf("unable to register conf: %w", err)
} }
s.wg.Add(1) return s.wg.Go(func(ctx context.Context) {
go func() {
defer s.wg.Done()
for { for {
select { select {
case <-confEvent.Confirmed: case <-confEvent.Confirmed:
@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return return
case <-s.quit: case <-ctx.Done():
return return
} }
} }
}() })
} }
return fmt.Errorf("unknown daemon event: %T", event) return fmt.Errorf("unknown daemon event: %T", event)
@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
// incoming events, and then drives the state machine forward until it reaches // incoming events, and then drives the state machine forward until it reaches
// a terminal state. // a terminal state.
func (s *StateMachine[Event, Env]) driveMachine() { func (s *StateMachine[Event, Env]) driveMachine() {
defer s.wg.Done()
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name()) log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
currentState := s.cfg.InitialState currentState := s.cfg.InitialState
@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
// An outside caller is querying our state, so we'll return the // An outside caller is querying our state, so we'll return the
// latest state. // latest state.
case stateQuery := <-s.stateQuery: case stateQuery := <-s.stateQuery:
if !fn.SendOrQuit( if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll
stateQuery.CurrentState, currentState, s.quit,
) {
return return
} }
case <-s.quit: case <-s.wg.Done():
// TODO(roasbeef): logs, etc
// * something in env?
return return
} }
} }