diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index bff2f742c..1685401c8 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -62,8 +62,10 @@ type ArbitratorLog interface { // InsertUnresolvedContracts inserts a set of unresolved contracts into // the log. The log will then persistently store each contract until - // they've been swapped out, or resolved. - InsertUnresolvedContracts(...ContractResolver) error + // they've been swapped out, or resolved. It takes a set of report which + // should be written to disk if as well if it is non-nil. + InsertUnresolvedContracts(reports []*channeldb.ResolverReport, + resolvers ...ContractResolver) error // FetchUnresolvedContracts returns all unresolved contracts that have // been previously written to the log. @@ -533,7 +535,9 @@ func (b *boltArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, erro // swapped out, or resolved. // // NOTE: Part of the ContractResolver interface. -func (b *boltArbitratorLog) InsertUnresolvedContracts(resolvers ...ContractResolver) error { +func (b *boltArbitratorLog) InsertUnresolvedContracts(reports []*channeldb.ResolverReport, + resolvers ...ContractResolver) error { + return kvdb.Batch(b.db, func(tx kvdb.RwTx) error { contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) if err != nil { @@ -547,6 +551,14 @@ func (b *boltArbitratorLog) InsertUnresolvedContracts(resolvers ...ContractResol } } + // Persist any reports that are present. + for _, report := range reports { + err := b.cfg.PutResolverReport(tx, report) + if err != nil { + return err + } + } + return nil }) } @@ -908,15 +920,28 @@ func (b *boltArbitratorLog) WipeHistory() error { // checkpointContract is a private method that will be fed into // ContractResolver instances to checkpoint their state once they reach -// milestones during contract resolution. -func (b *boltArbitratorLog) checkpointContract(c ContractResolver) error { +// milestones during contract resolution. If the report provided is non-nil, +// it should also be recorded. +func (b *boltArbitratorLog) checkpointContract(c ContractResolver, + reports ...*channeldb.ResolverReport) error { + return kvdb.Update(b.db, func(tx kvdb.RwTx) error { contractBucket, err := fetchContractWriteBucket(tx, b.scopeKey[:]) if err != nil { return err } - return b.writeResolver(contractBucket, c) + if err := b.writeResolver(contractBucket, c); err != nil { + return err + } + + for _, report := range reports { + if err := b.cfg.PutResolverReport(tx, report); err != nil { + return err + } + } + + return nil }) } diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 446110dd3..1e88f6074 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -338,8 +338,10 @@ func TestContractInsertionRetrieval(t *testing.T) { resolverMap[string(resolvers[3].ResolverKey())] = resolvers[3] resolverMap[string(resolvers[4].ResolverKey())] = resolvers[4] - // Now, we'll insert the resolver into the log. - if err := testLog.InsertUnresolvedContracts(resolvers...); err != nil { + // Now, we'll insert the resolver into the log, we do not need to apply + // any closures, so we will pass in nil. + err = testLog.InsertUnresolvedContracts(nil, resolvers...) + if err != nil { t.Fatalf("unable to insert resolvers: %v", err) } @@ -419,8 +421,9 @@ func TestContractResolution(t *testing.T) { } // First, we'll insert the resolver into the database and ensure that - // we get the same resolver out the other side. - err = testLog.InsertUnresolvedContracts(timeoutResolver) + // we get the same resolver out the other side. We do not need to apply + // any closures. + err = testLog.InsertUnresolvedContracts(nil, timeoutResolver) if err != nil { t.Fatalf("unable to insert contract into db: %v", err) } @@ -482,8 +485,9 @@ func TestContractSwapping(t *testing.T) { htlcTimeoutResolver: timeoutResolver, } - // We'll first insert the contest resolver into the log. - err = testLog.InsertUnresolvedContracts(contestResolver) + // We'll first insert the contest resolver into the log with no + // additional updates. + err = testLog.InsertUnresolvedContracts(nil, contestResolver) if err != nil { t.Fatalf("unable to insert contract into db: %v", err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index eba962e8b..6cbf5f1a7 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -976,7 +976,7 @@ func (c *ChannelArbitrator) stateStep( log.Debugf("ChannelArbitrator(%v): inserting %v contract "+ "resolvers", c.cfg.ChanPoint, len(htlcResolvers)) - err = c.log.InsertUnresolvedContracts(htlcResolvers...) + err = c.log.InsertUnresolvedContracts(nil, htlcResolvers...) if err != nil { return StateError, closeTx, err } @@ -1744,8 +1744,10 @@ func (c *ChannelArbitrator) prepContractResolutions( // resolver so they each can do their duty. resolverCfg := ResolverConfig{ ChannelArbitratorConfig: c.cfg, - Checkpoint: func(res ContractResolver) error { - return c.log.InsertUnresolvedContracts(res) + Checkpoint: func(res ContractResolver, + reports ...*channeldb.ResolverReport) error { + + return c.log.InsertUnresolvedContracts(reports, res) }, } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index c9af5416b..fe2d1bb47 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -78,7 +78,7 @@ func (b *mockArbitratorLog) FetchUnresolvedContracts() ([]ContractResolver, return v, nil } -func (b *mockArbitratorLog) InsertUnresolvedContracts( +func (b *mockArbitratorLog) InsertUnresolvedContracts(_ []*channeldb.ResolverReport, resolvers ...ContractResolver) error { b.Lock() diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 5c4990778..b7d297dad 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -277,7 +277,7 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) { c.reportLock.Unlock() c.resolved = true - return nil, c.Checkpoint(c) + return nil, c.Checkpoint(c, nil) } // Stop signals the resolver to cancel any current resolution processes, and diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index 109f62fda..d995b2dec 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -50,7 +50,9 @@ func newCommitSweepResolverTestContext(t *testing.T, cfg := ResolverConfig{ ChannelArbitratorConfig: chainCfg, - Checkpoint: func(_ ContractResolver) error { + Checkpoint: func(_ ContractResolver, + _ ...*channeldb.ResolverReport) error { + checkPointChan <- struct{}{} return nil }, diff --git a/contractcourt/contract_resolvers.go b/contractcourt/contract_resolvers.go index a5fe119ad..cac40bace 100644 --- a/contractcourt/contract_resolvers.go +++ b/contractcourt/contract_resolvers.go @@ -86,8 +86,10 @@ type ResolverConfig struct { // Checkpoint allows a resolver to check point its state. This function // should write the state of the resolver to persistent storage, and - // return a non-nil error upon success. - Checkpoint func(ContractResolver) error + // return a non-nil error upon success. It takes a resolver report, + // which contains information about the outcome and should be written + // to disk if non-nil. + Checkpoint func(ContractResolver, ...*channeldb.ResolverReport) error } // contractResolverKit is meant to be used as a mix-in struct to be embedded within a diff --git a/contractcourt/htlc_incoming_resolver_test.go b/contractcourt/htlc_incoming_resolver_test.go index 9b6555aeb..059e1ccc2 100644 --- a/contractcourt/htlc_incoming_resolver_test.go +++ b/contractcourt/htlc_incoming_resolver_test.go @@ -260,7 +260,9 @@ func newIncomingResolverTestContext(t *testing.T) *incomingResolverTestContext { cfg := ResolverConfig{ ChannelArbitratorConfig: chainCfg, - Checkpoint: func(_ ContractResolver) error { + Checkpoint: func(_ ContractResolver, + _ ...*channeldb.ResolverReport) error { + checkPointChan <- struct{}{} return nil }, diff --git a/contractcourt/htlc_outgoing_contest_resolver_test.go b/contractcourt/htlc_outgoing_contest_resolver_test.go index a689f5d7b..c68436db8 100644 --- a/contractcourt/htlc_outgoing_contest_resolver_test.go +++ b/contractcourt/htlc_outgoing_contest_resolver_test.go @@ -134,7 +134,9 @@ func newOutgoingResolverTestContext(t *testing.T) *outgoingResolverTestContext { cfg := ResolverConfig{ ChannelArbitratorConfig: chainCfg, - Checkpoint: func(_ ContractResolver) error { + Checkpoint: func(_ ContractResolver, + _ ...*channeldb.ResolverReport) error { + checkPointChan <- struct{}{} return nil }, diff --git a/contractcourt/htlc_timeout_resolver_test.go b/contractcourt/htlc_timeout_resolver_test.go index 9445b1a29..293db46fc 100644 --- a/contractcourt/htlc_timeout_resolver_test.go +++ b/contractcourt/htlc_timeout_resolver_test.go @@ -259,7 +259,9 @@ func TestHtlcTimeoutResolver(t *testing.T) { cfg := ResolverConfig{ ChannelArbitratorConfig: chainCfg, - Checkpoint: func(_ ContractResolver) error { + Checkpoint: func(_ ContractResolver, + _ ...*channeldb.ResolverReport) error { + checkPointChan <- struct{}{} return nil },