diff --git a/protofsm/daemon_events.go b/protofsm/daemon_events.go index 5a269c7f1..b65adf012 100644 --- a/protofsm/daemon_events.go +++ b/protofsm/daemon_events.go @@ -8,7 +8,7 @@ import ( "github.com/lightningnetwork/lnd/lnwire" ) -// DaemonEvent is a special event that can be emmitted by a state transition +// DaemonEvent is a special event that can be emitted by a state transition // function. A state machine can use this to perform side effects, such as // sending a message to a peer, or broadcasting a transaction. type DaemonEvent interface { diff --git a/protofsm/msg_mapper.go b/protofsm/msg_mapper.go new file mode 100644 index 000000000..b96d677e6 --- /dev/null +++ b/protofsm/msg_mapper.go @@ -0,0 +1,15 @@ +package protofsm + +import ( + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/lnwire" +) + +// MsgMapper is used to map incoming wire messages into a FSM event. This is +// useful to decouple the translation of an outside or wire message into an +// event type that can be understood by the FSM. +type MsgMapper[Event any] interface { + // MapMsg maps a wire message into a FSM event. If the message is not + // mappable, then an None is returned. + MapMsg(msg lnwire.Message) fn.Option[Event] +} diff --git a/protofsm/state_machine.go b/protofsm/state_machine.go index 6968e36e9..0dfa8e73e 100644 --- a/protofsm/state_machine.go +++ b/protofsm/state_machine.go @@ -64,7 +64,8 @@ type State[Event any, Env Environment] interface { // emitted. ProcessEvent(event Event, env Env) (*StateTransition[Event, Env], error) - // IsTerminal returns true if this state is terminal, and false otherwise. + // IsTerminal returns true if this state is terminal, and false + // otherwise. IsTerminal() bool // TODO(roasbeef): also add state serialization? @@ -159,13 +160,17 @@ type StateMachineCfg[Event any, Env Environment] struct { // can be used to set up tracking state such as a txid confirmation // event. InitEvent fn.Option[DaemonEvent] + + // MsgMapper is an optional message mapper that can be used to map + // normal wire messages into FSM events. + MsgMapper fn.Option[MsgMapper[Event]] } // NewStateMachine creates a new state machine given a set of daemon adapters, // an initial state, an environment, and an event to process as if emitted at // the onset of the state machine. Such an event can be used to set up tracking // state such as a txid confirmation event. -func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], +func NewStateMachine[Event any, Env Environment](cfg StateMachineCfg[Event, Env], //nolint:lll ) StateMachine[Event, Env] { return StateMachine[Event, Env]{ @@ -206,6 +211,43 @@ func (s *StateMachine[Event, Env]) SendEvent(event Event) { } } +// CanHandle returns true if the target message can be routed to the state +// machine. +func (s *StateMachine[Event, Env]) CanHandle(msg lnwire.Message) bool { + cfgMapper := s.cfg.MsgMapper + return fn.MapOptionZ(cfgMapper, func(mapper MsgMapper[Event]) bool { + return mapper.MapMsg(msg).IsSome() + }) +} + +// SendMessage attempts to send a wire message to the state machine. If the +// message can be mapped using the default message mapper, then true is +// returned indicating that the message was processed. Otherwise, false is +// returned. +func (s *StateMachine[Event, Env]) SendMessage(msg lnwire.Message) bool { + // If we have no message mapper, then return false as we can't process + // this message. + if !s.cfg.MsgMapper.IsSome() { + return false + } + + // Otherwise, try to map the message using the default message mapper. + // If we can't extract an event, then we'll return false to indicate + // that the message wasn't processed. + var processed bool + s.cfg.MsgMapper.WhenSome(func(mapper MsgMapper[Event]) { + event := mapper.MapMsg(msg) + + event.WhenSome(func(event Event) { + s.SendEvent(event) + + processed = true + }) + }) + + return processed +} + // CurrentState returns the current state of the state machine. func (s *StateMachine[Event, Env]) CurrentState() (State[Event, Env], error) { query := stateQuery[Event, Env]{ @@ -225,7 +267,9 @@ type StateSubscriber[E any, F Environment] *fn.EventReceiver[State[E, F]] // RegisterStateEvents registers a new event listener that will be notified of // new state transitions. -func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[Event, Env] { +func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[ + Event, Env] { + subscriber := fn.NewEventReceiver[State[Event, Env]](10) // TODO(roasbeef): instead give the state and the input event? @@ -237,8 +281,10 @@ func (s *StateMachine[Event, Env]) RegisterStateEvents() StateSubscriber[Event, // RemoveStateSub removes the target state subscriber from the set of active // subscribers. -func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[Event, Env]) { - s.newStateEvents.RemoveSubscriber(sub) +func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[ + Event, Env]) { + + _ = s.newStateEvents.RemoveSubscriber(sub) } // executeDaemonEvent executes a daemon event, which is a special type of event @@ -246,7 +292,6 @@ func (s *StateMachine[Event, Env]) RemoveStateSub(sub StateSubscriber[Event, Env // machine. An error is returned if the type of event is unknown. func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error { switch daemonEvent := event.(type) { - // This is a send message event, so we'll send the event, and also mind // any preconditions as well as post-send events. case *SendMsgEvent[Event]: @@ -255,7 +300,8 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error { daemonEvent.TargetPeer, daemonEvent.Msgs, ) if err != nil { - return fmt.Errorf("unable to send msgs: %w", err) + return fmt.Errorf("unable to send msgs: %w", + err) } // If a post-send event was specified, then we'll @@ -300,7 +346,12 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error { ) if canSend { - sendAndCleanUp() + err := sendAndCleanUp() + if err != nil { + //nolint:lll + log.Errorf("FSM(%v): unable to send message: %v", err) + } + return } @@ -319,8 +370,6 @@ func (s *StateMachine[Event, Env]) executeDaemonEvent(event DaemonEvent) error { daemonEvent.Tx, daemonEvent.Label, ) if err != nil { - // TODO(roasbeef): hook has channel read event event is - // hit? return fmt.Errorf("unable to broadcast txn: %w", err) } @@ -414,6 +463,8 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], // any new emitted internal events to our event queue. This continues // until we reach a terminal state, or we run out of internal events to // process. + // + //nolint:lll for nextEvent := eventQueue.Dequeue(); nextEvent.IsSome(); nextEvent = eventQueue.Dequeue() { err := fn.MapOptionZ(nextEvent, func(event Event) error { // Apply the state transition function of the current @@ -426,13 +477,17 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], } newEvents := transition.NewEvents - err = fn.MapOptionZ(newEvents, func(events EmittedEvent[Event]) error { + err = fn.MapOptionZ(newEvents, func(events EmittedEvent[Event]) error { //nolint:lll // With the event processed, we'll process any // new daemon events that were emitted as part // of this new state transition. + // + //nolint:lll err := fn.MapOptionZ(events.ExternalEvents, func(dEvents DaemonEventSet) error { for _, dEvent := range dEvents { - err := s.executeDaemonEvent(dEvent) + err := s.executeDaemonEvent( + dEvent, + ) if err != nil { return err } @@ -446,6 +501,8 @@ func (s *StateMachine[Event, Env]) applyEvents(currentState State[Event, Env], // Next, we'll add any new emitted events to // our event queue. + // + //nolint:lll events.InternalEvent.WhenSome(func(inEvent Event) { eventQueue.Enqueue(inEvent) }) @@ -516,7 +573,10 @@ func (s *StateMachine[Event, Env]) driveMachine() { // An outside caller is querying our state, so we'll return the // latest state. case stateQuery := <-s.stateQuery: - if !fn.SendOrQuit(stateQuery.CurrentState, currentState, s.quit) { + if !fn.SendOrQuit( + stateQuery.CurrentState, currentState, s.quit, + ) { + return } diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index 82d4431f2..bf7026f4e 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -174,13 +174,17 @@ func newDaemonAdapters() *dummyAdapters { } } -func (d *dummyAdapters) SendMessages(pub btcec.PublicKey, msgs []lnwire.Message) error { +func (d *dummyAdapters) SendMessages(pub btcec.PublicKey, + msgs []lnwire.Message) error { + args := d.Called(pub, msgs) return args.Error(0) } -func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx, label string) error { +func (d *dummyAdapters) BroadcastTransaction(tx *wire.MsgTx, + label string) error { + args := d.Called(tx, label) return args.Error(0) @@ -194,6 +198,7 @@ func (d *dummyAdapters) RegisterConfirmationsNtfn(txid *chainhash.Hash, args := d.Called(txid, pkScript, numConfs) err := args.Error(0) + return &chainntnfs.ConfirmationEvent{ Confirmed: d.confChan, }, err @@ -342,7 +347,9 @@ func TestStateMachineDaemonEvents(t *testing.T) { // As soon as we send in the daemon event, we expect the // disable+broadcast events to be processed, as they are unconditional. - adapters.On("BroadcastTransaction", mock.Anything, mock.Anything).Return(nil) + adapters.On( + "BroadcastTransaction", mock.Anything, mock.Anything, + ).Return(nil) adapters.On("SendMessages", *pub2, mock.Anything).Return(nil) // We'll start off by sending in the daemon event, which'll trigger the @@ -374,3 +381,70 @@ func TestStateMachineDaemonEvents(t *testing.T) { adapters.AssertExpectations(t) env.AssertExpectations(t) } + +type dummyMsgMapper struct { + mock.Mock +} + +func (d *dummyMsgMapper) MapMsg(wireMsg lnwire.Message) fn.Option[dummyEvents] { + args := d.Called(wireMsg) + + //nolint:forcetypeassert + return args.Get(0).(fn.Option[dummyEvents]) +} + +// TestStateMachineMsgMapper tests that given a message mapper, we can properly +// send in wire messages get mapped to FSM events. +func TestStateMachineMsgMapper(t *testing.T) { + // First, we'll create our state machine given the env, and our + // starting state. + env := &dummyEnv{} + startingState := &dummyStateStart{} + adapters := newDaemonAdapters() + + // We'll also provide a message mapper that only knows how to map a + // single wire message (error). + dummyMapper := &dummyMsgMapper{} + + // The only thing we know how to map is the error message, which'll + // terminate the state machine. + wireError := &lnwire.Error{} + initMsg := &lnwire.Init{} + dummyMapper.On("MapMsg", wireError).Return( + fn.Some(dummyEvents(&goToFin{})), + ) + dummyMapper.On("MapMsg", initMsg).Return(fn.None[dummyEvents]()) + + cfg := StateMachineCfg[dummyEvents, *dummyEnv]{ + Daemon: adapters, + InitialState: startingState, + Env: env, + MsgMapper: fn.Some[MsgMapper[dummyEvents]](dummyMapper), + } + stateMachine := NewStateMachine(cfg) + stateMachine.Start() + defer stateMachine.Stop() + + // As we're triggering internal events, we'll also subscribe to the set + // of new states so we can assert as we go. + stateSub := stateMachine.RegisterStateEvents() + defer stateMachine.RemoveStateSub(stateSub) + + // First, we'll verify that the CanHandle method works as expected. + require.True(t, stateMachine.CanHandle(wireError)) + require.False(t, stateMachine.CanHandle(&lnwire.Init{})) + + // Next, we'll attempt to send the wire message into the state machine. + // We should transition to the final state. + require.True(t, stateMachine.SendMessage(wireError)) + + // We should transition to the final state. + expectedStates := []State[dummyEvents, *dummyEnv]{ + &dummyStateStart{}, &dummyStateFin{}, + } + assertStateTransitions(t, stateSub, expectedStates) + + dummyMapper.AssertExpectations(t) + adapters.AssertExpectations(t) + env.AssertExpectations(t) +}