diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index c434513d7..364bcd7e3 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -100,14 +100,12 @@ type ChainArbitratorConfig struct { MarkLinkInactive func(wire.OutPoint) error // ContractBreach is a function closure that the ChainArbitrator will - // use to notify the breachArbiter about a contract breach. A callback - // should be passed that when called will mark the channel pending - // close in the database. It should only return a non-nil error when the - // breachArbiter has preserved the necessary breach info for this - // channel point, and the callback has succeeded, meaning it is safe to - // stop watching the channel. - ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution, - func() error) error + // use to notify the breachArbiter about a contract breach. It should + // only return a non-nil error when the breachArbiter has preserved + // the necessary breach info for this channel point. Once the breach + // resolution is persisted in the channel arbitrator, it will be safe + // to mark the channel closed. + ContractBreach func(wire.OutPoint, *lnwallet.BreachRetribution) error // IsOurAddress is a function that returns true if the passed address // is known to the underlying wallet. Otherwise, false should be @@ -512,19 +510,17 @@ func (c *ChainArbitrator) Start() error { // First, we'll create an active chainWatcher for this channel // to ensure that we detect any relevant on chain events. + breachClosure := func(ret *lnwallet.BreachRetribution) error { + return c.cfg.ContractBreach(chanPoint, ret) + } + chainWatcher, err := newChainWatcher( chainWatcherConfig{ - chanState: channel, - notifier: c.cfg.Notifier, - signer: c.cfg.Signer, - isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution, - markClosed func() error) error { - - return c.cfg.ContractBreach( - chanPoint, retInfo, markClosed, - ) - }, + chanState: channel, + notifier: c.cfg.Notifier, + signer: c.cfg.Signer, + isOurAddr: c.cfg.IsOurAddress, + contractBreach: breachClosure, extractStateNumHint: lnwallet.GetStateNumHint, }, ) @@ -1122,11 +1118,11 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error notifier: c.cfg.Notifier, signer: c.cfg.Signer, isOurAddr: c.cfg.IsOurAddress, - contractBreach: func(retInfo *lnwallet.BreachRetribution, - markClosed func() error) error { + contractBreach: func( + retInfo *lnwallet.BreachRetribution) error { return c.cfg.ContractBreach( - chanPoint, retInfo, markClosed, + chanPoint, retInfo, ) }, extractStateNumHint: lnwallet.GetStateNumHint, diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index 2d8dcce89..7e813ee4a 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -9,6 +9,7 @@ import ( "github.com/btcsuite/btcd/btcec" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" @@ -62,6 +63,24 @@ type BreachResolution struct { FundingOutPoint wire.OutPoint } +// BreachCloseInfo wraps the BreachResolution with a CommitSet for the latest, +// non-breached state, with the AnchorResolution for the breached state. +type BreachCloseInfo struct { + *BreachResolution + *lnwallet.AnchorResolution + + // CommitHash is the hash of the commitment transaction. + CommitHash chainhash.Hash + + // CommitSet is the set of known valid commitments at the time the + // breach occurred on-chain. + CommitSet CommitSet + + // CloseSummary gives the recipient of the BreachCloseInfo information + // to mark the channel closed in the database. + CloseSummary channeldb.ChannelCloseSummary +} + // CommitSet is a collection of the set of known valid commitments at a given // instant. If ConfCommitKey is set, then the commitment identified by the // HtlcSetKey has hit the chain. This struct will be used to examine all live @@ -129,7 +148,7 @@ type ChainEventSubscription struct { // ContractBreach is a channel that will be sent upon if we detect a // contract breach. The struct sent across the channel contains all the // material required to bring the cheating channel peer to justice. - ContractBreach chan *lnwallet.BreachRetribution + ContractBreach chan *BreachCloseInfo // Cancel cancels the subscription to the event stream for a particular // channel. This method should be called once the caller no longer needs to @@ -155,13 +174,10 @@ type chainWatcherConfig struct { signer input.Signer // contractBreach is a method that will be called by the watcher if it - // detects that a contract breach transaction has been confirmed. A - // callback should be passed that when called will mark the channel - // pending close in the database. It will only return a non-nil error - // when the breachArbiter has preserved the necessary breach info for - // this channel point, and the callback has succeeded, meaning it is - // safe to stop watching the channel. - contractBreach func(*lnwallet.BreachRetribution, func() error) error + // detects that a contract breach transaction has been confirmed. It + // will only return a non-nil error when the breachArbiter has + // preserved the necessary breach info for this channel point. + contractBreach func(*lnwallet.BreachRetribution) error // isOurAddr is a function that returns true if the passed address is // known to us. @@ -316,7 +332,7 @@ func (c *chainWatcher) SubscribeChannelEvents() *ChainEventSubscription { RemoteUnilateralClosure: make(chan *RemoteUnilateralCloseInfo, 1), LocalUnilateralClosure: make(chan *LocalUnilateralCloseInfo, 1), CooperativeClosure: make(chan *CooperativeCloseInfo, 1), - ContractBreach: make(chan *lnwallet.BreachRetribution, 1), + ContractBreach: make(chan *BreachCloseInfo, 1), Cancel: func() { c.Lock() delete(c.clientSubscriptions, clientID) @@ -790,12 +806,27 @@ func (c *chainWatcher) handleKnownRemoteState( return false, nil } + // Create an AnchorResolution for the breached state. + anchorRes, err := lnwallet.NewAnchorResolution( + c.cfg.chanState, commitSpend.SpendingTx, + ) + if err != nil { + return false, fmt.Errorf("unable to create anchor "+ + "resolution: %v", err) + } + + // We'll set the ConfCommitKey here as the remote htlc set. This is + // only used to ensure a nil-pointer-dereference doesn't occur and is + // not used otherwise. The HTLC's may not exist for the + // RemotePendingHtlcSet. + chainSet.commitSet.ConfCommitKey = &RemoteHtlcSet + // THEY'RE ATTEMPTING TO VIOLATE THE CONTRACT LAID OUT WITHIN THE // PAYMENT CHANNEL. Therefore we close the signal indicating a revoked // broadcast to allow subscribers to swiftly dispatch justice!!! err = c.dispatchContractBreach( - commitSpend, &chainSet.remoteCommit, - broadcastStateNum, retribution, + commitSpend, chainSet, broadcastStateNum, retribution, + anchorRes, ) if err != nil { return false, fmt.Errorf("unable to handle channel "+ @@ -1088,8 +1119,9 @@ func (c *chainWatcher) dispatchRemoteForceClose( // materials required to bring the cheater to justice, then notify all // registered subscribers of this event. func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail, - remoteCommit *channeldb.ChannelCommitment, broadcastStateNum uint64, - retribution *lnwallet.BreachRetribution) error { + chainSet *chainSet, broadcastStateNum uint64, + retribution *lnwallet.BreachRetribution, + anchorRes *lnwallet.AnchorResolution) error { log.Warnf("Remote peer has breached the channel contract for "+ "ChannelPoint(%v). Revoked state #%v was broadcast!!!", @@ -1130,7 +1162,7 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail return spew.Sdump(retribution) })) - settledBalance := remoteCommit.LocalBalance.ToSatoshis() + settledBalance := chainSet.remoteCommit.LocalBalance.ToSatoshis() closeSummary := channeldb.ChannelCloseSummary{ ChanPoint: c.cfg.chanState.FundingOutpoint, ChainHash: c.cfg.chanState.ChainHash, @@ -1156,38 +1188,35 @@ func (c *chainWatcher) dispatchContractBreach(spendEvent *chainntnfs.SpendDetail closeSummary.LastChanSyncMsg = chanSync } - // We create a function closure that will mark the channel as pending - // close in the database. We pass it to the contracBreach method such - // that it can ensure safe handoff of the breach before we close the - // channel. - markClosed := func() error { - // At this point, we've successfully received an ack for the - // breach close, and we can mark the channel as pending force - // closed. - if err := c.cfg.chanState.CloseChannel( - &closeSummary, channeldb.ChanStatusRemoteCloseInitiator, - ); err != nil { - return err - } - - log.Infof("Breached channel=%v marked pending-closed", - c.cfg.chanState.FundingOutpoint) - return nil - } - - // Hand the retribution info over to the breach arbiter. - if err := c.cfg.contractBreach(retribution, markClosed); err != nil { + // Hand the retribution info over to the breach arbiter. This function + // will wait for a response from the breach arbiter and then proceed to + // send a BreachCloseInfo to the channel arbitrator. The channel arb + // will then mark the channel as closed after resolutions and the + // commit set are logged in the arbitrator log. + if err := c.cfg.contractBreach(retribution); err != nil { log.Errorf("unable to hand breached contract off to "+ "breachArbiter: %v", err) return err } + breachRes := &BreachResolution{ + FundingOutPoint: c.cfg.chanState.FundingOutpoint, + } + + breachInfo := &BreachCloseInfo{ + CommitHash: spendEvent.SpendingTx.TxHash(), + BreachResolution: breachRes, + AnchorResolution: anchorRes, + CommitSet: chainSet.commitSet, + CloseSummary: closeSummary, + } + // With the event processed and channel closed, we'll now notify all // subscribers of the event. c.Lock() for _, sub := range c.clientSubscriptions { select { - case sub.ContractBreach <- retribution: + case sub.ContractBreach <- breachInfo: case <-c.quit: c.Unlock() return fmt.Errorf("quitting") diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 4c02b346c..a67c5e1e1 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -2646,14 +2646,59 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // the ChainWatcher and BreachArbiter, we don't have to do // anything in particular, so just advance our state and // gracefully exit. - case <-c.cfg.ChainEvents.ContractBreach: + case breachInfo := <-c.cfg.ChainEvents.ContractBreach: log.Infof("ChannelArbitrator(%v): remote party has "+ "breached channel!", c.cfg.ChanPoint) + // In the breach case, we'll only have anchor and + // breach resolutions. + contractRes := &ContractResolutions{ + CommitHash: breachInfo.CommitHash, + BreachResolution: breachInfo.BreachResolution, + AnchorResolution: breachInfo.AnchorResolution, + } + + // We'll transition to the ContractClosed state and log + // the set of resolutions such that they can be turned + // into resolvers later on. We'll also insert the + // CommitSet of the latest set of commitments. + err := c.log.LogContractResolutions(contractRes) + if err != nil { + log.Errorf("Unable to write resolutions: %v", + err) + return + } + err = c.log.InsertConfirmedCommitSet( + &breachInfo.CommitSet, + ) + if err != nil { + log.Errorf("Unable to write commit set: %v", + err) + return + } + + // The channel is finally marked pending closed here as + // the breacharbiter and channel arbitrator have + // persisted the relevant states. + closeSummary := &breachInfo.CloseSummary + err = c.cfg.MarkChannelClosed( + closeSummary, + channeldb.ChanStatusRemoteCloseInitiator, + ) + if err != nil { + log.Errorf("Unable to mark channel closed: %v", + err) + return + } + + log.Infof("Breached channel=%v marked pending-closed", + breachInfo.BreachResolution.FundingOutPoint) + // We'll advance our state machine until it reaches a // terminal state. - _, _, err := c.advanceState( - uint32(bestHeight), breachCloseTrigger, nil, + _, _, err = c.advanceState( + uint32(bestHeight), breachCloseTrigger, + &breachInfo.CommitSet, ) if err != nil { log.Errorf("Unable to advance state: %v", err) diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 46742a384..6a0af45d3 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -205,6 +205,9 @@ type chanArbTestCtx struct { log ArbitratorLog sweeper *mockSweeper + + breachSubscribed chan struct{} + breachResolutionChan chan struct{} } func (c *chanArbTestCtx) CleanUp() { @@ -303,13 +306,17 @@ func withMarkClosed(markClosed func(*channeldb.ChannelCloseSummary, func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, opts ...testChanArbOption) (*chanArbTestCtx, error) { + chanArbCtx := &chanArbTestCtx{ + breachSubscribed: make(chan struct{}), + } + chanPoint := wire.OutPoint{} shortChanID := lnwire.ShortChannelID{} chanEvents := &ChainEventSubscription{ RemoteUnilateralClosure: make(chan *RemoteUnilateralCloseInfo, 1), LocalUnilateralClosure: make(chan *LocalUnilateralCloseInfo, 1), CooperativeClosure: make(chan *CooperativeCloseInfo, 1), - ContractBreach: make(chan *lnwallet.BreachRetribution, 1), + ContractBreach: make(chan *BreachCloseInfo, 1), } resolutionChan := make(chan []ResolutionMsg, 1) @@ -346,6 +353,13 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, return true }, + SubscribeBreachComplete: func(op *wire.OutPoint, + c chan struct{}) (bool, error) { + + chanArbCtx.breachResolutionChan = c + chanArbCtx.breachSubscribed <- struct{}{} + return false, nil + }, Clock: clock.NewDefaultClock(), Sweeper: mockSweeper, } @@ -425,16 +439,16 @@ func createTestChannelArbitrator(t *testing.T, log ArbitratorLog, chanArb := NewChannelArbitrator(*arbCfg, htlcSets, log) - return &chanArbTestCtx{ - t: t, - chanArb: chanArb, - cleanUp: cleanUp, - resolvedChan: resolvedChan, - resolutions: resolutionChan, - log: log, - incubationRequests: incubateChan, - sweeper: mockSweeper, - }, nil + chanArbCtx.t = t + chanArbCtx.chanArb = chanArb + chanArbCtx.cleanUp = cleanUp + chanArbCtx.resolvedChan = resolvedChan + chanArbCtx.resolutions = resolutionChan + chanArbCtx.log = log + chanArbCtx.incubationRequests = incubateChan + chanArbCtx.sweeper = mockSweeper + + return chanArbCtx, nil } // TestChannelArbitratorCooperativeClose tests that the ChannelArbitertor @@ -661,11 +675,13 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { // TestChannelArbitratorBreachClose tests that the ChannelArbitrator goes // through the expected states in case we notice a breach in the chain, and -// gracefully exits. +// is able to properly progress the breachResolver and anchorResolver to a +// successful resolution. func TestChannelArbitratorBreachClose(t *testing.T) { log := &mockArbitratorLog{ state: StateDefault, newStates: make(chan ArbitratorState, 5), + resolvers: make(map[ContractResolver]struct{}), } chanArbCtx, err := createTestChannelArbitrator(t, log) @@ -673,6 +689,8 @@ func TestChannelArbitratorBreachClose(t *testing.T) { t.Fatalf("unable to create ChannelArbitrator: %v", err) } chanArb := chanArbCtx.chanArb + chanArb.cfg.PreimageDB = newMockWitnessBeacon() + chanArb.cfg.Registry = &mockRegistry{} if err := chanArb.Start(nil); err != nil { t.Fatalf("unable to start ChannelArbitrator: %v", err) @@ -686,13 +704,99 @@ func TestChannelArbitratorBreachClose(t *testing.T) { // It should start out in the default state. chanArbCtx.AssertState(StateDefault) - // Send a breach close event. - chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{} + // We create two HTLCs, one incoming and one outgoing. We will later + // assert that we only receive a ResolutionMsg for the outgoing HTLC. + outgoingIdx := uint64(2) - // It should transition StateDefault -> StateFullyResolved. - chanArbCtx.AssertStateTransitions( - StateFullyResolved, - ) + rHash1 := [lntypes.PreimageSize]byte{1, 2, 3} + htlc1 := channeldb.HTLC{ + RHash: rHash1, + OutputIndex: 2, + Incoming: false, + HtlcIndex: outgoingIdx, + LogIndex: 2, + } + + rHash2 := [lntypes.PreimageSize]byte{2, 2, 2} + htlc2 := channeldb.HTLC{ + RHash: rHash2, + OutputIndex: 3, + Incoming: true, + HtlcIndex: 3, + LogIndex: 3, + } + + anchorRes := &lnwallet.AnchorResolution{ + AnchorSignDescriptor: input.SignDescriptor{ + Output: &wire.TxOut{Value: 1}, + }, + } + + // Create the BreachCloseInfo that the chain_watcher would normally + // send to the channel_arbitrator. + breachInfo := &BreachCloseInfo{ + BreachResolution: &BreachResolution{ + FundingOutPoint: wire.OutPoint{}, + }, + AnchorResolution: anchorRes, + CommitSet: CommitSet{ + ConfCommitKey: &RemoteHtlcSet, + HtlcSets: map[HtlcSetKey][]channeldb.HTLC{ + RemoteHtlcSet: {htlc1, htlc2}, + }, + }, + CommitHash: chainhash.Hash{}, + } + + // Send a breach close event. + chanArb.cfg.ChainEvents.ContractBreach <- breachInfo + + // It should transition StateDefault -> StateContractClosed. + chanArbCtx.AssertStateTransitions(StateContractClosed) + + // We should receive one ResolutionMsg as there was only one outgoing + // HTLC at the time of the breach. + select { + case res := <-chanArbCtx.resolutions: + require.Equal(t, 1, len(res)) + require.Equal(t, outgoingIdx, res[0].HtlcIndex) + case <-time.After(5 * time.Second): + t.Fatal("expected to receive a resolution msg") + } + + // We should now transition from StateContractClosed to + // StateWaitingFullResolution. + chanArbCtx.AssertStateTransitions(StateWaitingFullResolution) + + // One of the resolvers should be an anchor resolver and the other + // should be a breach resolver. + require.Equal(t, 2, len(chanArb.activeResolvers)) + + var anchorExists, breachExists bool + for _, resolver := range chanArb.activeResolvers { + switch resolver.(type) { + case *anchorResolver: + anchorExists = true + case *breachResolver: + breachExists = true + default: + t.Fatalf("did not expect resolver %T", resolver) + } + } + require.True(t, anchorExists && breachExists) + + // The anchor resolver is expected to re-offer the anchor input to the + // sweeper. + <-chanArbCtx.sweeper.sweptInputs + + // Wait for SubscribeBreachComplete to be called. + <-chanArbCtx.breachSubscribed + + // We'll now close the breach channel so that the state transitions to + // StateFullyResolved. + close(chanArbCtx.breachResolutionChan) + + chanArbCtx.AssertStateTransitions(StateFullyResolved) // It should also mark the channel as resolved. select { @@ -1318,12 +1422,14 @@ func TestChannelArbitratorPersistence(t *testing.T) { // TestChannelArbitratorForceCloseBreachedChannel tests that the channel // arbitrator is able to handle a channel in the process of being force closed // is breached by the remote node. In these cases we expect the -// ChannelArbitrator to gracefully exit, as the breach is handled by other -// subsystems. +// ChannelArbitrator to properly execute the breachResolver flow and then +// gracefully exit once the breachResolver receives the signal from what would +// normally be the breacharbiter. func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { log := &mockArbitratorLog{ state: StateDefault, newStates: make(chan ArbitratorState, 5), + resolvers: make(map[ContractResolver]struct{}), } chanArbCtx, err := createTestChannelArbitrator(t, log) @@ -1389,6 +1495,20 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { t.Fatalf("no response received") } + // Before restarting, we'll need to modify the arbitrator log to have + // a set of contract resolutions and a commit set. + log.resolutions = &ContractResolutions{ + BreachResolution: &BreachResolution{ + FundingOutPoint: wire.OutPoint{}, + }, + } + log.commitSet = &CommitSet{ + ConfCommitKey: &RemoteHtlcSet, + HtlcSets: map[HtlcSetKey][]channeldb.HTLC{ + RemoteHtlcSet: {}, + }, + } + // We mimic that the channel is breached while the channel arbitrator // is down. This means that on restart it will be started with a // pending close channel, of type BreachClose. @@ -1402,7 +1522,18 @@ func TestChannelArbitratorForceCloseBreachedChannel(t *testing.T) { } defer chanArbCtx.CleanUp() - // Finally it should advance to StateFullyResolved. + // We should transition to StateContractClosed. + chanArbCtx.AssertStateTransitions( + StateContractClosed, StateWaitingFullResolution, + ) + + // Wait for SubscribeBreachComplete to be called. + <-chanArbCtx.breachSubscribed + + // We'll close the breachResolutionChan to cleanup the breachResolver + // and make the state transition to StateFullyResolved. + close(chanArbCtx.breachResolutionChan) + chanArbCtx.AssertStateTransitions(StateFullyResolved) // It should also mark the channel as resolved. diff --git a/server.go b/server.go index 40c1a4a24..61687c42f 100644 --- a/server.go +++ b/server.go @@ -1088,8 +1088,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, IsOurAddress: cc.Wallet.IsOurAddress, ContractBreach: func(chanPoint wire.OutPoint, - breachRet *lnwallet.BreachRetribution, - markClosed func() error) error { + breachRet *lnwallet.BreachRetribution) error { // processACK will handle the breachArbiter ACKing the // event. @@ -1101,8 +1100,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, } // If the breachArbiter successfully handled - // the event, we can mark the channel closed. - finalErr <- markClosed() + // the event, we can signal that the handoff + // was successful. + finalErr <- nil } event := &contractcourt.ContractBreachEvent{ @@ -1118,9 +1118,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return ErrServerShuttingDown } - // We'll wait for a final error to be available, either - // from the breachArbiter or from our markClosed - // function closure. + // We'll wait for a final error to be available from + // the breachArbiter. select { case err := <-finalErr: return err