mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 01:33:02 +01:00
Merge pull request #8222 from ellemouton/wtclientStartupPerf
wtclient+migration: start storing chan max height in channel details bucket
This commit is contained in:
commit
80684eccbd
@ -123,6 +123,9 @@
|
||||
## Breaking Changes
|
||||
## Performance Improvements
|
||||
|
||||
* Watchtower client DB migration to massively [improve the start-up
|
||||
performance](https://github.com/lightningnetwork/lnd/pull/8222) of a client.
|
||||
|
||||
# Technical and Architectural Updates
|
||||
## BOLT Spec Updates
|
||||
|
||||
|
@ -17,6 +17,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/chainntnfs"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/channelnotifier"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
@ -295,9 +296,8 @@ type TowerClient struct {
|
||||
|
||||
closableSessionQueue *sessionCloseMinHeap
|
||||
|
||||
backupMu sync.Mutex
|
||||
summaries wtdb.ChannelSummaries
|
||||
chanCommitHeights map[lnwire.ChannelID]uint64
|
||||
backupMu sync.Mutex
|
||||
chanInfos wtdb.ChannelInfos
|
||||
|
||||
statTicker *time.Ticker
|
||||
stats *ClientStats
|
||||
@ -339,9 +339,7 @@ func New(config *Config) (*TowerClient, error) {
|
||||
|
||||
plog := build.NewPrefixLog(prefix, log)
|
||||
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
// registered channels.
|
||||
chanSummaries, err := cfg.DB.FetchChanSummaries()
|
||||
chanInfos, err := cfg.DB.FetchChanInfos()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -358,9 +356,8 @@ func New(config *Config) (*TowerClient, error) {
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: queue,
|
||||
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
|
||||
activeSessions: newSessionQueueSet(),
|
||||
summaries: chanSummaries,
|
||||
chanInfos: chanInfos,
|
||||
closableSessionQueue: newSessionCloseMinHeap(),
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(ClientStats),
|
||||
@ -369,44 +366,6 @@ func New(config *Config) (*TowerClient, error) {
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
// perUpdate is a callback function that will be used to inspect the
|
||||
// full set of candidate client sessions loaded from disk, and to
|
||||
// determine the highest known commit height for each channel. This
|
||||
// allows the client to reject backups that it has already processed for
|
||||
// its active policy.
|
||||
perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID,
|
||||
commitHeight uint64) {
|
||||
|
||||
// We only want to consider accepted updates that have been
|
||||
// accepted under an identical policy to the client's current
|
||||
// policy.
|
||||
if policy != c.cfg.Policy {
|
||||
return
|
||||
}
|
||||
|
||||
c.backupMu.Lock()
|
||||
defer c.backupMu.Unlock()
|
||||
|
||||
// Take the highest commit height found in the session's acked
|
||||
// updates.
|
||||
height, ok := c.chanCommitHeights[chanID]
|
||||
if !ok || commitHeight > height {
|
||||
c.chanCommitHeights[chanID] = commitHeight
|
||||
}
|
||||
}
|
||||
|
||||
perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID,
|
||||
height uint64) {
|
||||
|
||||
perUpdate(s.Policy, chanID, height)
|
||||
}
|
||||
|
||||
perCommittedUpdate := func(s *wtdb.ClientSession,
|
||||
u *wtdb.CommittedUpdate) {
|
||||
|
||||
perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight)
|
||||
}
|
||||
|
||||
candidateTowers := newTowerListIterator()
|
||||
perActiveTower := func(tower *Tower) {
|
||||
// If the tower has already been marked as active, then there is
|
||||
@ -429,8 +388,6 @@ func New(config *Config) (*TowerClient, error) {
|
||||
candidateSessions, err := getTowerAndSessionCandidates(
|
||||
cfg.DB, cfg.SecretKeyRing, perActiveTower,
|
||||
wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)),
|
||||
wtdb.WithPerMaxHeight(perMaxHeight),
|
||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||
wtdb.WithPostEvalFilterFn(ExhaustedSessionFilter()),
|
||||
)
|
||||
if err != nil {
|
||||
@ -594,7 +551,7 @@ func (c *TowerClient) Start() error {
|
||||
|
||||
// Iterate over the list of registered channels and check if
|
||||
// any of them can be marked as closed.
|
||||
for id := range c.summaries {
|
||||
for id := range c.chanInfos {
|
||||
isClosed, closedHeight, err := c.isChannelClosed(id)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
@ -615,7 +572,7 @@ func (c *TowerClient) Start() error {
|
||||
|
||||
// Since the channel has been marked as closed, we can
|
||||
// also remove it from the channel summaries map.
|
||||
delete(c.summaries, id)
|
||||
delete(c.chanInfos, id)
|
||||
}
|
||||
|
||||
// Load all closable sessions.
|
||||
@ -732,7 +689,7 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
|
||||
// If a pkscript for this channel already exists, the channel has been
|
||||
// previously registered.
|
||||
if _, ok := c.summaries[chanID]; ok {
|
||||
if _, ok := c.chanInfos[chanID]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -752,8 +709,10 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error {
|
||||
|
||||
// Finally, cache the pkscript in our in-memory cache to avoid db
|
||||
// lookups for the remainder of the daemon's execution.
|
||||
c.summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: pkScript,
|
||||
c.chanInfos[chanID] = &wtdb.ChannelInfo{
|
||||
ClientChanSummary: wtdb.ClientChanSummary{
|
||||
SweepPkScript: pkScript,
|
||||
},
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -770,16 +729,23 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
|
||||
// Make sure that this channel is registered with the tower client.
|
||||
c.backupMu.Lock()
|
||||
if _, ok := c.summaries[*chanID]; !ok {
|
||||
info, ok := c.chanInfos[*chanID]
|
||||
if !ok {
|
||||
c.backupMu.Unlock()
|
||||
|
||||
return ErrUnregisteredChannel
|
||||
}
|
||||
|
||||
// Ignore backups that have already been presented to the client.
|
||||
height, ok := c.chanCommitHeights[*chanID]
|
||||
if ok && stateNum <= height {
|
||||
var duplicate bool
|
||||
info.MaxHeight.WhenSome(func(maxHeight uint64) {
|
||||
if stateNum <= maxHeight {
|
||||
duplicate = true
|
||||
}
|
||||
})
|
||||
if duplicate {
|
||||
c.backupMu.Unlock()
|
||||
|
||||
c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+
|
||||
"height=%d", chanID, stateNum)
|
||||
|
||||
@ -789,7 +755,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
// This backup has a higher commit height than any known backup for this
|
||||
// channel. We'll update our tip so that we won't accept it again if the
|
||||
// link flaps.
|
||||
c.chanCommitHeights[*chanID] = stateNum
|
||||
c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum)
|
||||
c.backupMu.Unlock()
|
||||
|
||||
id := &wtdb.BackupID{
|
||||
@ -899,7 +865,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
|
||||
defer c.backupMu.Unlock()
|
||||
|
||||
// We only care about channels registered with the tower client.
|
||||
if _, ok := c.summaries[chanID]; !ok {
|
||||
if _, ok := c.chanInfos[chanID]; !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -924,8 +890,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID,
|
||||
return fmt.Errorf("could not track closable sessions: %w", err)
|
||||
}
|
||||
|
||||
delete(c.summaries, chanID)
|
||||
delete(c.chanCommitHeights, chanID)
|
||||
delete(c.chanInfos, chanID)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -1332,7 +1297,7 @@ func (c *TowerClient) backupDispatcher() {
|
||||
// the prevTask, and should be reprocessed after obtaining a new sessionQueue.
|
||||
func (c *TowerClient) processTask(task *wtdb.BackupID) {
|
||||
c.backupMu.Lock()
|
||||
summary, ok := c.summaries[task.ChanID]
|
||||
summary, ok := c.chanInfos[task.ChanID]
|
||||
if !ok {
|
||||
c.backupMu.Unlock()
|
||||
|
||||
|
@ -81,10 +81,10 @@ type DB interface {
|
||||
// successfully backed up using the given session.
|
||||
NumAckedUpdates(id *wtdb.SessionID) (uint64, error)
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
// their channel summaries. Only the channels that have not yet been
|
||||
// FetchChanInfos loads a mapping from all registered channels to
|
||||
// their wtdb.ChannelInfo. Only the channels that have not yet been
|
||||
// marked as closed will be loaded.
|
||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||
FetchChanInfos() (wtdb.ChannelInfos, error)
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
|
@ -3,11 +3,29 @@ package wtdb
|
||||
import (
|
||||
"io"
|
||||
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
)
|
||||
|
||||
// ChannelSummaries is a map for a given channel id to it's ClientChanSummary.
|
||||
type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary
|
||||
// ChannelInfos is a map for a given channel id to it's ChannelInfo.
|
||||
type ChannelInfos map[lnwire.ChannelID]*ChannelInfo
|
||||
|
||||
// ChannelInfo contains various useful things about a registered channel.
|
||||
//
|
||||
// NOTE: the reason for adding this struct which wraps ClientChanSummary
|
||||
// instead of extending ClientChanSummary is for faster look-up of added fields.
|
||||
// If we were to extend ClientChanSummary instead then we would need to decode
|
||||
// the entire struct each time we want to read the new fields and then re-encode
|
||||
// the struct each time we want to write to a new field.
|
||||
type ChannelInfo struct {
|
||||
ClientChanSummary
|
||||
|
||||
// MaxHeight is the highest commitment height that the tower has been
|
||||
// handed for this channel. An Option type is used to store this since
|
||||
// a commitment height of zero is valid, and we need a way of knowing if
|
||||
// we have seen a new height yet or not.
|
||||
MaxHeight fn.Option[uint64]
|
||||
}
|
||||
|
||||
// ClientChanSummary tracks channel-specific information. A new
|
||||
// ClientChanSummary is inserted in the database the first time the client
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
"sync"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/lightningnetwork/lnd/fn"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
@ -25,6 +26,7 @@ var (
|
||||
// => cChanDBID -> db-assigned-id
|
||||
// => cChanSessions => db-session-id -> 1
|
||||
// => cChanClosedHeight -> block-height
|
||||
// => cChanMaxCommitmentHeight -> commitment-height
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
// cChanSessions is a sub-bucket of cChanDetailsBkt which stores:
|
||||
@ -45,6 +47,13 @@ var (
|
||||
// body of ClientChanSummary.
|
||||
cChannelSummary = []byte("client-channel-summary")
|
||||
|
||||
// cChanMaxCommitmentHeight is a key used in the cChanDetailsBkt used
|
||||
// to store the highest commitment height for this channel that the
|
||||
// tower has been handed.
|
||||
cChanMaxCommitmentHeight = []byte(
|
||||
"client-channel-max-commitment-height",
|
||||
)
|
||||
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionDBID -> db-assigned-id
|
||||
@ -1300,11 +1309,11 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
|
||||
return numAcked, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries. Only the channels that have not yet been marked as closed
|
||||
// will be loaded.
|
||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
var summaries map[lnwire.ChannelID]ClientChanSummary
|
||||
// FetchChanInfos loads a mapping from all registered channels to their
|
||||
// ChannelInfo. Only the channels that have not yet been marked as closed will
|
||||
// be loaded.
|
||||
func (c *ClientDB) FetchChanInfos() (ChannelInfos, error) {
|
||||
var infos ChannelInfos
|
||||
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||
@ -1317,34 +1326,47 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
if chanDetails == nil {
|
||||
return ErrCorruptChanDetails
|
||||
}
|
||||
|
||||
// If this channel has already been marked as closed,
|
||||
// then its summary does not need to be loaded.
|
||||
closedHeight := chanDetails.Get(cChanClosedHeight)
|
||||
if len(closedHeight) > 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var chanID lnwire.ChannelID
|
||||
copy(chanID[:], k)
|
||||
|
||||
summary, err := getChanSummary(chanDetails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
summaries[chanID] = *summary
|
||||
info := &ChannelInfo{
|
||||
ClientChanSummary: *summary,
|
||||
}
|
||||
|
||||
maxHeightBytes := chanDetails.Get(
|
||||
cChanMaxCommitmentHeight,
|
||||
)
|
||||
if len(maxHeightBytes) != 0 {
|
||||
height, err := readBigSize(maxHeightBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
info.MaxHeight = fn.Some(height)
|
||||
}
|
||||
|
||||
infos[chanID] = info
|
||||
|
||||
return nil
|
||||
})
|
||||
}, func() {
|
||||
summaries = make(map[lnwire.ChannelID]ClientChanSummary)
|
||||
infos = make(ChannelInfos)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
return infos, nil
|
||||
}
|
||||
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
@ -1963,6 +1985,12 @@ func (c *ClientDB) CommitUpdate(id *SessionID,
|
||||
return err
|
||||
}
|
||||
|
||||
// Update the channel's max commitment height if needed.
|
||||
err = maybeUpdateMaxCommitHeight(tx, update.BackupID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Finally, capture the session's last applied value so it can
|
||||
// be sent in the next state update to the tower.
|
||||
lastApplied = session.TowerLastApplied
|
||||
@ -2178,9 +2206,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
||||
|
||||
// GetDBQueue returns a BackupID Queue instance under the given namespace.
|
||||
func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] {
|
||||
return NewQueueDB[*BackupID](
|
||||
return NewQueueDB(
|
||||
c.db, namespace, func() *BackupID {
|
||||
return &BackupID{}
|
||||
}, func(tx kvdb.RwTx, item *BackupID) error {
|
||||
return maybeUpdateMaxCommitHeight(tx, *item)
|
||||
},
|
||||
)
|
||||
}
|
||||
@ -2720,6 +2750,58 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64,
|
||||
return id, idBytes, nil
|
||||
}
|
||||
|
||||
// maybeUpdateMaxCommitHeight updates the given channel details bucket with the
|
||||
// given height if it is larger than the current max height stored for the
|
||||
// channel.
|
||||
func maybeUpdateMaxCommitHeight(tx kvdb.RwTx, backupID BackupID) error {
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
// If an entry for this channel does not exist in the channel details
|
||||
// bucket then we exit here as this means that the channel has been
|
||||
// closed.
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(backupID.ChanID[:])
|
||||
if chanDetails == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
putHeight := func() error {
|
||||
b, err := writeBigSize(backupID.CommitHeight)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return chanDetails.Put(
|
||||
cChanMaxCommitmentHeight, b,
|
||||
)
|
||||
}
|
||||
|
||||
// Get current height.
|
||||
heightBytes := chanDetails.Get(cChanMaxCommitmentHeight)
|
||||
|
||||
// The height might have not been set yet, in which case
|
||||
// we can just write the new height.
|
||||
if len(heightBytes) == 0 {
|
||||
return putHeight()
|
||||
}
|
||||
|
||||
// Otherwise, read in the current max commitment height for the channel.
|
||||
currentHeight, err := readBigSize(heightBytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the new height is not larger than the current persisted height,
|
||||
// then there is nothing left for us to do.
|
||||
if backupID.CommitHeight <= currentHeight {
|
||||
return nil
|
||||
}
|
||||
|
||||
return putHeight()
|
||||
}
|
||||
|
||||
func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID,
|
||||
error) {
|
||||
|
||||
|
@ -156,13 +156,13 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID,
|
||||
return tower
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary {
|
||||
func (h *clientDBHarness) fetchChanInfos() wtdb.ChannelInfos {
|
||||
h.t.Helper()
|
||||
|
||||
summaries, err := h.db.FetchChanSummaries()
|
||||
infos, err := h.db.FetchChanInfos()
|
||||
require.NoError(h.t, err)
|
||||
|
||||
return summaries
|
||||
return infos
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID,
|
||||
@ -552,7 +552,7 @@ func testRemoveTower(h *clientDBHarness) {
|
||||
func testChanSummaries(h *clientDBHarness) {
|
||||
// First, assert that this channel is not already registered.
|
||||
var chanID lnwire.ChannelID
|
||||
_, ok := h.fetchChanSummaries()[chanID]
|
||||
_, ok := h.fetchChanInfos()[chanID]
|
||||
require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
|
||||
@ -565,7 +565,7 @@ func testChanSummaries(h *clientDBHarness) {
|
||||
|
||||
// Assert that the channel exists and that its sweep pkscript matches
|
||||
// the one we registered.
|
||||
summary, ok := h.fetchChanSummaries()[chanID]
|
||||
summary, ok := h.fetchChanInfos()[chanID]
|
||||
require.Truef(h.t, ok, "pkscript for channel %x should not exist yet",
|
||||
chanID)
|
||||
require.Equal(h.t, expPkScript, summary.SweepPkScript)
|
||||
@ -767,6 +767,58 @@ func testRogueUpdates(h *clientDBHarness) {
|
||||
require.Len(h.t, closableSessionsMap, 1)
|
||||
}
|
||||
|
||||
// testMaxCommitmentHeights tests that the max known commitment height of a
|
||||
// channel is properly persisted.
|
||||
func testMaxCommitmentHeights(h *clientDBHarness) {
|
||||
const maxUpdates = 5
|
||||
t := h.t
|
||||
|
||||
// Initially, we expect no channels.
|
||||
infos := h.fetchChanInfos()
|
||||
require.Empty(t, infos)
|
||||
|
||||
// Create a new tower.
|
||||
tower := h.newTower()
|
||||
|
||||
// Create and insert a new session.
|
||||
session1 := h.randSession(t, tower.ID, maxUpdates)
|
||||
h.insertSession(session1, nil)
|
||||
|
||||
// Create a new channel and register it.
|
||||
chanID1 := randChannelID(t)
|
||||
h.registerChan(chanID1, nil, nil)
|
||||
|
||||
// At this point, we expect one channel to be returned from
|
||||
// fetchChanInfos but with an unset max height.
|
||||
infos = h.fetchChanInfos()
|
||||
require.Len(t, infos, 1)
|
||||
|
||||
info, ok := infos[chanID1]
|
||||
require.True(t, ok)
|
||||
require.True(t, info.MaxHeight.IsNone())
|
||||
|
||||
// Commit and ACK some updates for this channel.
|
||||
for i := 1; i <= maxUpdates; i++ {
|
||||
update := randCommittedUpdateForChanWithHeight(
|
||||
t, chanID1, uint16(i), uint64(i-1),
|
||||
)
|
||||
lastApplied := h.commitUpdate(&session1.ID, update, nil)
|
||||
h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil)
|
||||
}
|
||||
|
||||
// Assert that the max height has now been set accordingly for this
|
||||
// channel.
|
||||
infos = h.fetchChanInfos()
|
||||
require.Len(t, infos, 1)
|
||||
|
||||
info, ok = infos[chanID1]
|
||||
require.True(t, ok)
|
||||
require.True(t, info.MaxHeight.IsSome())
|
||||
info.MaxHeight.WhenSome(func(u uint64) {
|
||||
require.EqualValues(t, maxUpdates-1, u)
|
||||
})
|
||||
}
|
||||
|
||||
// testMarkChannelClosed asserts the behaviour of MarkChannelClosed.
|
||||
func testMarkChannelClosed(h *clientDBHarness) {
|
||||
tower := h.newTower()
|
||||
@ -1097,6 +1149,10 @@ func TestClientDB(t *testing.T) {
|
||||
name: "rogue updates",
|
||||
run: testRogueUpdates,
|
||||
},
|
||||
{
|
||||
name: "max commitment heights",
|
||||
run: testMaxCommitmentHeights,
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration8"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized with no output filters. This
|
||||
@ -40,6 +41,7 @@ func UseLogger(logger btclog.Logger) {
|
||||
migration5.UseLogger(logger)
|
||||
migration6.UseLogger(logger)
|
||||
migration7.UseLogger(logger)
|
||||
migration8.UseLogger(logger)
|
||||
}
|
||||
|
||||
// logClosure is used to provide a closure over expensive logging operations so
|
||||
|
234
watchtower/wtdb/migration8/codec.go
Normal file
234
watchtower/wtdb/migration8/codec.go
Normal file
@ -0,0 +1,234 @@
|
||||
package migration8
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
// BreachHintSize is the length of the identifier used to detect remote
|
||||
// commitment broadcasts.
|
||||
const BreachHintSize = 16
|
||||
|
||||
// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify
|
||||
// the breach transaction.
|
||||
type BreachHint [BreachHintSize]byte
|
||||
|
||||
// ChannelID is a series of 32-bytes that uniquely identifies all channels
|
||||
// within the network. The ChannelID is computed using the outpoint of the
|
||||
// funding transaction (the txid, and output index). Given a funding output the
|
||||
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
|
||||
// txid and the big-endian serialization of the output index, truncated to
|
||||
// 2 bytes.
|
||||
type ChannelID [32]byte
|
||||
|
||||
// writeBigSize will encode the given uint64 as a BigSize byte slice.
|
||||
func writeBigSize(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
err := tlv.WriteVarInt(&b, i, &[8]byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// readBigSize converts the given byte slice into a uint64 and assumes that the
|
||||
// bytes slice is using BigSize encoding.
|
||||
func readBigSize(b []byte) (uint64, error) {
|
||||
r := bytes.NewReader(b)
|
||||
i, err := tlv.ReadVarInt(r, &[8]byte{})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
||||
|
||||
// CommittedUpdate holds a state update sent by a client along with its
|
||||
// allocated sequence number and the exact remote commitment the encrypted
|
||||
// justice transaction can rectify.
|
||||
type CommittedUpdate struct {
|
||||
// SeqNum is the unique sequence number allocated by the session to this
|
||||
// update.
|
||||
SeqNum uint16
|
||||
|
||||
CommittedUpdateBody
|
||||
}
|
||||
|
||||
// BackupID identifies a particular revoked, remote commitment by channel id and
|
||||
// commitment height.
|
||||
type BackupID struct {
|
||||
// ChanID is the channel id of the revoked commitment.
|
||||
ChanID ChannelID
|
||||
|
||||
// CommitHeight is the commitment height of the revoked commitment.
|
||||
CommitHeight uint64
|
||||
}
|
||||
|
||||
// Encode writes the BackupID from the passed io.Writer.
|
||||
func (b *BackupID) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
b.ChanID,
|
||||
b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a BackupID from the passed io.Reader.
|
||||
func (b *BackupID) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&b.ChanID,
|
||||
&b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// String returns a human-readable encoding of a BackupID.
|
||||
func (b BackupID) String() string {
|
||||
return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight)
|
||||
}
|
||||
|
||||
// WriteElements serializes a variadic list of elements into the given
|
||||
// io.Writer.
|
||||
func WriteElements(w io.Writer, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := WriteElement(w, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElements deserializes the provided io.Reader into a variadic list of
|
||||
// target elements.
|
||||
func ReadElements(r io.Reader, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := ReadElement(r, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElement serializes a single element into the provided io.Writer.
|
||||
func WriteElement(w io.Writer, element interface{}) error {
|
||||
switch e := element.(type) {
|
||||
case ChannelID:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case uint64:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case BreachHint:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case []byte:
|
||||
if err := wire.WriteVarBytes(w, 0, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElement deserializes a single element from the provided io.Reader.
|
||||
func ReadElement(r io.Reader, element interface{}) error {
|
||||
switch e := element.(type) {
|
||||
case *ChannelID:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *uint64:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *BreachHint:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *[]byte:
|
||||
bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*e = bytes
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CommittedUpdateBody represents the primary components of a CommittedUpdate.
|
||||
// On disk, this is stored under the sequence number, which acts as its key.
|
||||
type CommittedUpdateBody struct {
|
||||
// BackupID identifies the breached commitment that the encrypted blob
|
||||
// can spend from.
|
||||
BackupID BackupID
|
||||
|
||||
// Hint is the 16-byte prefix of the revoked commitment transaction ID.
|
||||
Hint BreachHint
|
||||
|
||||
// EncryptedBlob is a ciphertext containing the sweep information for
|
||||
// exacting justice if the commitment transaction matching the breach
|
||||
// hint is broadcast.
|
||||
EncryptedBlob []byte
|
||||
}
|
||||
|
||||
// Encode writes the CommittedUpdateBody to the passed io.Writer.
|
||||
func (u *CommittedUpdateBody) Encode(w io.Writer) error {
|
||||
err := u.BackupID.Encode(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return WriteElements(w,
|
||||
u.Hint,
|
||||
u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a CommittedUpdateBody from the passed io.Reader.
|
||||
func (u *CommittedUpdateBody) Decode(r io.Reader) error {
|
||||
err := u.BackupID.Decode(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return ReadElements(r,
|
||||
&u.Hint,
|
||||
&u.EncryptedBlob,
|
||||
)
|
||||
}
|
||||
|
||||
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||
const SessionIDSize = 33
|
||||
|
||||
// SessionID is created from the remote public key of a client, and serves as a
|
||||
// unique identifier and authentication for sending state updates.
|
||||
type SessionID [SessionIDSize]byte
|
||||
|
||||
// String returns a hex encoding of the session id.
|
||||
func (s SessionID) String() string {
|
||||
return hex.EncodeToString(s[:])
|
||||
}
|
14
watchtower/wtdb/migration8/log.go
Normal file
14
watchtower/wtdb/migration8/log.go
Normal file
@ -0,0 +1,14 @@
|
||||
package migration8
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
223
watchtower/wtdb/migration8/migration.go
Normal file
223
watchtower/wtdb/migration8/migration.go
Normal file
@ -0,0 +1,223 @@
|
||||
package migration8
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionDBID -> db-assigned-id
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAckRangeIndex => db-chan-id => start -> end
|
||||
// => cSessionRogueUpdateCount -> count
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cChanIDIndexBkt is a top-level bucket storing:
|
||||
// db-assigned-id -> channel-ID
|
||||
cChanIDIndexBkt = []byte("client-channel-id-index")
|
||||
|
||||
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing
|
||||
// chan-id => start -> end
|
||||
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
|
||||
|
||||
// cSessionBody is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded CommittedUpdate.
|
||||
cSessionCommits = []byte("client-session-commits")
|
||||
|
||||
// cChanDetailsBkt is a top-level bucket storing:
|
||||
// channel-id => cChannelSummary -> encoded ClientChanSummary.
|
||||
// => cChanDBID -> db-assigned-id
|
||||
// => cChanSessions => db-session-id -> 1
|
||||
// => cChanClosedHeight -> block-height
|
||||
// => cChanMaxCommitmentHeight -> commitment-height
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
cChanMaxCommitmentHeight = []byte(
|
||||
"client-channel-max-commitment-height",
|
||||
)
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
byteOrder = binary.BigEndian
|
||||
)
|
||||
|
||||
// MigrateChannelMaxHeights migrates the tower client db by collecting all the
|
||||
// max commitment heights that have been backed up for each channel and then
|
||||
// storing those heights alongside the channel info.
|
||||
func MigrateChannelMaxHeights(tx kvdb.RwTx) error {
|
||||
log.Infof("Migrating the tower client DB for quick channel max " +
|
||||
"commitment height lookup")
|
||||
|
||||
heights, err := collectChanMaxHeights(tx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return writeChanMaxHeights(tx, heights)
|
||||
}
|
||||
|
||||
// writeChanMaxHeights iterates over the given channel ID to height map and
|
||||
// writes an entry under the cChanMaxCommitmentHeight key for each channel.
|
||||
func writeChanMaxHeights(tx kvdb.RwTx, heights map[ChannelID]uint64) error {
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
for chanID, maxHeight := range heights {
|
||||
chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:])
|
||||
|
||||
// If the details bucket for this channel ID does not exist,
|
||||
// it is probably a channel that has been closed and deleted
|
||||
// already. So we can skip this height.
|
||||
if chanDetails == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
b, err := writeBigSize(maxHeight)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = chanDetails.Put(cChanMaxCommitmentHeight, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// collectChanMaxHeights iterates over all the sessions in the DB. For each
|
||||
// session, it iterates over all the Acked updates and the committed updates
|
||||
// to collect the maximum commitment height for each channel.
|
||||
func collectChanMaxHeights(tx kvdb.RwTx) (map[ChannelID]uint64, error) {
|
||||
sessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if sessionsBkt == nil {
|
||||
return nil, ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||
if chanIDIndexBkt == nil {
|
||||
return nil, ErrUninitializedDB
|
||||
}
|
||||
|
||||
heights := make(map[ChannelID]uint64)
|
||||
|
||||
// For each update we consider, we will only update the heights map if
|
||||
// the commitment height for the channel is larger than the current
|
||||
// max height stored for the channel.
|
||||
cb := func(chanID ChannelID, commitHeight uint64) {
|
||||
if commitHeight > heights[chanID] {
|
||||
heights[chanID] = commitHeight
|
||||
}
|
||||
}
|
||||
|
||||
err := sessionsBkt.ForEach(func(sessIDBytes, _ []byte) error {
|
||||
sessBkt := sessionsBkt.NestedReadBucket(sessIDBytes)
|
||||
if sessBkt == nil {
|
||||
return fmt.Errorf("bucket not found for session %x",
|
||||
sessIDBytes)
|
||||
}
|
||||
|
||||
err := forEachCommittedUpdate(sessBkt, cb)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return forEachAckedUpdate(sessBkt, chanIDIndexBkt, cb)
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return heights, nil
|
||||
}
|
||||
|
||||
// forEachCommittedUpdate iterates over all the given session's committed
|
||||
// updates and calls the call-back for each.
|
||||
func forEachCommittedUpdate(sessBkt kvdb.RBucket,
|
||||
cb func(chanID ChannelID, commitHeight uint64)) error {
|
||||
|
||||
sessionCommits := sessBkt.NestedReadBucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return sessionCommits.ForEach(func(k, v []byte) error {
|
||||
var update CommittedUpdate
|
||||
err := update.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cb(update.BackupID.ChanID, update.BackupID.CommitHeight)
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// forEachAckedUpdate iterates over all the given session's acked update range
|
||||
// indices and calls the call-back for each.
|
||||
func forEachAckedUpdate(sessBkt, chanIDIndexBkt kvdb.RBucket,
|
||||
cb func(chanID ChannelID, commitHeight uint64)) error {
|
||||
|
||||
sessionAcksRanges := sessBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||
if sessionAcksRanges == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error {
|
||||
rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID)
|
||||
if rangeBkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
index, err := readRangeIndex(rangeBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
chanIDBytes := chanIDIndexBkt.Get(dbChanID)
|
||||
var chanID ChannelID
|
||||
copy(chanID[:], chanIDBytes)
|
||||
|
||||
cb(chanID, index.MaxHeight())
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// readRangeIndex reads a persisted RangeIndex from the passed bucket and into
|
||||
// a new in-memory RangeIndex.
|
||||
func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) {
|
||||
ranges := make(map[uint64]uint64)
|
||||
err := rangesBkt.ForEach(func(k, v []byte) error {
|
||||
start, err := readBigSize(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
end, err := readBigSize(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ranges[start] = end
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize))
|
||||
}
|
214
watchtower/wtdb/migration8/migration_test.go
Normal file
214
watchtower/wtdb/migration8/migration_test.go
Normal file
@ -0,0 +1,214 @@
|
||||
package migration8
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
chan1ID = 10
|
||||
chan2ID = 20
|
||||
chan3ID = 30
|
||||
chan4ID = 40
|
||||
|
||||
chan1DBID = 111
|
||||
chan2DBID = 222
|
||||
chan3DBID = 333
|
||||
)
|
||||
|
||||
var (
|
||||
// preDetails is the expected data of the channel details bucket before
|
||||
// the migration.
|
||||
preDetails = map[string]interface{}{
|
||||
channelIDString(chan1ID): map[string]interface{}{},
|
||||
channelIDString(chan2ID): map[string]interface{}{},
|
||||
channelIDString(chan3ID): map[string]interface{}{},
|
||||
}
|
||||
|
||||
// channelIDIndex is the data in the channelID index that is used to
|
||||
// find the mapping between the db-assigned channel ID and the real
|
||||
// channel ID.
|
||||
channelIDIndex = map[string]interface{}{
|
||||
uint64ToStr(chan1DBID): channelIDString(chan1ID),
|
||||
uint64ToStr(chan2DBID): channelIDString(chan2ID),
|
||||
uint64ToStr(chan3DBID): channelIDString(chan3ID),
|
||||
}
|
||||
|
||||
// postDetails is the expected data in the channel details bucket after
|
||||
// the migration.
|
||||
postDetails = map[string]interface{}{
|
||||
channelIDString(chan1ID): map[string]interface{}{
|
||||
string(cChanMaxCommitmentHeight): uint64ToStr(105),
|
||||
},
|
||||
channelIDString(chan2ID): map[string]interface{}{
|
||||
string(cChanMaxCommitmentHeight): uint64ToStr(205),
|
||||
},
|
||||
channelIDString(chan3ID): map[string]interface{}{
|
||||
string(cChanMaxCommitmentHeight): uint64ToStr(304),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex
|
||||
// function correctly builds the new channel-to-sessionID index to the tower
|
||||
// client DB.
|
||||
func TestMigrateChannelToSessionIndex(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
update1 := &CommittedUpdate{
|
||||
SeqNum: 1,
|
||||
CommittedUpdateBody: CommittedUpdateBody{
|
||||
BackupID: BackupID{
|
||||
ChanID: intToChannelID(chan1ID),
|
||||
CommitHeight: 105,
|
||||
},
|
||||
},
|
||||
}
|
||||
var update1B bytes.Buffer
|
||||
require.NoError(t, update1.Encode(&update1B))
|
||||
|
||||
update3 := &CommittedUpdate{
|
||||
SeqNum: 1,
|
||||
CommittedUpdateBody: CommittedUpdateBody{
|
||||
BackupID: BackupID{
|
||||
ChanID: intToChannelID(chan3ID),
|
||||
CommitHeight: 304,
|
||||
},
|
||||
},
|
||||
}
|
||||
var update3B bytes.Buffer
|
||||
require.NoError(t, update3.Encode(&update3B))
|
||||
|
||||
update4 := &CommittedUpdate{
|
||||
SeqNum: 1,
|
||||
CommittedUpdateBody: CommittedUpdateBody{
|
||||
BackupID: BackupID{
|
||||
ChanID: intToChannelID(chan4ID),
|
||||
CommitHeight: 400,
|
||||
},
|
||||
},
|
||||
}
|
||||
var update4B bytes.Buffer
|
||||
require.NoError(t, update4.Encode(&update4B))
|
||||
|
||||
// sessions is the expected data in the sessions bucket before and
|
||||
// after the migration.
|
||||
sessions := map[string]interface{}{
|
||||
// A session with both acked and committed updates.
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
// This range index gives channel 1 a max height
|
||||
// of 104.
|
||||
uint64ToStr(chan1DBID): map[string]interface{}{
|
||||
uint64ToStr(100): uint64ToStr(101),
|
||||
uint64ToStr(104): uint64ToStr(104),
|
||||
},
|
||||
// This range index gives channel 2 a max height
|
||||
// of 200.
|
||||
uint64ToStr(chan2DBID): map[string]interface{}{
|
||||
uint64ToStr(200): uint64ToStr(200),
|
||||
},
|
||||
},
|
||||
string(cSessionCommits): map[string]interface{}{
|
||||
// This committed update gives channel 1 a max
|
||||
// height of 105 and so it overrides the heights
|
||||
// from the range index.
|
||||
uint64ToStr(1): update1B.String(),
|
||||
},
|
||||
},
|
||||
// A session with only acked updates.
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
// This range index gives channel 2 a max height
|
||||
// of 205.
|
||||
uint64ToStr(chan2DBID): map[string]interface{}{
|
||||
uint64ToStr(201): uint64ToStr(205),
|
||||
},
|
||||
},
|
||||
},
|
||||
// A session with only committed updates.
|
||||
sessionIDString("3"): map[string]interface{}{
|
||||
string(cSessionCommits): map[string]interface{}{
|
||||
// This committed update gives channel 3 a max
|
||||
// height of 304.
|
||||
uint64ToStr(1): update3B.String(),
|
||||
},
|
||||
},
|
||||
// This session only contains heights for channel 4 which has
|
||||
// been closed and so this should have no effect.
|
||||
sessionIDString("4"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(444): map[string]interface{}{
|
||||
uint64ToStr(400): uint64ToStr(402),
|
||||
uint64ToStr(403): uint64ToStr(405),
|
||||
},
|
||||
},
|
||||
string(cSessionCommits): map[string]interface{}{
|
||||
uint64ToStr(1): update4B.String(),
|
||||
},
|
||||
},
|
||||
// A session with no updates.
|
||||
sessionIDString("5"): map[string]interface{}{},
|
||||
}
|
||||
|
||||
// Before the migration we have a channel details
|
||||
// bucket, a sessions bucket, a session ID index bucket
|
||||
// and a channel ID index bucket.
|
||||
before := func(tx kvdb.RwTx) error {
|
||||
err := migtest.RestoreDB(tx, cChanDetailsBkt, preDetails)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migtest.RestoreDB(tx, cSessionBkt, sessions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migtest.RestoreDB(tx, cChanIDIndexBkt, channelIDIndex)
|
||||
}
|
||||
|
||||
after := func(tx kvdb.RwTx) error {
|
||||
err := migtest.VerifyDB(tx, cSessionBkt, sessions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migtest.VerifyDB(tx, cChanDetailsBkt, postDetails)
|
||||
}
|
||||
|
||||
migtest.ApplyMigration(
|
||||
t, before, after, MigrateChannelMaxHeights, false,
|
||||
)
|
||||
}
|
||||
|
||||
func sessionIDString(id string) string {
|
||||
var sessID SessionID
|
||||
copy(sessID[:], id)
|
||||
return sessID.String()
|
||||
}
|
||||
|
||||
func channelIDString(id uint64) string {
|
||||
var chanID ChannelID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return string(chanID[:])
|
||||
}
|
||||
|
||||
func uint64ToStr(id uint64) string {
|
||||
b, err := writeBigSize(id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func intToChannelID(id uint64) ChannelID {
|
||||
var chanID ChannelID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return chanID
|
||||
}
|
619
watchtower/wtdb/migration8/range_index.go
Normal file
619
watchtower/wtdb/migration8/range_index.go
Normal file
@ -0,0 +1,619 @@
|
||||
package migration8
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// rangeItem represents the start and end values of a range.
|
||||
type rangeItem struct {
|
||||
start uint64
|
||||
end uint64
|
||||
}
|
||||
|
||||
// RangeIndexOption describes the signature of a functional option that can be
|
||||
// used to modify the behaviour of a RangeIndex.
|
||||
type RangeIndexOption func(*RangeIndex)
|
||||
|
||||
// WithSerializeUint64Fn is a functional option that can be used to set the
|
||||
// function to be used to do the serialization of a uint64 into a byte slice.
|
||||
func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption {
|
||||
return func(index *RangeIndex) {
|
||||
index.serializeUint64 = fn
|
||||
}
|
||||
}
|
||||
|
||||
// RangeIndex can be used to keep track of which numbers have been added to a
|
||||
// set. It does so by keeping track of a sorted list of rangeItems. Each
|
||||
// rangeItem has a start and end value of a range where all values in-between
|
||||
// have been added to the set. It works well in situations where it is expected
|
||||
// numbers in the set are not sparse.
|
||||
type RangeIndex struct {
|
||||
// set is a sorted list of rangeItem.
|
||||
set []rangeItem
|
||||
|
||||
// mu is used to ensure safe access to set.
|
||||
mu sync.Mutex
|
||||
|
||||
// serializeUint64 is the function that can be used to convert a uint64
|
||||
// to a byte slice.
|
||||
serializeUint64 func(uint64) ([]byte, error)
|
||||
}
|
||||
|
||||
// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be
|
||||
// passed to the function in the form of a map.
|
||||
func NewRangeIndex(ranges map[uint64]uint64,
|
||||
opts ...RangeIndexOption) (*RangeIndex, error) {
|
||||
|
||||
index := &RangeIndex{
|
||||
serializeUint64: defaultSerializeUint64,
|
||||
set: make([]rangeItem, 0),
|
||||
}
|
||||
|
||||
// Apply any functional options.
|
||||
for _, o := range opts {
|
||||
o(index)
|
||||
}
|
||||
|
||||
for s, e := range ranges {
|
||||
if err := index.addRange(s, e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// addRange can be used to add an entire new range to the set. This method
|
||||
// should only ever be called by NewRangeIndex to initialise the in-memory
|
||||
// structure and so the RangeIndex mutex is not held during this method.
|
||||
func (a *RangeIndex) addRange(start, end uint64) error {
|
||||
// Check that the given range is valid.
|
||||
if start > end {
|
||||
return fmt.Errorf("invalid range. Start height %d is larger "+
|
||||
"than end height %d", start, end)
|
||||
}
|
||||
|
||||
// min is a helper closure that will return the minimum of two uint64s.
|
||||
min := func(a, b uint64) uint64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// max is a helper closure that will return the maximum of two uint64s.
|
||||
max := func(a, b uint64) uint64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Collect the ranges that fall before and after the new range along
|
||||
// with the start and end values of the new range.
|
||||
var before, after []rangeItem
|
||||
for _, x := range a.set {
|
||||
// If the new start value can't extend the current ranges end
|
||||
// value, then the two cannot be merged. The range is added to
|
||||
// the group of ranges that fall before the new range.
|
||||
if x.end+1 < start {
|
||||
before = append(before, x)
|
||||
continue
|
||||
}
|
||||
|
||||
// If the current ranges start value does not follow on directly
|
||||
// from the new end value, then the two cannot be merged. The
|
||||
// range is added to the group of ranges that fall after the new
|
||||
// range.
|
||||
if end+1 < x.start {
|
||||
after = append(after, x)
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, there is an overlap and so the two can be merged.
|
||||
start = min(start, x.start)
|
||||
end = max(end, x.end)
|
||||
}
|
||||
|
||||
// Re-construct the range index set.
|
||||
a.set = append(append(before, rangeItem{
|
||||
start: start,
|
||||
end: end,
|
||||
}), after...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsInIndex returns true if the given number is in the range set.
|
||||
func (a *RangeIndex) IsInIndex(n uint64) bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
_, isCovered := a.lowerBoundIndex(n)
|
||||
|
||||
return isCovered
|
||||
}
|
||||
|
||||
// NumInSet returns the number of items covered by the range set.
|
||||
func (a *RangeIndex) NumInSet() uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
var numItems uint64
|
||||
for _, r := range a.set {
|
||||
numItems += r.end - r.start + 1
|
||||
}
|
||||
|
||||
return numItems
|
||||
}
|
||||
|
||||
// MaxHeight returns the highest number covered in the range.
|
||||
func (a *RangeIndex) MaxHeight() uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if len(a.set) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return a.set[len(a.set)-1].end
|
||||
}
|
||||
|
||||
// GetAllRanges returns a copy of the range set in the form of a map.
|
||||
func (a *RangeIndex) GetAllRanges() map[uint64]uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
cp := make(map[uint64]uint64, len(a.set))
|
||||
for _, item := range a.set {
|
||||
cp[item.start] = item.end
|
||||
}
|
||||
|
||||
return cp
|
||||
}
|
||||
|
||||
// lowerBoundIndex returns the index of the RangeIndex that is most appropriate
|
||||
// for the new value, n. In other words, it returns the index of the rangeItem
|
||||
// set of the range where the start value is the highest start value in the set
|
||||
// that is still lower than or equal to the given number, n. The returned
|
||||
// boolean is true if the given number is already covered in the RangeIndex.
|
||||
// A returned index of -1 indicates that no lower bound range exists in the set.
|
||||
// Since the most likely case is that the new number will just extend the
|
||||
// highest range, a check is first done to see if this is the case which will
|
||||
// make the methods' computational complexity O(1). Otherwise, a binary search
|
||||
// is done which brings the computational complexity to O(log N).
|
||||
func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) {
|
||||
// If the set is empty, then there is no such index and the value
|
||||
// definitely is not in the set.
|
||||
if len(a.set) == 0 {
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// In most cases, the last index item will be the one we want. So just
|
||||
// do a quick check on that index first to avoid doing the binary
|
||||
// search.
|
||||
lastIndex := len(a.set) - 1
|
||||
lastRange := a.set[lastIndex]
|
||||
if lastRange.start <= n {
|
||||
return lastIndex, lastRange.end >= n
|
||||
}
|
||||
|
||||
// Otherwise, do a binary search to find the index of interest.
|
||||
var (
|
||||
low = 0
|
||||
high = len(a.set) - 1
|
||||
rangeIndex = -1
|
||||
)
|
||||
for {
|
||||
mid := (low + high) / 2
|
||||
currentRange := a.set[mid]
|
||||
|
||||
switch {
|
||||
case currentRange.start > n:
|
||||
// If the start of the range is greater than n, we can
|
||||
// completely cut out that entire part of the array.
|
||||
high = mid
|
||||
|
||||
case currentRange.start < n:
|
||||
// If the range already includes the given height, we
|
||||
// can stop searching now.
|
||||
if currentRange.end >= n {
|
||||
return mid, true
|
||||
}
|
||||
|
||||
// If the start of the range is smaller than n, we can
|
||||
// store this as the new best index to return.
|
||||
rangeIndex = mid
|
||||
|
||||
// If low and mid are already equal, then increment low
|
||||
// by 1. Exit if this means that low is now greater than
|
||||
// high.
|
||||
if low == mid {
|
||||
low = mid + 1
|
||||
if low > high {
|
||||
return rangeIndex, false
|
||||
}
|
||||
} else {
|
||||
low = mid
|
||||
}
|
||||
|
||||
continue
|
||||
|
||||
default:
|
||||
// If the height is equal to the start value of the
|
||||
// current range that mid is pointing to, then the
|
||||
// height is already covered.
|
||||
return mid, true
|
||||
}
|
||||
|
||||
// Exit if we have checked all the ranges.
|
||||
if low == high {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rangeIndex, false
|
||||
}
|
||||
|
||||
// KVStore is an interface representing a key-value store.
|
||||
type KVStore interface {
|
||||
// Put saves the specified key/value pair to the store. Keys that do not
|
||||
// already exist are added and keys that already exist are overwritten.
|
||||
Put(key, value []byte) error
|
||||
|
||||
// Delete removes the specified key from the bucket. Deleting a key that
|
||||
// does not exist does not return an error.
|
||||
Delete(key []byte) error
|
||||
}
|
||||
|
||||
// Add adds a single number to the range set. It first attempts to apply the
|
||||
// necessary changes to the passed KV store and then only if this succeeds, will
|
||||
// the changes be applied to the in-memory structure.
|
||||
func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Compute the changes that will need to be applied to both the sorted
|
||||
// rangeItem array representation and the key-value store representation
|
||||
// of the range index.
|
||||
arrayChanges, kvStoreChanges := a.getChanges(newHeight)
|
||||
|
||||
// First attempt to apply the KV store changes. Only if this succeeds
|
||||
// will we apply the changes to our in-memory range index structure.
|
||||
err := a.applyKVChanges(kv, kvStoreChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since the DB changes were successful, we can now commit the
|
||||
// changes to our in-memory representation of the range set.
|
||||
a.applyArrayChanges(arrayChanges)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyKVChanges applies the given set of kvChanges to a KV store. It is
|
||||
// assumed that a transaction is being held on the kv store so that if any
|
||||
// of the actions of the function fails, the changes will be reverted.
|
||||
func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error {
|
||||
// Exit early if there are no changes to apply.
|
||||
if kv == nil || changes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if any range pair needs to be deleted.
|
||||
if changes.deleteKVKey != nil {
|
||||
del, err := a.serializeUint64(*changes.deleteKVKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := kv.Delete(del); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
start, err := a.serializeUint64(changes.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
end, err := a.serializeUint64(changes.value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return kv.Put(start, end)
|
||||
}
|
||||
|
||||
// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex
|
||||
// itself. This should only be done once the persisted kv store changes have
|
||||
// already been applied.
|
||||
func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) {
|
||||
if changes == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if changes.indexToDelete != nil {
|
||||
a.set = append(
|
||||
a.set[:*changes.indexToDelete],
|
||||
a.set[*changes.indexToDelete+1:]...,
|
||||
)
|
||||
}
|
||||
|
||||
if changes.newIndex != nil {
|
||||
switch {
|
||||
case *changes.newIndex == 0:
|
||||
a.set = append([]rangeItem{{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}}, a.set...)
|
||||
|
||||
case *changes.newIndex == len(a.set):
|
||||
a.set = append(a.set, rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
})
|
||||
|
||||
default:
|
||||
a.set = append(
|
||||
a.set[:*changes.newIndex+1],
|
||||
a.set[*changes.newIndex:]...,
|
||||
)
|
||||
a.set[*changes.newIndex] = rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if changes.indexToEdit != nil {
|
||||
a.set[*changes.indexToEdit] = rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// arrayChanges encompasses the diff to apply to the sorted rangeItem array
|
||||
// representation of a range index. Such a diff will either include adding a
|
||||
// new range or editing an existing range. If an existing range is edited, then
|
||||
// the diff might also include deleting an index (this will be the case if the
|
||||
// editing of the one range results in the merge of another range).
|
||||
type arrayChanges struct {
|
||||
start uint64
|
||||
end uint64
|
||||
|
||||
// newIndex, if set, is the index of the in-memory range array where a
|
||||
// new range, [start:end], should be added. newIndex should never be
|
||||
// set at the same time as indexToEdit or indexToDelete.
|
||||
newIndex *int
|
||||
|
||||
// indexToDelete, if set, is the index of the sorted rangeItem array
|
||||
// that should be deleted. This should be applied before reading the
|
||||
// index value of indexToEdit. This should not be set at the same time
|
||||
// as newIndex.
|
||||
indexToDelete *int
|
||||
|
||||
// indexToEdit is the index of the in-memory range array that should be
|
||||
// edited. The range at this index will be changed to [start:end]. This
|
||||
// should only be read after indexToDelete index has been deleted.
|
||||
indexToEdit *int
|
||||
}
|
||||
|
||||
// kvChanges encompasses the diff to apply to a KV-store representation of a
|
||||
// range index. A kv-store diff for the addition of a single number to the range
|
||||
// index will include either a brand new key-value pair or the altering of the
|
||||
// value of an existing key. Optionally, the diff may also include the deletion
|
||||
// of an existing key. A deletion will be required if the addition of the new
|
||||
// number results in the merge of two ranges.
|
||||
type kvChanges struct {
|
||||
key uint64
|
||||
value uint64
|
||||
|
||||
// deleteKVKey, if set, is the key of the kv store representation that
|
||||
// should be deleted.
|
||||
deleteKVKey *uint64
|
||||
}
|
||||
|
||||
// getChanges will calculate and return the changes that need to be applied to
|
||||
// both the sorted-rangeItem-array representation and the key-value store
|
||||
// representation of the range index.
|
||||
func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) {
|
||||
// If the set is empty then a new range item is added.
|
||||
if len(a.set) == 0 {
|
||||
// For the array representation, a new range [n:n] is added to
|
||||
// the first index of the array.
|
||||
firstIndex := 0
|
||||
ac := &arrayChanges{
|
||||
newIndex: &firstIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
// For the KV representation, a new [n:n] pair is added.
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Find the index of the lower bound range to the new number.
|
||||
indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n)
|
||||
|
||||
switch {
|
||||
// The new number is already covered by the range index. No changes are
|
||||
// required.
|
||||
case alreadyCovered:
|
||||
return nil, nil
|
||||
|
||||
// No lower bound index exists.
|
||||
case indexOfRangeBelow < 0:
|
||||
// Check if the very first range can be merged into this new
|
||||
// one.
|
||||
if n+1 == a.set[0].start {
|
||||
// If so, the two ranges can be merged and so the start
|
||||
// value of the range is n and the end value is the end
|
||||
// of the existing first range.
|
||||
start := n
|
||||
end := a.set[0].end
|
||||
|
||||
// For the array representation, we can just edit the
|
||||
// first entry of the array
|
||||
editIndex := 0
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &editIndex,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the KV store representation, we add a new kv pair
|
||||
// and delete the range with the key equal to the start
|
||||
// value of the range we are merging.
|
||||
kvKeyToDelete := a.set[0].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &kvKeyToDelete,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Otherwise, we add a new index.
|
||||
|
||||
// For the array representation, a new range [n:n] is added to
|
||||
// the first index of the array.
|
||||
newIndex := 0
|
||||
ac := &arrayChanges{
|
||||
newIndex: &newIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
// For the KV representation, a new [n:n] pair is added.
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
|
||||
// A lower range does exist, and it can be extended to include this new
|
||||
// number.
|
||||
case a.set[indexOfRangeBelow].end+1 == n:
|
||||
start := a.set[indexOfRangeBelow].start
|
||||
end := n
|
||||
indexToChange := indexOfRangeBelow
|
||||
|
||||
// If there are no intervals above this one or if there are, but
|
||||
// they can't be merged into this one then we just need to edit
|
||||
// this interval.
|
||||
if indexOfRangeBelow == len(a.set)-1 ||
|
||||
a.set[indexOfRangeBelow+1].start != n+1 {
|
||||
|
||||
// For the array representation, we just edit the index.
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &indexToChange,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the key-value representation, we just overwrite
|
||||
// the end value at the existing start key.
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// There is a range above this one that we need to merge into
|
||||
// this one.
|
||||
delIndex := indexOfRangeBelow + 1
|
||||
end = a.set[delIndex].end
|
||||
|
||||
// For the array representation, we delete the range above this
|
||||
// one and edit this range to include the end value of the range
|
||||
// above.
|
||||
ac := &arrayChanges{
|
||||
indexToDelete: &delIndex,
|
||||
indexToEdit: &indexToChange,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the kv representation, we tweak the end value of an
|
||||
// existing key and delete the key of the range we are deleting.
|
||||
deleteKey := a.set[delIndex].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &deleteKey,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
|
||||
// A lower range does exist, but it can't be extended to include this
|
||||
// new number, and so we need to add a new range after the lower bound
|
||||
// range.
|
||||
default:
|
||||
newIndex := indexOfRangeBelow + 1
|
||||
|
||||
// If there are no ranges above this new one or if there are,
|
||||
// but they can't be merged into this new one, then we can just
|
||||
// add the new one as is.
|
||||
if newIndex == len(a.set) || a.set[newIndex].start != n+1 {
|
||||
ac := &arrayChanges{
|
||||
newIndex: &newIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Else, we merge the above index.
|
||||
start := n
|
||||
end := a.set[newIndex].end
|
||||
toEdit := newIndex
|
||||
|
||||
// For the array representation, we edit the range above to
|
||||
// include the new start value.
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &toEdit,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the kv representation, we insert the new start-end key
|
||||
// value pair and delete the key using the old start value.
|
||||
delKey := a.set[newIndex].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &delKey,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
}
|
||||
|
||||
func defaultSerializeUint64(i uint64) ([]byte, error) {
|
||||
var b [8]byte
|
||||
byteOrder.PutUint64(b[:], i)
|
||||
return b[:], nil
|
||||
}
|
@ -80,6 +80,7 @@ type DiskQueueDB[T Serializable] struct {
|
||||
db kvdb.Backend
|
||||
topLevelBkt []byte
|
||||
constructor func() T
|
||||
onItemWrite func(tx kvdb.RwTx, item T) error
|
||||
}
|
||||
|
||||
// A compile-time check to ensure that DiskQueueDB implements the Queue
|
||||
@ -89,12 +90,14 @@ var _ Queue[Serializable] = (*DiskQueueDB[Serializable])(nil)
|
||||
// NewQueueDB constructs a new DiskQueueDB. A queueBktName must be provided so
|
||||
// that the DiskQueueDB can create its own namespace in the bolt db.
|
||||
func NewQueueDB[T Serializable](db kvdb.Backend, queueBktName []byte,
|
||||
constructor func() T) Queue[T] {
|
||||
constructor func() T,
|
||||
onItemWrite func(tx kvdb.RwTx, item T) error) Queue[T] {
|
||||
|
||||
return &DiskQueueDB[T]{
|
||||
db: db,
|
||||
topLevelBkt: queueBktName,
|
||||
constructor: constructor,
|
||||
onItemWrite: onItemWrite,
|
||||
}
|
||||
}
|
||||
|
||||
@ -279,6 +282,13 @@ func (d *DiskQueueDB[T]) addItem(tx kvdb.RwTx, queueName []byte, item T) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if d.onItemWrite != nil {
|
||||
err = d.onItemWrite(tx, item)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// Find the index to use for placing this new item at the back of the
|
||||
// queue.
|
||||
var nextIndex uint64
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -31,6 +32,14 @@ func TestDiskQueue(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
// In order to test that the queue's `onItemWrite` call back (which in
|
||||
// this case will be set to maybeUpdateMaxCommitHeight) is executed as
|
||||
// expected, we need to register a channel so that we can later assert
|
||||
// that it's max height field was updated properly.
|
||||
var chanID lnwire.ChannelID
|
||||
err = db.RegisterChannel(chanID, []byte{})
|
||||
require.NoError(t, err)
|
||||
|
||||
namespace := []byte("test-namespace")
|
||||
queue := db.GetDBQueue(namespace)
|
||||
|
||||
@ -110,4 +119,19 @@ func TestDiskQueue(t *testing.T) {
|
||||
// This should not have changed the order of the tasks, they should
|
||||
// still appear in the correct order.
|
||||
popAndAssert(task1, task2, task3, task4, task5, task6)
|
||||
|
||||
// Finally, we check that the `onItemWrite` call back was executed by
|
||||
// the queue. We do this by checking that the channel's recorded max
|
||||
// commitment height was set correctly. It should be equal to the height
|
||||
// recorded in task6.
|
||||
infos, err := db.FetchChanInfos()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, infos, 1)
|
||||
|
||||
info, ok := infos[chanID]
|
||||
require.True(t, ok)
|
||||
require.True(t, info.MaxHeight.IsSome())
|
||||
info.MaxHeight.WhenSome(func(height uint64) {
|
||||
require.EqualValues(t, task6.CommitHeight, height)
|
||||
})
|
||||
}
|
||||
|
@ -12,6 +12,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration5"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration6"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration7"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration8"
|
||||
)
|
||||
|
||||
// txMigration is a function which takes a prior outdated version of the
|
||||
@ -67,6 +68,9 @@ var clientDBVersions = []version{
|
||||
{
|
||||
txMigration: migration7.MigrateChannelToSessionIndex,
|
||||
},
|
||||
{
|
||||
txMigration: migration8.MigrateChannelMaxHeights,
|
||||
},
|
||||
}
|
||||
|
||||
// getLatestDBVersion returns the last known database version.
|
||||
|
Loading…
x
Reference in New Issue
Block a user