diff --git a/.golangci.yml b/.golangci.yml index e00edaa19..ca3e68818 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -45,6 +45,16 @@ linters: # trigger funlen problems that we may not want to solve at that time. - funlen + # Disable for now as we haven't yet tuned the sensitivity to our codebase + # yet. Enabling by default for example, would also force new contributors to + # potentially extensively refactor code, when they want to smaller change to + # land. + - gocyclo + + # Instances of table driven tests that don't pre-allocate shouldn't trigger + # the linter. + - prealloc + issues: # Only show newly introduced problems. new-from-rev: 01f696afce2f9c0d4ed854edefa3846891d01d8a diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 42593537d..1789e775e 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -415,7 +415,19 @@ func (c *ChannelArbitrator) Start() error { if startingState == StateWaitingFullResolution && nextState == StateWaitingFullResolution { - if err := c.relaunchResolvers(); err != nil { + // In order to relaunch the resolvers, we'll need to fetch the + // set of HTLCs that were present in the commitment transaction + // at the time it was confirmed. commitSet.ConfCommitKey can't + // be nil at this point since we're in + // StateWaitingFullResolution. We can only be in + // StateWaitingFullResolution after we've transitioned from + // StateContractClosed which can only be triggered by the local + // or remote close trigger. This trigger is only fired when we + // receive a chain event from the chain watcher than the + // commitment has been confirmed on chain, and before we + // advance our state step, we call InsertConfirmedCommitSet. + confCommitSet := commitSet.HtlcSets[*commitSet.ConfCommitKey] + if err := c.relaunchResolvers(confCommitSet); err != nil { c.cfg.BlockEpochs.Cancel() return err } @@ -431,7 +443,7 @@ func (c *ChannelArbitrator) Start() error { // starting the ChannelArbitrator. This information should ideally be stored in // the database, so this only serves as a intermediate work-around to prevent a // migration. -func (c *ChannelArbitrator) relaunchResolvers() error { +func (c *ChannelArbitrator) relaunchResolvers(confirmedHTLCs []channeldb.HTLC) error { // We'll now query our log to see if there are any active // unresolved contracts. If this is the case, then we'll // relaunch all contract resolvers. @@ -456,31 +468,22 @@ func (c *ChannelArbitrator) relaunchResolvers() error { // to prevent a db migration. We use all available htlc sets here in // order to ensure we have complete coverage. htlcMap := make(map[wire.OutPoint]*channeldb.HTLC) - for _, htlcs := range c.activeHTLCs { - for _, htlc := range htlcs.incomingHTLCs { - htlc := htlc - outpoint := wire.OutPoint{ - Hash: commitHash, - Index: uint32(htlc.OutputIndex), - } - htlcMap[outpoint] = &htlc - } - - for _, htlc := range htlcs.outgoingHTLCs { - htlc := htlc - outpoint := wire.OutPoint{ - Hash: commitHash, - Index: uint32(htlc.OutputIndex), - } - htlcMap[outpoint] = &htlc + for _, htlc := range confirmedHTLCs { + htlc := htlc + outpoint := wire.OutPoint{ + Hash: commitHash, + Index: uint32(htlc.OutputIndex), } + htlcMap[outpoint] = &htlc } log.Infof("ChannelArbitrator(%v): relaunching %v contract "+ "resolvers", c.cfg.ChanPoint, len(unresolvedContracts)) for _, resolver := range unresolvedContracts { - c.supplementResolver(resolver, htlcMap) + if err := c.supplementResolver(resolver, htlcMap); err != nil { + return err + } } c.launchResolvers(unresolvedContracts) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index f98fe1161..2bb8acc1b 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -3,12 +3,16 @@ package contractcourt import ( "errors" "fmt" + "io/ioutil" + "os" + "path/filepath" "sync" "testing" "time" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/input" @@ -127,6 +131,25 @@ func (b *mockArbitratorLog) WipeHistory() error { return nil } +// testArbLog is a wrapper around an existing (ideally fully concrete +// ArbitratorLog) that lets us intercept certain calls like transitioning to a +// new state. +type testArbLog struct { + ArbitratorLog + + newStates chan ArbitratorState +} + +func (t *testArbLog) CommitState(s ArbitratorState) error { + if err := t.ArbitratorLog.CommitState(s); err != nil { + return err + } + + t.newStates <- s + + return nil +} + type mockChainIO struct{} var _ lnwallet.BlockChainIO = (*mockChainIO)(nil) @@ -148,9 +171,101 @@ func (*mockChainIO) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) return nil, nil } -func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, - chan struct{}, chan []ResolutionMsg, chan *chainntnfs.BlockEpoch, error) { +type chanArbTestCtx struct { + t *testing.T + chanArb *ChannelArbitrator + + cleanUp func() + + resolvedChan chan struct{} + + blockEpochs chan *chainntnfs.BlockEpoch + + incubationRequests chan struct{} + + resolutions chan []ResolutionMsg + + log ArbitratorLog +} + +func (c *chanArbTestCtx) CleanUp() { + if err := c.chanArb.Stop(); err != nil { + c.t.Fatalf("unable to stop chan arb: %v", err) + } + + if c.cleanUp != nil { + c.cleanUp() + } +} + +// AssertStateTransitions asserts that the state machine steps through the +// passed states in order. +func (c *chanArbTestCtx) AssertStateTransitions(expectedStates ...ArbitratorState) { + c.t.Helper() + + var newStatesChan chan ArbitratorState + switch log := c.log.(type) { + case *mockArbitratorLog: + newStatesChan = log.newStates + + case *testArbLog: + newStatesChan = log.newStates + + default: + c.t.Fatalf("unable to assert state transitions with %T", log) + } + + for _, exp := range expectedStates { + var state ArbitratorState + select { + case state = <-newStatesChan: + case <-time.After(5 * time.Second): + c.t.Fatalf("new state not received") + } + + if state != exp { + c.t.Fatalf("expected new state %v, got %v", exp, state) + } + } +} + +// AssertState checks that the ChannelArbitrator is in the state we expect it +// to be. +func (c *chanArbTestCtx) AssertState(expected ArbitratorState) { + if c.chanArb.state != expected { + c.t.Fatalf("expected state %v, was %v", expected, c.chanArb.state) + } +} + +// Restart simulates a clean restart of the channel arbitrator, forcing it to +// walk through it's recovery logic. If this function returns nil, then a +// restart was successful. Note that the restart process keeps the log in +// place, in order to simulate proper persistence of the log. The caller can +// optionally provide a restart closure which will be executed before the +// resolver is started again, but after it is created. +func (c *chanArbTestCtx) Restart(restartClosure func(*chanArbTestCtx)) (*chanArbTestCtx, error) { + if err := c.chanArb.Stop(); err != nil { + return nil, err + } + + newCtx, err := createTestChannelArbitrator(c.t, c.log) + if err != nil { + return nil, err + } + + if restartClosure != nil { + restartClosure(newCtx) + } + + if err := newCtx.chanArb.Start(); err != nil { + return nil, err + } + + return newCtx, nil +} + +func createTestChannelArbitrator(t *testing.T, log ArbitratorLog) (*chanArbTestCtx, error) { blockEpochs := make(chan *chainntnfs.BlockEpoch) blockEpoch := &chainntnfs.BlockEpochEvent{ Epochs: blockEpochs, @@ -167,6 +282,7 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, } resolutionChan := make(chan []ResolutionMsg, 1) + incubateChan := make(chan struct{}) chainIO := &mockChainIO{} chainArbCfg := ChainArbitratorConfig{ @@ -188,6 +304,8 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, IncubateOutputs: func(wire.OutPoint, *lnwallet.CommitOutputResolution, *lnwallet.OutgoingHtlcResolution, *lnwallet.IncomingHtlcResolution, uint32) error { + + incubateChan <- struct{}{} return nil }, } @@ -224,17 +342,49 @@ func createTestChannelArbitrator(log ArbitratorLog) (*ChannelArbitrator, ChainEvents: chanEvents, } - htlcSets := make(map[HtlcSetKey]htlcSet) - return NewChannelArbitrator(arbCfg, htlcSets, log), resolvedChan, - resolutionChan, blockEpochs, nil -} + var cleanUp func() + if log == nil { + dbDir, err := ioutil.TempDir("", "chanArb") + if err != nil { + return nil, err + } + dbPath := filepath.Join(dbDir, "testdb") + db, err := bbolt.Open(dbPath, 0600, nil) + if err != nil { + return nil, err + } -// assertState checks that the ChannelArbitrator is in the state we expect it -// to be. -func assertState(t *testing.T, c *ChannelArbitrator, expected ArbitratorState) { - if c.state != expected { - t.Fatalf("expected state %v, was %v", expected, c.state) + backingLog, err := newBoltArbitratorLog( + db, arbCfg, chainhash.Hash{}, chanPoint, + ) + if err != nil { + return nil, err + } + cleanUp = func() { + db.Close() + os.RemoveAll(dbDir) + } + + log = &testArbLog{ + ArbitratorLog: backingLog, + newStates: make(chan ArbitratorState), + } } + + htlcSets := make(map[HtlcSetKey]htlcSet) + + chanArb := NewChannelArbitrator(arbCfg, htlcSets, log) + + return &chanArbTestCtx{ + t: t, + chanArb: chanArb, + cleanUp: cleanUp, + resolvedChan: resolvedChan, + resolutions: resolutionChan, + blockEpochs: blockEpochs, + log: log, + incubationRequests: incubateChan, + }, nil } // TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor @@ -246,22 +396,26 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - if err := chanArb.Start(); err != nil { + if err := chanArbCtx.chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } - defer chanArb.Stop() + defer func() { + if err := chanArbCtx.chanArb.Stop(); err != nil { + t.Fatalf("unable to stop chan arb: %v", err) + } + }() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // We set up a channel to detect when MarkChannelClosed is called. closeInfos := make(chan *channeldb.ChannelCloseSummary) - chanArb.cfg.MarkChannelClosed = func( + chanArbCtx.chanArb.cfg.MarkChannelClosed = func( closeInfo *channeldb.ChannelCloseSummary) error { closeInfos <- closeInfo return nil @@ -272,7 +426,7 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { closeInfo := &CooperativeCloseInfo{ &channeldb.ChannelCloseSummary{}, } - chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo + chanArbCtx.chanArb.cfg.ChainEvents.CooperativeClosure <- closeInfo select { case c := <-closeInfos: @@ -285,31 +439,13 @@ func TestChannelArbitratorCooperativeClose(t *testing.T) { // It should mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") } } -func assertStateTransitions(t *testing.T, newStates <-chan ArbitratorState, - expectedStates ...ArbitratorState) { - t.Helper() - - for _, exp := range expectedStates { - var state ArbitratorState - select { - case state = <-newStates: - case <-time.After(5 * time.Second): - t.Fatalf("new state not received") - } - - if state != exp { - t.Fatalf("expected new state %v, got %v", exp, state) - } - } -} - // TestChannelArbitratorRemoteForceClose checks that the ChannelArbitrator goes // through the expected states if a remote force close is observed in the // chain. @@ -319,10 +455,11 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -330,7 +467,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { defer chanArb.Stop() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Send a remote force close event. commitSpend := &chainntnfs.SpendDetail{ @@ -351,13 +488,13 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { // It should transition StateDefault -> StateContractClosed -> // StateFullyResolved. - assertStateTransitions( - t, log.newStates, StateContractClosed, StateFullyResolved, + chanArbCtx.AssertStateTransitions( + StateContractClosed, StateFullyResolved, ) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -373,10 +510,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -384,7 +522,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { defer chanArb.Stop() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // We create a channel we can use to pause the ChannelArbitrator at the // point where it broadcasts the close tx, and check its state. @@ -411,7 +549,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { } // It should transition to StateBroadcastCommit. - assertStateTransitions(t, log.newStates, StateBroadcastCommit) + chanArbCtx.AssertStateTransitions(StateBroadcastCommit) // When it is broadcasting the force close, its state should be // StateBroadcastCommit. @@ -426,7 +564,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { // After broadcasting, transition should be to // StateCommitmentBroadcasted. - assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted) select { case <-respChan: @@ -445,7 +583,7 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { // After broadcasting the close tx, it should be in state // StateCommitmentBroadcasted. - assertState(t, chanArb, StateCommitmentBroadcasted) + chanArbCtx.AssertState(StateCommitmentBroadcasted) // Now notify about the local force close getting confirmed. chanArb.cfg.ChainEvents.LocalUnilateralClosure <- &LocalUnilateralCloseInfo{ @@ -458,12 +596,11 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { } // It should transition StateContractClosed -> StateFullyResolved. - assertStateTransitions(t, log.newStates, StateContractClosed, - StateFullyResolved) + chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -479,10 +616,11 @@ func TestChannelArbitratorBreachClose(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -494,19 +632,19 @@ func TestChannelArbitratorBreachClose(t *testing.T) { }() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Send a breach close event. chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{} // It should transition StateDefault -> StateFullyResolved. - assertStateTransitions( - t, log.newStates, StateFullyResolved, + chanArbCtx.AssertStateTransitions( + StateFullyResolved, ) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -517,29 +655,15 @@ func TestChannelArbitratorBreachClose(t *testing.T) { // ChannelArbitrator goes through the expected states in case we request it to // force close a channel that still has an HTLC pending. func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { - arbLog := &mockArbitratorLog{ - state: StateDefault, - newStates: make(chan ArbitratorState, 5), - resolvers: make(map[ContractResolver]struct{}), - } - - chanArb, resolved, resolutions, _, err := createTestChannelArbitrator( - arbLog, - ) + // We create a new test context for this channel arb, notice that we + // pass in a nil ArbitratorLog which means that a default one backed by + // a real DB will be created. We need this for our test as we want to + // test proper restart recovery and resolver population. + chanArbCtx, err := createTestChannelArbitrator(t, nil) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - - incubateChan := make(chan struct{}) - chanArb.cfg.IncubateOutputs = func(_ wire.OutPoint, - _ *lnwallet.CommitOutputResolution, - _ *lnwallet.OutgoingHtlcResolution, - _ *lnwallet.IncomingHtlcResolution, _ uint32) error { - - incubateChan <- struct{}{} - - return nil - } + chanArb := chanArbCtx.chanArb chanArb.cfg.PreimageDB = newMockWitnessBeacon() chanArb.cfg.Registry = &mockRegistry{} @@ -558,9 +682,10 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { chanArb.UpdateContractSignals(signals) // Add HTLC to channel arbitrator. + htlcAmt := 10000 htlc := channeldb.HTLC{ Incoming: false, - Amt: 10000, + Amt: lnwire.MilliSatoshi(htlcAmt), HtlcIndex: 99, } @@ -599,8 +724,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // The force close request should trigger broadcast of the commitment // transaction. - assertStateTransitions( - t, arbLog.newStates, StateBroadcastCommit, + chanArbCtx.AssertStateTransitions( + StateBroadcastCommit, StateCommitmentBroadcasted, ) select { @@ -636,8 +761,8 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { Index: 0, } - // Set up the outgoing resolution. Populate SignedTimeoutTx because - // our commitment transaction got confirmed. + // Set up the outgoing resolution. Populate SignedTimeoutTx because our + // commitment transaction got confirmed. outgoingRes := lnwallet.OutgoingHtlcResolution{ Expiry: 10, SweepSignDesc: input.SignDescriptor{ @@ -675,15 +800,15 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { }, } - assertStateTransitions( - t, arbLog.newStates, StateContractClosed, + chanArbCtx.AssertStateTransitions( + StateContractClosed, StateWaitingFullResolution, ) // We expect an immediate resolution message for the outgoing dust htlc. // It is not resolvable on-chain. select { - case msgs := <-resolutions: + case msgs := <-chanArbCtx.resolutions: if len(msgs) != 1 { t.Fatalf("expected 1 message, instead got %v", len(msgs)) } @@ -696,34 +821,76 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { t.Fatalf("resolution msgs not sent") } + // We'll grab the old notifier here as our resolvers are still holding + // a reference to this instance, and a new one will be created when we + // restart the channel arb below. + oldNotifier := chanArb.cfg.Notifier.(*mockNotifier) + + // At this point, in order to simulate a restart, we'll re-create the + // channel arbitrator. We do this to ensure that all information + // required to properly resolve this HTLC are populated. + if err := chanArb.Stop(); err != nil { + t.Fatalf("unable to stop chan arb: %v", err) + } + + // We'll no re-create the resolver, notice that we use the existing + // arbLog so it carries over the same on-disk state. + chanArbCtxNew, err := chanArbCtx.Restart(nil) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + chanArb = chanArbCtxNew.chanArb + defer chanArbCtxNew.CleanUp() + + // Post restart, it should be the case that our resolver was properly + // supplemented, and we only have a single resolver in the final set. + if len(chanArb.activeResolvers) != 1 { + t.Fatalf("expected single resolver, instead got: %v", + len(chanArb.activeResolvers)) + } + + // We'll now examine the in-memory state of the active resolvers to + // ensure t hey were populated properly. + resolver := chanArb.activeResolvers[0] + outgoingResolver, ok := resolver.(*htlcOutgoingContestResolver) + if !ok { + t.Fatalf("expected outgoing contest resolver, got %vT", + resolver) + } + + // The resolver should have its htlcAmt field populated as it. + if int64(outgoingResolver.htlcAmt) != int64(htlcAmt) { + t.Fatalf("wrong htlc amount: expected %v, got %v,", + htlcAmt, int64(outgoingResolver.htlcAmt)) + } + // htlcOutgoingContestResolver is now active and waiting for the HTLC to // expire. It should not yet have passed it on for incubation. select { - case <-incubateChan: + case <-chanArbCtx.incubationRequests: t.Fatalf("contract should not be incubated yet") default: } // Send a notification that the expiry height has been reached. - notifier := chanArb.cfg.Notifier.(*mockNotifier) - notifier.epochChan <- &chainntnfs.BlockEpoch{Height: 10} + oldNotifier.epochChan <- &chainntnfs.BlockEpoch{Height: 10} // htlcOutgoingContestResolver is now transforming into a // htlcTimeoutResolver and should send the contract off for incubation. select { - case <-incubateChan: + case <-chanArbCtx.incubationRequests: case <-time.After(5 * time.Second): t.Fatalf("no response received") } // Notify resolver that the HTLC output of the commitment has been // spent. - notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + oldNotifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} // Finally, we should also receive a resolution message instructing the // switch to cancel back the HTLC. select { - case msgs := <-resolutions: + case msgs := <-chanArbCtx.resolutions: if len(msgs) != 1 { t.Fatalf("expected 1 message, instead got %v", len(msgs)) } @@ -740,18 +907,18 @@ func TestChannelArbitratorLocalForceClosePendingHtlc(t *testing.T) { // to the second level. Channel arbitrator should still not be marked // as resolved. select { - case <-resolved: + case <-chanArbCtxNew.resolvedChan: t.Fatalf("channel resolved prematurely") default: } // Notify resolver that the second level transaction is spent. - notifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} + oldNotifier.spendChan <- &chainntnfs.SpendDetail{SpendingTx: closeTx} // At this point channel should be marked as resolved. - assertStateTransitions(t, arbLog.newStates, StateFullyResolved) + chanArbCtxNew.AssertStateTransitions(StateFullyResolved) select { - case <-resolved: + case <-chanArbCtxNew.resolvedChan: case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") } @@ -766,10 +933,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -777,7 +945,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { defer chanArb.Stop() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Create a channel we can use to assert the state when it publishes // the close tx. @@ -804,7 +972,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } // It should transition to StateBroadcastCommit. - assertStateTransitions(t, log.newStates, StateBroadcastCommit) + chanArbCtx.AssertStateTransitions(StateBroadcastCommit) // We expect it to be in state StateBroadcastCommit when publishing // the force close. @@ -819,7 +987,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { // After broadcasting, transition should be to // StateCommitmentBroadcasted. - assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted) // Wait for a response to the force close. select { @@ -838,7 +1006,7 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } // The state should be StateCommitmentBroadcasted. - assertState(t, chanArb, StateCommitmentBroadcasted) + chanArbCtx.AssertState(StateCommitmentBroadcasted) // Now notify about the _REMOTE_ commitment getting confirmed. commitSpend := &chainntnfs.SpendDetail{ @@ -853,12 +1021,11 @@ func TestChannelArbitratorLocalForceCloseRemoteConfirmed(t *testing.T) { } // It should transition StateContractClosed -> StateFullyResolved. - assertStateTransitions(t, log.newStates, StateContractClosed, - StateFullyResolved) + chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved) // It should resolve. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(15 * time.Second): t.Fatalf("contract was not resolved") @@ -875,10 +1042,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -886,7 +1054,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { defer chanArb.Stop() // It should start out in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Return ErrDoubleSpend when attempting to publish the tx. stateChan := make(chan ArbitratorState) @@ -912,7 +1080,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } // It should transition to StateBroadcastCommit. - assertStateTransitions(t, log.newStates, StateBroadcastCommit) + chanArbCtx.AssertStateTransitions(StateBroadcastCommit) // We expect it to be in state StateBroadcastCommit when publishing // the force close. @@ -927,7 +1095,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { // After broadcasting, transition should be to // StateCommitmentBroadcasted. - assertStateTransitions(t, log.newStates, StateCommitmentBroadcasted) + chanArbCtx.AssertStateTransitions(StateCommitmentBroadcasted) // Wait for a response to the force close. select { @@ -946,7 +1114,7 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } // The state should be StateCommitmentBroadcasted. - assertState(t, chanArb, StateCommitmentBroadcasted) + chanArbCtx.AssertState(StateCommitmentBroadcasted) // Now notify about the _REMOTE_ commitment getting confirmed. commitSpend := &chainntnfs.SpendDetail{ @@ -961,12 +1129,11 @@ func TestChannelArbitratorLocalForceDoubleSpend(t *testing.T) { } // It should transition StateContractClosed -> StateFullyResolved. - assertStateTransitions(t, log.newStates, StateContractClosed, - StateFullyResolved) + chanArbCtx.AssertStateTransitions(StateContractClosed, StateFullyResolved) // It should resolve. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(15 * time.Second): t.Fatalf("contract was not resolved") @@ -983,17 +1150,18 @@ func TestChannelArbitratorPersistence(t *testing.T) { failLog: true, } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } // It should start in StateDefault. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Send a remote force close event. commitSpend := &chainntnfs.SpendDetail{ @@ -1014,20 +1182,17 @@ func TestChannelArbitratorPersistence(t *testing.T) { if log.state != StateDefault { t.Fatalf("expected to stay in StateDefault") } - chanArb.Stop() - // Create a new arbitrator with the same log. - chanArb, resolved, _, _, err = createTestChannelArbitrator(log) + // Restart the channel arb, this'll use the same long and prior + // context. + chanArbCtx, err = chanArbCtx.Restart(nil) if err != nil { - t.Fatalf("unable to create ChannelArbitrator: %v", err) - } - - if err := chanArb.Start(); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) + t.Fatalf("unable to restart channel arb: %v", err) } + chanArb = chanArbCtx.chanArb // Again, it should start up in the default state. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Now we make the log succeed writing the resolutions, but fail when // attempting to close the channel. @@ -1047,20 +1212,16 @@ func TestChannelArbitratorPersistence(t *testing.T) { if log.state != StateDefault { t.Fatalf("expected to stay in StateDefault") } - chanArb.Stop() - // Create yet another arbitrator with the same log. - chanArb, resolved, _, _, err = createTestChannelArbitrator(log) + // Restart once again to simulate yet another restart. + chanArbCtx, err = chanArbCtx.Restart(nil) if err != nil { - t.Fatalf("unable to create ChannelArbitrator: %v", err) - } - - if err := chanArb.Start(); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) + t.Fatalf("unable to restart channel arb: %v", err) } + chanArb = chanArbCtx.chanArb // Starts out in StateDefault. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // Now make fetching the resolutions fail. log.failFetch = fmt.Errorf("intentional fetch failure") @@ -1070,9 +1231,7 @@ func TestChannelArbitratorPersistence(t *testing.T) { // Since logging the resolutions and closing the channel now succeeds, // it should advance to StateContractClosed. - assertStateTransitions( - t, log.newStates, StateContractClosed, - ) + chanArbCtx.AssertStateTransitions(StateContractClosed) // It should not advance further, however, as fetching resolutions // failed. @@ -1084,24 +1243,18 @@ func TestChannelArbitratorPersistence(t *testing.T) { // Create a new arbitrator, and now make fetching resolutions succeed. log.failFetch = nil - chanArb, resolved, _, _, err = createTestChannelArbitrator(log) + chanArbCtx, err = chanArbCtx.Restart(nil) if err != nil { - t.Fatalf("unable to create ChannelArbitrator: %v", err) + t.Fatalf("unable to restart channel arb: %v", err) } - - if err := chanArb.Start(); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } - defer chanArb.Stop() + defer chanArbCtx.CleanUp() // Finally it should advance to StateFullyResolved. - assertStateTransitions( - t, log.newStates, StateFullyResolved, - ) + chanArbCtx.AssertStateTransitions(StateFullyResolved) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -1119,17 +1272,18 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { newStates: make(chan ArbitratorState, 5), } - chanArb, _, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } // It should start in StateDefault. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) // We start by attempting a local force close. We'll return an // unexpected publication error, causing the state machine to halt. @@ -1157,7 +1311,7 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { } // It should transition to StateBroadcastCommit. - assertStateTransitions(t, log.newStates, StateBroadcastCommit) + chanArbCtx.AssertStateTransitions(StateBroadcastCommit) // We expect it to be in state StateBroadcastCommit when attempting // the force close. @@ -1181,43 +1335,25 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { t.Fatalf("no response received") } - // Stop the channel abitrator. - if err := chanArb.Stop(); err != nil { - t.Fatal(err) - } - // We mimic that the channel is breached while the channel arbitrator // is down. This means that on restart it will be started with a // pending close channel, of type BreachClose. - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) { + c.chanArb.cfg.IsPendingClose = true + c.chanArb.cfg.ClosingHeight = 100 + c.chanArb.cfg.CloseType = channeldb.BreachClose + }) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - - chanArb.cfg.IsPendingClose = true - chanArb.cfg.ClosingHeight = 100 - chanArb.cfg.CloseType = channeldb.BreachClose - - // Start the channel abitrator again, and make sure it goes straight to - // state fully resolved, as in case of breach there is nothing to - // handle. - if err := chanArb.Start(); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } - defer func() { - if err := chanArb.Stop(); err != nil { - t.Fatal(err) - } - }() + defer chanArbCtx.CleanUp() // Finally it should advance to StateFullyResolved. - assertStateTransitions( - t, log.newStates, StateFullyResolved, - ) + chanArbCtx.AssertStateTransitions(StateFullyResolved) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -1286,6 +1422,8 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { } for _, test := range testCases { + test := test + log := &mockArbitratorLog{ state: StateDefault, newStates: make(chan ArbitratorState, 5), @@ -1296,17 +1434,18 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { failCommitState: test.expectedStates[0], } - chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } // It should start in StateDefault. - assertState(t, chanArb, StateDefault) + chanArbCtx.AssertState(StateDefault) closed := make(chan struct{}) chanArb.cfg.MarkChannelClosed = func( @@ -1336,30 +1475,23 @@ func TestChannelArbitratorCommitFailure(t *testing.T) { // Start the arbitrator again, with IsPendingClose reporting // the channel closed in the database. - chanArb, resolved, _, _, err = createTestChannelArbitrator(log) + log.failCommit = false + chanArbCtx, err = chanArbCtx.Restart(func(c *chanArbTestCtx) { + c.chanArb.cfg.IsPendingClose = true + c.chanArb.cfg.ClosingHeight = 100 + c.chanArb.cfg.CloseType = test.closeType + }) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } - log.failCommit = false - - chanArb.cfg.IsPendingClose = true - chanArb.cfg.ClosingHeight = 100 - chanArb.cfg.CloseType = test.closeType - - if err := chanArb.Start(); err != nil { - t.Fatalf("unable to start ChannelArbitrator: %v", err) - } - // Since the channel is marked closed in the database, it // should advance to the expected states. - assertStateTransitions( - t, log.newStates, test.expectedStates..., - ) + chanArbCtx.AssertStateTransitions(test.expectedStates...) // It should also mark the channel as resolved. select { - case <-resolved: + case <-chanArbCtx.resolvedChan: // Expected. case <-time.After(5 * time.Second): t.Fatalf("contract was not resolved") @@ -1382,11 +1514,12 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { failFetch: errNoResolutions, } - chanArb, _, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb chanArb.cfg.IsPendingClose = true chanArb.cfg.ClosingHeight = 100 chanArb.cfg.CloseType = channeldb.RemoteForceClose @@ -1397,9 +1530,7 @@ func TestChannelArbitratorEmptyResolutions(t *testing.T) { // It should not advance its state beyond StateContractClosed, since // fetching resolutions fails. - assertStateTransitions( - t, log.newStates, StateContractClosed, - ) + chanArbCtx.AssertStateTransitions(StateContractClosed) // It should not advance further, however, as fetching resolutions // failed. @@ -1420,10 +1551,11 @@ func TestChannelArbitratorAlreadyForceClosed(t *testing.T) { log := &mockArbitratorLog{ state: StateCommitmentBroadcasted, } - chanArb, _, _, _, err := createTestChannelArbitrator(log) + chanArbCtx, err := createTestChannelArbitrator(t, log) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1515,12 +1647,13 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { resolvers: make(map[ContractResolver]struct{}), } - chanArb, _, resolutions, blockEpochs, err := createTestChannelArbitrator( - arbLog, + chanArbCtx, err := createTestChannelArbitrator( + t, arbLog, ) if err != nil { t.Fatalf("unable to create ChannelArbitrator: %v", err) } + chanArb := chanArbCtx.chanArb if err := chanArb.Start(); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) } @@ -1568,7 +1701,7 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // now mine a block (height 5), which is 5 blocks away // (our grace delta) from the expiry of that HTLC. case testCase.htlcExpired: - blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} + chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 5} // Otherwise, we'll just trigger a regular force close // request. @@ -1584,8 +1717,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // determined that it needs to go to chain in order to // block off the redemption path so it can cancel the // incoming HTLC. - assertStateTransitions( - t, arbLog.newStates, StateBroadcastCommit, + chanArbCtx.AssertStateTransitions( + StateBroadcastCommit, StateCommitmentBroadcasted, ) @@ -1646,15 +1779,15 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // The channel arb should now transition to waiting // until the HTLCs have been fully resolved. - assertStateTransitions( - t, arbLog.newStates, StateContractClosed, + chanArbCtx.AssertStateTransitions( + StateContractClosed, StateWaitingFullResolution, ) // Now that we've sent this signal, we should have that // HTLC be cancelled back immediately. select { - case msgs := <-resolutions: + case msgs := <-chanArbCtx.resolutions: if len(msgs) != 1 { t.Fatalf("expected 1 message, "+ "instead got %v", len(msgs)) @@ -1672,10 +1805,8 @@ func TestChannelArbitratorDanglingCommitForceClose(t *testing.T) { // so instead, we'll mine another block which'll cause // it to re-examine its state and realize there're no // more HTLCs. - blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} - assertStateTransitions( - t, arbLog.newStates, StateFullyResolved, - ) + chanArbCtx.blockEpochs <- &chainntnfs.BlockEpoch{Height: 6} + chanArbCtx.AssertStateTransitions(StateFullyResolved) }) } }