diff --git a/protofsm/state_machine_test.go b/protofsm/state_machine_test.go index 42e7a80d4..3dcb35099 100644 --- a/protofsm/state_machine_test.go +++ b/protofsm/state_machine_test.go @@ -321,12 +321,20 @@ func (d *dummyStateSpent) IsTerminal() bool { return true } -func assertState[Event any, Env Environment](t *testing.T, - m *StateMachine[Event, Env], expectedState State[Event, Env]) { +// assertState asserts that the state machine is currently in the expected +// state type and returns the state cast to that type. +func assertState[Event any, Env Environment, S State[Event, Env]](t *testing.T, + m *StateMachine[Event, Env], expectedState S) S { state, err := m.CurrentState() require.NoError(t, err) require.IsType(t, expectedState, state) + + // Perform the type assertion to return the concrete type. + concreteState, ok := state.(S) + require.True(t, ok, "state type assertion failed") + + return concreteState } func assertStateTransitions[Event any, Env Environment]( @@ -626,18 +634,15 @@ func TestStateMachineConfMapper(t *testing.T) { assertStateTransitions(t, stateSub, expectedStates) // Final state assertion. - finalState, err := stateMachine.CurrentState() - require.NoError(t, err) - require.IsType(t, &dummyStateConfirmed{}, finalState) + finalState := assertState(t, &stateMachine, &dummyStateConfirmed{}) // Assert that the details from the confirmation event were correctly // propagated to the final state. - finalStateDetails := finalState.(*dummyStateConfirmed) require.Equal(t, - *simulatedConf.BlockHash, finalStateDetails.blockHash, + *simulatedConf.BlockHash, finalState.blockHash, ) require.Equal(t, - simulatedConf.BlockHeight, finalStateDetails.blockHeight, + simulatedConf.BlockHeight, finalState.blockHeight, ) adapters.AssertExpectations(t) @@ -706,18 +711,15 @@ func TestStateMachineSpendMapper(t *testing.T) { assertStateTransitions(t, stateSub, expectedStates) // Final state assertion. - finalState, err := stateMachine.CurrentState() - require.NoError(t, err) - require.IsType(t, &dummyStateSpent{}, finalState) + finalState := assertState(t, &stateMachine, &dummyStateSpent{}) // Assert that the details from the spend event were correctly // propagated to the final state. - finalStateDetails := finalState.(*dummyStateSpent) require.Equal(t, - *simulatedSpend.SpenderTxHash, finalStateDetails.spenderTxHash, + *simulatedSpend.SpenderTxHash, finalState.spenderTxHash, ) require.Equal(t, - simulatedSpend.SpendingHeight, finalStateDetails.spendingHeight, + simulatedSpend.SpendingHeight, finalState.spendingHeight, ) adapters.AssertExpectations(t)