protofsm: use new fn.GoroutineManager to manage goroutines

This fixes an isuse that can occur when we have concurrent calls to
`Stop` while the state machine is driving forward.
This commit is contained in:
Olaoluwa Osuntokun
2024-11-13 17:10:30 -08:00
parent 6de0615cd5
commit 2e3c0b2a7d

View File

@@ -1,6 +1,7 @@
package protofsm
import (
"context"
"fmt"
"sync"
"time"
@@ -135,12 +136,11 @@ type StateMachine[Event any, Env Environment] struct {
// query the internal state machine state.
stateQuery chan stateQuery[Event, Env]
wg fn.GoroutineManager
quit chan struct{}
startOnce sync.Once
stopOnce sync.Once
// TODO(roasbeef): also use that context guard here?
quit chan struct{}
wg sync.WaitGroup
}
// ErrorReporter is an interface that's used to report errors that occur during
@@ -194,8 +194,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
cfg: cfg,
events: make(chan Event, 1),
stateQuery: make(chan stateQuery[Event, Env]),
quit: make(chan struct{}),
wg: *fn.NewGoroutineManager(context.Background()),
newStateEvents: fn.NewEventDistributor[State[Event, Env]](),
quit: make(chan struct{}),
}
}
@@ -203,8 +204,9 @@ func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env]
// the state machine to completion.
func (s *StateMachine[Event, Env]) Start() {
s.startOnce.Do(func() {
s.wg.Add(1)
go s.driveMachine()
_ = s.wg.Go(func(ctx context.Context) {
s.driveMachine()
})
})
}
@@ -213,7 +215,7 @@ func (s *StateMachine[Event, Env]) Start() {
func (s *StateMachine[Event, Env]) Stop() {
s.stopOnce.Do(func() {
close(s.quit)
s.wg.Wait()
s.wg.Stop()
})
}
@@ -320,7 +322,7 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[
// executeDaemonEvent executes a daemon event, which is a special type of event
// that can be emitted as part of the state transition function of the state
// machine. An error is returned if the type of event is unknown.
func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
func (s *StateMachine[Event, Env]) executeDaemonEvent(
event DaemonEvent) error {
switch daemonEvent := event.(type) {
@@ -342,14 +344,10 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
err)
}
// If a post-send event was specified, then we'll
// funnel that back into the main state machine now as
// well.
daemonEvent.PostSendEvent.WhenSome(func(event Event) {
s.wg.Add(1)
go func() {
defer s.wg.Done()
// If a post-send event was specified, then we'll funnel
// that back into the main state machine now as well.
return fn.MapOptionZ(daemonEvent.PostSendEvent, func(event Event) error { //nolint:lll
return s.wg.Go(func(ctx context.Context) {
log.Debugf("FSM(%v): sending "+
"post-send event: %v",
s.cfg.Env.Name(),
@@ -357,10 +355,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
)
s.SendEvent(event)
}()
})
})
return nil
}
// If this doesn't have a SendWhen predicate, then we can just
@@ -372,10 +368,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
// Otherwise, this has a SendWhen predicate, so we'll need
// launch a goroutine to poll the SendWhen, then send only once
// the predicate is true.
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
predicateTicker := time.NewTicker(
s.cfg.CustomPollInterval.UnwrapOr(pollInterval),
)
@@ -408,13 +401,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
}
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
return nil
})
// If this is a broadcast transaction event, then we'll broadcast with
// the label attached.
@@ -445,9 +436,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register spend: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
for {
select {
case spend, ok := <-spendEvent.Spend:
@@ -466,13 +455,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
return nil
})
// The state machine has requested a new event to be sent once a
// specified txid+pkScript pair has confirmed.
@@ -489,9 +476,7 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return fmt.Errorf("unable to register conf: %w", err)
}
s.wg.Add(1)
go func() {
defer s.wg.Done()
return s.wg.Go(func(ctx context.Context) {
for {
select {
case <-confEvent.Confirmed:
@@ -508,11 +493,11 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent( //nolint:funlen
return
case <-s.quit:
case <-ctx.Done():
return
}
}
}()
})
}
return fmt.Errorf("unknown daemon event: %T", event)
@@ -632,8 +617,6 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env],
// incoming events, and then drives the state machine forward until it reaches
// a terminal state.
func (s *StateMachine[Event, Env]) driveMachine() {
defer s.wg.Done()
log.Debugf("FSM(%v): starting state machine", s.cfg.Env.Name())
currentState := s.cfg.InitialState
@@ -676,16 +659,11 @@ func (s *StateMachine[Event, Env]) driveMachine() {
// An outside caller is querying our state, so we'll return the
// latest state.
case stateQuery := <-s.stateQuery:
if !fn.SendOrQuit(
stateQuery.CurrentState, currentState, s.quit,
) {
if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { //nolint:lll
return
}
case <-s.quit:
// TODO(roasbeef): logs, etc
// * something in env?
case <-s.wg.Done():
return
}
}