diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index 6b377eeb0..eb1489d56 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -54,8 +54,10 @@ type ArbitratorLog interface { // TODO(roasbeef): document on interface the errors expected to be // returned - // CurrentState returns the current state of the ChannelArbitrator. - CurrentState() (ArbitratorState, error) + // CurrentState returns the current state of the ChannelArbitrator. It + // takes an optional database transaction, which will be used if it is + // non-nil, otherwise the lookup will be done in its own transaction. + CurrentState(tx kvdb.RTx) (ArbitratorState, error) // CommitState persists, the current state of the chain attendant. CommitState(ArbitratorState) error @@ -96,8 +98,10 @@ type ArbitratorLog interface { InsertConfirmedCommitSet(c *CommitSet) error // FetchConfirmedCommitSet fetches the known confirmed active HTLC set - // from the database. - FetchConfirmedCommitSet() (*CommitSet, error) + // from the database. It takes an optional database transaction, which + // will be used if it is non-nil, otherwise the lookup will be done in + // its own transaction. + FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error) // FetchChainActions attempts to fetch the set of previously stored // chain actions. We'll use this upon restart to properly advance our @@ -412,27 +416,28 @@ func (b *boltArbitratorLog) writeResolver(contractBucket kvdb.RwBucket, return contractBucket.Put(resKey, buf.Bytes()) } -// CurrentState returns the current state of the ChannelArbitrator. +// CurrentState returns the current state of the ChannelArbitrator. It takes an +// optional database transaction, which will be used if it is non-nil, otherwise +// the lookup will be done in its own transaction. // // NOTE: Part of the ContractResolver interface. -func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { - var s ArbitratorState - err := kvdb.View(b.db, func(tx kvdb.RTx) error { - scopeBucket := tx.ReadBucket(b.scopeKey[:]) - if scopeBucket == nil { - return errScopeBucketNoExist - } +func (b *boltArbitratorLog) CurrentState(tx kvdb.RTx) (ArbitratorState, error) { + var ( + s ArbitratorState + err error + ) - stateBytes := scopeBucket.Get(stateKey) - if stateBytes == nil { - return nil - } + if tx != nil { + s, err = b.currentState(tx) + } else { + err = kvdb.View(b.db, func(tx kvdb.RTx) error { + s, err = b.currentState(tx) + return err + }, func() { + s = 0 + }) + } - s = ArbitratorState(stateBytes[0]) - return nil - }, func() { - s = 0 - }) if err != nil && err != errScopeBucketNoExist { return s, err } @@ -440,6 +445,20 @@ func (b *boltArbitratorLog) CurrentState() (ArbitratorState, error) { return s, nil } +func (b *boltArbitratorLog) currentState(tx kvdb.RTx) (ArbitratorState, error) { + scopeBucket := tx.ReadBucket(b.scopeKey[:]) + if scopeBucket == nil { + return 0, errScopeBucketNoExist + } + + stateBytes := scopeBucket.Get(stateKey) + if stateBytes == nil { + return 0, nil + } + + return ArbitratorState(stateBytes[0]), nil +} + // CommitState persists, the current state of the chain attendant. // // NOTE: Part of the ContractResolver interface. @@ -851,29 +870,20 @@ func (b *boltArbitratorLog) InsertConfirmedCommitSet(c *CommitSet) error { } // FetchConfirmedCommitSet fetches the known confirmed active HTLC set from the -// database. +// database. It takes an optional database transaction, which will be used if it +// is non-nil, otherwise the lookup will be done in its own transaction. // // NOTE: Part of the ContractResolver interface. -func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { +func (b *boltArbitratorLog) FetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, error) { + if tx != nil { + return b.fetchConfirmedCommitSet(tx) + } + var c *CommitSet err := kvdb.View(b.db, func(tx kvdb.RTx) error { - scopeBucket := tx.ReadBucket(b.scopeKey[:]) - if scopeBucket == nil { - return errScopeBucketNoExist - } - - commitSetBytes := scopeBucket.Get(commitSetKey) - if commitSetBytes == nil { - return errNoCommitSet - } - - commitSet, err := decodeCommitSet(bytes.NewReader(commitSetBytes)) - if err != nil { - return err - } - - c = commitSet - return nil + var err error + c, err = b.fetchConfirmedCommitSet(tx) + return err }, func() { c = nil }) @@ -884,6 +894,22 @@ func (b *boltArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { return c, nil } +func (b *boltArbitratorLog) fetchConfirmedCommitSet(tx kvdb.RTx) (*CommitSet, + error) { + + scopeBucket := tx.ReadBucket(b.scopeKey[:]) + if scopeBucket == nil { + return nil, errScopeBucketNoExist + } + + commitSetBytes := scopeBucket.Get(commitSetKey) + if commitSetBytes == nil { + return nil, errNoCommitSet + } + + return decodeCommitSet(bytes.NewReader(commitSetBytes)) +} + // WipeHistory is to be called ONLY once *all* contracts have been fully // resolved, and the channel closure if finalized. This method will delete all // on-disk state within the persistent log. diff --git a/contractcourt/briefcase_test.go b/contractcourt/briefcase_test.go index 1e88f6074..6c2936b4e 100644 --- a/contractcourt/briefcase_test.go +++ b/contractcourt/briefcase_test.go @@ -611,7 +611,7 @@ func TestStateMutation(t *testing.T) { defer cleanUp() // The default state of an arbitrator should be StateDefault. - arbState, err := testLog.CurrentState() + arbState, err := testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -625,7 +625,7 @@ func TestStateMutation(t *testing.T) { if err := testLog.CommitState(StateFullyResolved); err != nil { t.Fatalf("unable to write state: %v", err) } - arbState, err = testLog.CurrentState() + arbState, err = testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -643,7 +643,7 @@ func TestStateMutation(t *testing.T) { // If we try to query for the state again, we should get the default // state again. - arbState, err = testLog.CurrentState() + arbState, err = testLog.CurrentState(nil) if err != nil { t.Fatalf("unable to query current state: %v", err) } @@ -687,11 +687,11 @@ func TestScopeIsolation(t *testing.T) { // Querying each log, the states should be the prior one we set, and be // disjoint. - log1State, err := testLog1.CurrentState() + log1State, err := testLog1.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } - log2State, err := testLog2.CurrentState() + log2State, err := testLog2.CurrentState(nil) if err != nil { t.Fatalf("unable to read arb state: %v", err) } @@ -752,7 +752,7 @@ func TestCommitSetStorage(t *testing.T) { t.Fatalf("unable to write commit set: %v", err) } - diskCommitSet, err := testLog.FetchConfirmedCommitSet() + diskCommitSet, err := testLog.FetchConfirmedCommitSet(nil) if err != nil { t.Fatalf("unable to read commit set: %v", err) } diff --git a/contractcourt/channel_arbitrator.go b/contractcourt/channel_arbitrator.go index 86ddd87d7..8b4b3df7a 100644 --- a/contractcourt/channel_arbitrator.go +++ b/contractcourt/channel_arbitrator.go @@ -398,7 +398,7 @@ func (c *ChannelArbitrator) Start() error { // First, we'll read our last state from disk, so our internal state // machine can act accordingly. - c.state, err = c.log.CurrentState() + c.state, err = c.log.CurrentState(nil) if err != nil { return err } @@ -454,7 +454,7 @@ func (c *ChannelArbitrator) Start() error { // older nodes, this won't be found at all, and will rely on the // existing written chain actions. Additionally, if this channel hasn't // logged any actions in the log, then this field won't be present. - commitSet, err := c.log.FetchConfirmedCommitSet() + commitSet, err := c.log.FetchConfirmedCommitSet(nil) if err != nil && err != errNoCommitSet && err != errScopeBucketNoExist { return err } diff --git a/contractcourt/channel_arbitrator_test.go b/contractcourt/channel_arbitrator_test.go index 38970b6be..3371998ff 100644 --- a/contractcourt/channel_arbitrator_test.go +++ b/contractcourt/channel_arbitrator_test.go @@ -51,7 +51,7 @@ type mockArbitratorLog struct { // interface. var _ ArbitratorLog = (*mockArbitratorLog)(nil) -func (b *mockArbitratorLog) CurrentState() (ArbitratorState, error) { +func (b *mockArbitratorLog) CurrentState(kvdb.RTx) (ArbitratorState, error) { return b.state, nil } @@ -140,7 +140,7 @@ func (b *mockArbitratorLog) InsertConfirmedCommitSet(c *CommitSet) error { return nil } -func (b *mockArbitratorLog) FetchConfirmedCommitSet() (*CommitSet, error) { +func (b *mockArbitratorLog) FetchConfirmedCommitSet(kvdb.RTx) (*CommitSet, error) { return b.commitSet, nil }