diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 9008894f7..c12796afb 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -418,8 +418,6 @@ func (c *ChannelArbitrator) Start() error { } } - // TODO(roasbeef): cancel if breached - c.wg.Add(1) go c.channelAttendant(bestHeight) return nil @@ -649,6 +647,11 @@ const ( // coopCloseTrigger is a transition trigger driven by a cooperative // close transaction being confirmed. coopCloseTrigger + + // breachCloseTrigger is a transition trigger driven by a remote breach + // being confirmed. In this case the channel arbitrator won't have to + // do anything, so we'll just clean up and exit gracefully. + breachCloseTrigger ) // String returns a human readable string describing the passed @@ -670,6 +673,9 @@ func (t transitionTrigger) String() string { case coopCloseTrigger: return "coopCloseTrigger" + case breachCloseTrigger: + return "breachCloseTrigger" + default: return "unknown trigger" } @@ -748,8 +754,9 @@ func (c *ChannelArbitrator) stateStep( // If the trigger is a cooperative close being confirmed, then // we can go straight to StateFullyResolved, as there won't be - // any contracts to resolve. - case coopCloseTrigger: + // any contracts to resolve. The same is true in the case of a + // breach. + case coopCloseTrigger, breachCloseTrigger: nextState = StateFullyResolved // Otherwise, if this state advance was triggered by a @@ -773,7 +780,7 @@ func (c *ChannelArbitrator) stateStep( // StateBroadcastCommit via a user or chain trigger. On restart, // this state may be reexecuted after closing the channel, but // failing to commit to StateContractClosed or - // StateFullyResolved. In that case, one of the three close + // StateFullyResolved. In that case, one of the four close // triggers will be presented, signifying that we should skip // rebroadcasting, and go straight to resolving the on-chain // contract or marking the channel resolved. @@ -785,7 +792,7 @@ func (c *ChannelArbitrator) stateStep( c.cfg.ChanPoint, trigger, StateContractClosed) return StateContractClosed, closeTx, nil - case coopCloseTrigger: + case coopCloseTrigger, breachCloseTrigger: log.Infof("ChannelArbitrator(%v): detected %s "+ "close after closing channel, fast-forwarding "+ "to %s to resolve contract", @@ -861,7 +868,9 @@ func (c *ChannelArbitrator) stateStep( c.cfg.ChanPoint, trigger) nextState = StateContractClosed - case coopCloseTrigger: + // If a coop close or breach was confirmed, jump straight to + // the fully resolved state. + case coopCloseTrigger, breachCloseTrigger: log.Infof("ChannelArbitrator(%v): trigger %v, "+ " going to StateFullyResolved", c.cfg.ChanPoint, trigger) @@ -2026,7 +2035,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { uint32(bestHeight), chainTrigger, nil, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) } // If as a result of this trigger, the contract is @@ -2081,7 +2090,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { closeInfo.ChannelCloseSummary, ) if err != nil { - log.Errorf("unable to mark channel closed: "+ + log.Errorf("Unable to mark channel closed: "+ "%v", err) return } @@ -2092,7 +2101,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { closeInfo.CloseHeight, coopCloseTrigger, nil, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) return } @@ -2123,7 +2132,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // actions on restart. err := c.log.LogContractResolutions(contractRes) if err != nil { - log.Errorf("unable to write resolutions: %v", + log.Errorf("Unable to write resolutions: %v", err) return } @@ -2131,7 +2140,8 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { &closeInfo.CommitSet, ) if err != nil { - log.Errorf("unable to write commit set: %v", err) + log.Errorf("Unable to write commit set: %v", + err) return } @@ -2149,7 +2159,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { closeInfo.ChannelCloseSummary, ) if err != nil { - log.Errorf("unable to mark "+ + log.Errorf("Unable to mark "+ "channel closed: %v", err) return } @@ -2161,7 +2171,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { localCloseTrigger, &closeInfo.CommitSet, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) } // The remote party has broadcast the commitment on-chain. @@ -2188,7 +2198,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { // actions on restart. err := c.log.LogContractResolutions(contractRes) if err != nil { - log.Errorf("unable to write resolutions: %v", + log.Errorf("Unable to write resolutions: %v", err) return } @@ -2196,7 +2206,8 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { &uniClosure.CommitSet, ) if err != nil { - log.Errorf("unable to write commit set: %v", err) + log.Errorf("Unable to write commit set: %v", + err) return } @@ -2213,7 +2224,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { closeSummary := &uniClosure.ChannelCloseSummary err = c.cfg.MarkChannelClosed(closeSummary) if err != nil { - log.Errorf("unable to mark channel closed: %v", + log.Errorf("Unable to mark channel closed: %v", err) return } @@ -2225,7 +2236,24 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { remoteCloseTrigger, &uniClosure.CommitSet, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) + } + + // The remote has breached the channel. As this is handled by + // the ChainWatcher and BreachArbiter, we don't have to do + // anything in particular, so just advance our state and + // gracefully exit. + case <-c.cfg.ChainEvents.ContractBreach: + log.Infof("ChannelArbitrator(%v): remote party has "+ + "breached channel!", c.cfg.ChanPoint) + + // We'll advance our state machine until it reaches a + // terminal state. + _, _, err := c.advanceState( + uint32(bestHeight), breachCloseTrigger, nil, + ) + if err != nil { + log.Errorf("Unable to advance state: %v", err) } // A new contract has just been resolved, we'll now check our @@ -2239,7 +2267,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { uint32(bestHeight), chainTrigger, nil, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) } // If we don't have anything further to do after @@ -2273,7 +2301,7 @@ func (c *ChannelArbitrator) channelAttendant(bestHeight int32) { uint32(bestHeight), userTrigger, nil, ) if err != nil { - log.Errorf("unable to advance state: %v", err) + log.Errorf("Unable to advance state: %v", err) } select { diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 7b413669f..b1183260e 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -354,7 +354,7 @@ func TestChannelArbitratorRemoteForceClose(t *testing.T) { t, log.newStates, StateContractClosed, StateFullyResolved, ) - // It should alos mark the channel as resolved. + // It should also mark the channel as resolved. select { case <-resolved: // Expected. @@ -469,6 +469,49 @@ func TestChannelArbitratorLocalForceClose(t *testing.T) { } } +// TestChannelArbitratorBreachClose tests that the ChannelArbitrator goes +// through the expected states in case we notice a breach in the chain, and +// gracefully exits. +func TestChannelArbitratorBreachClose(t *testing.T) { + log := &mockArbitratorLog{ + state: StateDefault, + newStates: make(chan ArbitratorState, 5), + } + + chanArb, resolved, _, _, err := createTestChannelArbitrator(log) + if err != nil { + t.Fatalf("unable to create ChannelArbitrator: %v", err) + } + + if err := chanArb.Start(); err != nil { + t.Fatalf("unable to start ChannelArbitrator: %v", err) + } + defer func() { + if err := chanArb.Stop(); err != nil { + t.Fatal(err) + } + }() + + // It should start out in the default state. + assertState(t, chanArb, StateDefault) + + // Send a breach close event. + chanArb.cfg.ChainEvents.ContractBreach <- &lnwallet.BreachRetribution{} + + // It should transition StateDefault -> StateFullyResolved. + assertStateTransitions( + t, log.newStates, StateFullyResolved, + ) + + // It should also mark the channel as resolved. + select { + case <-resolved: + // Expected. + case <-time.After(5 * time.Second): + t.Fatalf("contract was not resolved") + } +} + // TestChannelArbitratorLocalForceClosePendingHtlc tests that the // ChannelArbitrator goes through the expected states in case we request it to // force close a channel that still has an HTLC pending.