mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-28 01:32:27 +02:00
watchtower/wtdb: add migration code for AckedUpdates
In this commit, the code for migration 4 is added. This migration takes all the existing session acked updates and migrates them to be stored in the RangeIndex form instead. Note that this migration is not activated in this commit. This is done in a follow up commit in order to keep this one smaller.
This commit is contained in:
parent
870a91a1e8
commit
50ad10666c
628
watchtower/wtdb/migration4/client_db.go
Normal file
628
watchtower/wtdb/migration4/client_db.go
Normal file
@ -0,0 +1,628 @@
|
||||
package migration4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
var (
|
||||
// cChanDetailsBkt is a top-level bucket storing:
|
||||
// channel-id => cChannelSummary -> encoded ClientChanSummary.
|
||||
// => cChanDBID -> db-assigned-id
|
||||
cChanDetailsBkt = []byte("client-channel-detail-bucket")
|
||||
|
||||
// cChanDBID is a key used in the cChanDetailsBkt to store the
|
||||
// db-assigned-id of a channel.
|
||||
cChanDBID = []byte("client-channel-db-id")
|
||||
|
||||
// cSessionBkt is a top-level bucket storing:
|
||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||
// => cSessionAcks => seqnum -> encoded BackupID
|
||||
cSessionBkt = []byte("client-session-bucket")
|
||||
|
||||
// cSessionAcks is a sub-bucket of cSessionBkt storing:
|
||||
// seqnum -> encoded BackupID.
|
||||
cSessionAcks = []byte("client-session-acks")
|
||||
|
||||
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing:
|
||||
// chan-id => start -> end
|
||||
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
|
||||
|
||||
// ErrUninitializedDB signals that top-level buckets for the database
|
||||
// have not been initialized.
|
||||
ErrUninitializedDB = errors.New("db not initialized")
|
||||
|
||||
// ErrClientSessionNotFound signals that the requested client session
|
||||
// was not found in the database.
|
||||
ErrClientSessionNotFound = errors.New("client session not found")
|
||||
|
||||
// ErrCorruptChanDetails signals that the clients channel detail's
|
||||
// on-disk structure deviates from what is expected.
|
||||
ErrCorruptChanDetails = errors.New("channel details corrupted")
|
||||
|
||||
// ErrChannelNotRegistered signals a channel has not yet been registered
|
||||
// in the client database.
|
||||
ErrChannelNotRegistered = errors.New("channel not registered")
|
||||
|
||||
// byteOrder is the default endianness used when serializing integers.
|
||||
byteOrder = binary.BigEndian
|
||||
|
||||
// errExit is an error used to signal that the sessionIterator should
|
||||
// exit.
|
||||
errExit = errors.New("the exit condition has been met")
|
||||
)
|
||||
|
||||
// DefaultSessionsPerTx is the default number of sessions that should be
|
||||
// migrated per db transaction.
|
||||
const DefaultSessionsPerTx = 5000
|
||||
|
||||
// MigrateAckedUpdates migrates the tower client DB. It takes the individual
|
||||
// Acked Updates that are stored for each session and re-stores them using the
|
||||
// RangeIndex representation.
|
||||
func MigrateAckedUpdates(sessionsPerTx int) func(kvdb.Backend) error {
|
||||
return func(db kvdb.Backend) error {
|
||||
log.Infof("Migrating the tower client db to move all Acked " +
|
||||
"Updates to the new Range Index representation.")
|
||||
|
||||
// Migrate the old acked-updates.
|
||||
err := migrateAckedUpdates(db, sessionsPerTx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("migration failed: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("Migrating old session acked updates finished, now " +
|
||||
"checking the migration results...")
|
||||
|
||||
// Before we can safety delete the old buckets, we perform a
|
||||
// check to make sure the sessions have been migrated as
|
||||
// expected.
|
||||
err = kvdb.View(db, validateMigration, func() {})
|
||||
if err != nil {
|
||||
return fmt.Errorf("validate migration failed: %w", err)
|
||||
}
|
||||
|
||||
// Delete old acked updates.
|
||||
err = kvdb.Update(db, deleteOldAckedUpdates, func() {})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old acked "+
|
||||
"updates: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// migrateAckedUpdates migrates the acked updates of each session in the
|
||||
// wtclient db into the new RangeIndex form. This is done over multiple db
|
||||
// transactions in order to prevent the migration from taking up too much RAM.
|
||||
// The sessionsPerTx parameter can be used to set the maximum number of sessions
|
||||
// that should be migrated per transaction.
|
||||
func migrateAckedUpdates(db kvdb.Backend, sessionsPerTx int) error {
|
||||
// Get migration progress stats.
|
||||
total, migrated, err := logMigrationStats(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
log.Infof("Total sessions=%d, migrated=%d", total, migrated)
|
||||
|
||||
// Exit early if the old session acked updates have already been
|
||||
// migrated and deleted.
|
||||
if total == 0 {
|
||||
log.Info("Migration already finished!")
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
finished bool
|
||||
startKey []byte
|
||||
)
|
||||
for {
|
||||
// Process the migration.
|
||||
err = kvdb.Update(db, func(tx kvdb.RwTx) error {
|
||||
startKey, finished, err = processMigration(
|
||||
tx, startKey, sessionsPerTx,
|
||||
)
|
||||
|
||||
return err
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
|
||||
// Each time we finished the above process, we'd read the stats
|
||||
// again to understand the current progress.
|
||||
total, migrated, err = logMigrationStats(db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Calculate and log the progress if the progress is less than
|
||||
// one hundred percent.
|
||||
//nolint:gomnd
|
||||
progress := float64(migrated) / float64(total) * 100
|
||||
if progress >= 100 { //nolint:gomnd
|
||||
break
|
||||
}
|
||||
|
||||
log.Infof("Migration progress: %.3f%%, still have: %d",
|
||||
progress, total-migrated)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateMigration(tx kvdb.RTx) error {
|
||||
mainSessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if mainSessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return mainSessionsBkt.ForEach(func(sessID, _ []byte) error {
|
||||
// Get the bucket for this particular session.
|
||||
sessionBkt := mainSessionsBkt.NestedReadBucket(sessID)
|
||||
if sessionBkt == nil {
|
||||
return ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Get the bucket where any old acked updates would be stored.
|
||||
oldAcksBucket := sessionBkt.NestedReadBucket(cSessionAcks)
|
||||
|
||||
// Get the bucket where any new acked updates would be stored.
|
||||
newAcksBucket := sessionBkt.NestedReadBucket(
|
||||
cSessionAckRangeIndex,
|
||||
)
|
||||
|
||||
switch {
|
||||
// If both the old and new acked updates buckets are nil, then
|
||||
// we can safely skip this session.
|
||||
case oldAcksBucket == nil && newAcksBucket == nil:
|
||||
return nil
|
||||
|
||||
case oldAcksBucket == nil:
|
||||
return fmt.Errorf("no old acks but do have new acks")
|
||||
|
||||
case newAcksBucket == nil:
|
||||
return fmt.Errorf("no new acks but have old acks")
|
||||
|
||||
default:
|
||||
}
|
||||
|
||||
// Collect acked ranges for this session.
|
||||
ackedRanges := make(map[uint64]*RangeIndex)
|
||||
err := newAcksBucket.ForEach(func(dbChanID, _ []byte) error {
|
||||
rangeIndexBkt := newAcksBucket.NestedReadBucket(
|
||||
dbChanID,
|
||||
)
|
||||
if rangeIndexBkt == nil {
|
||||
return fmt.Errorf("no acked updates bucket "+
|
||||
"found for channel %x", dbChanID)
|
||||
}
|
||||
|
||||
// Read acked ranges from new bucket.
|
||||
ri, err := readRangeIndex(rangeIndexBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbChanIDNum, err := readBigSize(dbChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ackedRanges[dbChanIDNum] = ri
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now we will iterate through each of the old acked updates and
|
||||
// make sure that the update appears in the new bucket.
|
||||
return oldAcksBucket.ForEach(func(_, v []byte) error {
|
||||
var backupID BackupID
|
||||
err := backupID.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbChanID, _, err := getDBChanID(
|
||||
chanDetailsBkt, backupID.ChanID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
index, ok := ackedRanges[dbChanID]
|
||||
if !ok {
|
||||
return fmt.Errorf("no index found for this " +
|
||||
"channel")
|
||||
}
|
||||
|
||||
if !index.IsInIndex(backupID.CommitHeight) {
|
||||
return fmt.Errorf("commit height not found " +
|
||||
"in index")
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) {
|
||||
ranges := make(map[uint64]uint64)
|
||||
err := rangesBkt.ForEach(func(k, v []byte) error {
|
||||
start, err := readBigSize(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
end, err := readBigSize(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ranges[start] = end
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize))
|
||||
}
|
||||
|
||||
func deleteOldAckedUpdates(tx kvdb.RwTx) error {
|
||||
mainSessionsBkt := tx.ReadWriteBucket(cSessionBkt)
|
||||
if mainSessionsBkt == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
return mainSessionsBkt.ForEach(func(sessID, _ []byte) error {
|
||||
// Get the bucket for this particular session.
|
||||
sessionBkt := mainSessionsBkt.NestedReadWriteBucket(
|
||||
sessID,
|
||||
)
|
||||
if sessionBkt == nil {
|
||||
return ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Get the bucket where any old acked updates would be stored.
|
||||
oldAcksBucket := sessionBkt.NestedReadBucket(cSessionAcks)
|
||||
if oldAcksBucket == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Now that we have read everything that we need to from
|
||||
// the cSessionAcks sub-bucket, we can delete it.
|
||||
return sessionBkt.DeleteNestedBucket(cSessionAcks)
|
||||
})
|
||||
}
|
||||
|
||||
// processMigration uses the given transaction to perform a maximum of
|
||||
// sessionsPerTx session migrations. If startKey is non-nil, it is used to
|
||||
// determine the first session to start the migration at. The first return
|
||||
// item is the key of the last session that was migrated successfully and the
|
||||
// boolean is true if there are no more sessions left to migrate.
|
||||
func processMigration(tx kvdb.RwTx, startKey []byte, sessionsPerTx int) ([]byte,
|
||||
bool, error) {
|
||||
|
||||
chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt)
|
||||
if chanDetailsBkt == nil {
|
||||
return nil, false, ErrUninitializedDB
|
||||
}
|
||||
|
||||
// sessionCount keeps track of the number of sessions that have been
|
||||
// migrated under the current db transaction.
|
||||
var sessionCount int
|
||||
|
||||
// migrateSessionCB is a callback function that calls migrateSession
|
||||
// in order to migrate a single session. Upon success, the sessionCount
|
||||
// is incremented and is then compared against sessionsPerTx to
|
||||
// determine if we should continue migrating more sessions in this db
|
||||
// transaction.
|
||||
migrateSessionCB := func(sessionBkt kvdb.RwBucket) error {
|
||||
err := migrateSession(chanDetailsBkt, sessionBkt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionCount++
|
||||
|
||||
// If we have migrated sessionsPerTx sessions in this tx, then
|
||||
// we return errExit in order to signal that this tx should be
|
||||
// committed and the migration should be continued in a new
|
||||
// transaction.
|
||||
if sessionCount >= sessionsPerTx {
|
||||
return errExit
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Starting at startKey, iterate over the sessions in the db and migrate
|
||||
// them until either all are migrated or until the errExit signal is
|
||||
// received.
|
||||
lastKey, err := sessionIterator(tx, startKey, migrateSessionCB)
|
||||
if err != nil && errors.Is(err, errExit) {
|
||||
return lastKey, false, nil
|
||||
} else if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
// The migration is complete.
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
// migrateSession migrates a single session's acked-updates to the new
|
||||
// RangeIndex form.
|
||||
func migrateSession(chanDetailsBkt kvdb.RBucket,
|
||||
sessionBkt kvdb.RwBucket) error {
|
||||
|
||||
// Get the existing cSessionAcks bucket. If there is no such bucket,
|
||||
// then there are no acked-updates to migrate for this session.
|
||||
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
|
||||
if sessionAcks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If there is already a new cSessionAckedRangeIndex bucket, then this
|
||||
// session has already been migrated.
|
||||
sessionAckRangesBkt := sessionBkt.NestedReadBucket(
|
||||
cSessionAckRangeIndex,
|
||||
)
|
||||
if sessionAckRangesBkt != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Otherwise, we will iterate over each of the acked-updates, and we
|
||||
// will construct a new RangeIndex for each channel.
|
||||
m := make(map[ChannelID]*RangeIndex)
|
||||
if err := sessionAcks.ForEach(func(_, v []byte) error {
|
||||
var backupID BackupID
|
||||
err := backupID.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, ok := m[backupID.ChanID]; !ok {
|
||||
index, err := NewRangeIndex(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m[backupID.ChanID] = index
|
||||
}
|
||||
|
||||
return m[backupID.ChanID].Add(backupID.CommitHeight, nil)
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create a new sub-bucket that will be used to store the new RangeIndex
|
||||
// representation of the acked updates.
|
||||
ackRangeBkt, err := sessionBkt.CreateBucket(cSessionAckRangeIndex)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Iterate over each of the new range indexes that we will add for this
|
||||
// session.
|
||||
for chanID, rangeIndex := range m {
|
||||
// Get db chanID.
|
||||
chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return ErrCorruptChanDetails
|
||||
}
|
||||
|
||||
// Create a sub-bucket for this channel using the db-assigned ID
|
||||
// for the channel.
|
||||
dbChanID := chanDetails.Get(cChanDBID)
|
||||
chanAcksBkt, err := ackRangeBkt.CreateBucket(dbChanID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Iterate over the range pairs that we need to add to the DB.
|
||||
for k, v := range rangeIndex.GetAllRanges() {
|
||||
start, err := writeBigSize(k)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
end, err := writeBigSize(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = chanAcksBkt.Put(start, end)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// logMigrationStats reads the buckets to provide stats over current migration
|
||||
// progress. The returned values are the numbers of total records and already
|
||||
// migrated records.
|
||||
func logMigrationStats(db kvdb.Backend) (uint64, uint64, error) {
|
||||
var (
|
||||
err error
|
||||
total uint64
|
||||
unmigrated uint64
|
||||
)
|
||||
|
||||
err = kvdb.View(db, func(tx kvdb.RTx) error {
|
||||
total, unmigrated, err = getMigrationStats(tx)
|
||||
|
||||
return err
|
||||
}, func() {})
|
||||
|
||||
log.Debugf("Total sessions=%d, unmigrated=%d", total, unmigrated)
|
||||
|
||||
return total, total - unmigrated, err
|
||||
}
|
||||
|
||||
// getMigrationStats iterates over all sessions. It counts the total number of
|
||||
// sessions as well as the total number of unmigrated sessions.
|
||||
func getMigrationStats(tx kvdb.RTx) (uint64, uint64, error) {
|
||||
var (
|
||||
total uint64
|
||||
unmigrated uint64
|
||||
)
|
||||
|
||||
// Get sessions bucket.
|
||||
mainSessionsBkt := tx.ReadBucket(cSessionBkt)
|
||||
if mainSessionsBkt == nil {
|
||||
return 0, 0, ErrUninitializedDB
|
||||
}
|
||||
|
||||
// Iterate over each session ID in the bucket.
|
||||
err := mainSessionsBkt.ForEach(func(sessID, _ []byte) error {
|
||||
// Get the bucket for this particular session.
|
||||
sessionBkt := mainSessionsBkt.NestedReadBucket(sessID)
|
||||
if sessionBkt == nil {
|
||||
return ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
total++
|
||||
|
||||
// Get the cSessionAckRangeIndex bucket.
|
||||
sessionAcksBkt := sessionBkt.NestedReadBucket(cSessionAcks)
|
||||
|
||||
// Get the cSessionAckRangeIndex bucket.
|
||||
sessionAckRangesBkt := sessionBkt.NestedReadBucket(
|
||||
cSessionAckRangeIndex,
|
||||
)
|
||||
|
||||
// If both buckets do not exist, then this session is empty and
|
||||
// does not need to be migrated.
|
||||
if sessionAckRangesBkt == nil && sessionAcksBkt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the sessionAckRangesBkt is not nil, then the session has
|
||||
// already been migrated.
|
||||
if sessionAckRangesBkt != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Else the session has not yet been migrated.
|
||||
unmigrated++
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
}
|
||||
|
||||
return total, unmigrated, nil
|
||||
}
|
||||
|
||||
// getDBChanID returns the db-assigned channel ID for the given real channel ID.
|
||||
// It returns both the uint64 and byte representation.
|
||||
func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID ChannelID) (uint64,
|
||||
[]byte, error) {
|
||||
|
||||
chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:])
|
||||
if chanDetails == nil {
|
||||
return 0, nil, ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
idBytes := chanDetails.Get(cChanDBID)
|
||||
if len(idBytes) == 0 {
|
||||
return 0, nil, fmt.Errorf("no db-assigned ID found for "+
|
||||
"channel ID %s", chanID)
|
||||
}
|
||||
|
||||
id, err := readBigSize(idBytes)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
|
||||
return id, idBytes, nil
|
||||
}
|
||||
|
||||
// callback defines a type that's used by the sessionIterator.
|
||||
type callback func(bkt kvdb.RwBucket) error
|
||||
|
||||
// sessionIterator is a helper function that iterates over the main sessions
|
||||
// bucket and performs the callback function on each individual session. If a
|
||||
// seeker is specified, it will move the cursor to the given position otherwise
|
||||
// it will start from the first item.
|
||||
func sessionIterator(tx kvdb.RwTx, seeker []byte, cb callback) ([]byte, error) {
|
||||
// Get sessions bucket.
|
||||
mainSessionsBkt := tx.ReadWriteBucket(cSessionBkt)
|
||||
if mainSessionsBkt == nil {
|
||||
return nil, ErrUninitializedDB
|
||||
}
|
||||
|
||||
c := mainSessionsBkt.ReadCursor()
|
||||
k, _ := c.First()
|
||||
|
||||
// Move the cursor to the specified position if seeker is non-nil.
|
||||
if seeker != nil {
|
||||
k, _ = c.Seek(seeker)
|
||||
}
|
||||
|
||||
// Start the iteration and exit on condition.
|
||||
for k := k; k != nil; k, _ = c.Next() {
|
||||
// Get the bucket for this particular session.
|
||||
bkt := mainSessionsBkt.NestedReadWriteBucket(k)
|
||||
if bkt == nil {
|
||||
return nil, ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Call the callback function with the session's bucket.
|
||||
if err := cb(bkt); err != nil {
|
||||
// return k, err
|
||||
lastIndex := make([]byte, len(k))
|
||||
copy(lastIndex, k)
|
||||
return lastIndex, err
|
||||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// writeBigSize will encode the given uint64 as a BigSize byte slice.
|
||||
func writeBigSize(i uint64) ([]byte, error) {
|
||||
var b bytes.Buffer
|
||||
err := tlv.WriteVarInt(&b, i, &[8]byte{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// readBigSize converts the given byte slice into a uint64 and assumes that the
|
||||
// bytes slice is using BigSize encoding.
|
||||
func readBigSize(b []byte) (uint64, error) {
|
||||
r := bytes.NewReader(b)
|
||||
i, err := tlv.ReadVarInt(r, &[8]byte{})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return i, nil
|
||||
}
|
329
watchtower/wtdb/migration4/client_db_test.go
Normal file
329
watchtower/wtdb/migration4/client_db_test.go
Normal file
@ -0,0 +1,329 @@
|
||||
package migration4
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"testing"
|
||||
|
||||
"github.com/btcsuite/btcwallet/walletdb"
|
||||
"github.com/lightningnetwork/lnd/channeldb/migtest"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
)
|
||||
|
||||
var (
|
||||
// details is the expected data of the channel details bucket. This
|
||||
// bucket should not be changed during the migration, but it is used
|
||||
// to find the db-assigned ID for each channel.
|
||||
details = map[string]interface{}{
|
||||
channelIDString(1): map[string]interface{}{
|
||||
string(cChanDBID): uint64ToStr(10),
|
||||
},
|
||||
channelIDString(2): map[string]interface{}{
|
||||
string(cChanDBID): uint64ToStr(20),
|
||||
},
|
||||
}
|
||||
|
||||
// preSessions is the expected data in the sessions bucket before the
|
||||
// migration.
|
||||
preSessions = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 30,
|
||||
}),
|
||||
"2": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 31,
|
||||
}),
|
||||
"3": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 32,
|
||||
}),
|
||||
"4": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 34,
|
||||
}),
|
||||
"5": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 30,
|
||||
}),
|
||||
},
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 33,
|
||||
}),
|
||||
},
|
||||
},
|
||||
sessionIDString("3"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 35,
|
||||
}),
|
||||
"2": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 36,
|
||||
}),
|
||||
"3": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 28,
|
||||
}),
|
||||
"4": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 29,
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// preMidStateDB is a possible state that the db could be in if the
|
||||
// migration started but was interrupted before completing. In this
|
||||
// state, some sessions still have the old cSessionAcks bucket along
|
||||
// with the new cSessionAckRangeIndex. This is a valid pre-state and
|
||||
// a migration on this state should succeed.
|
||||
preMidStateDB = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 30,
|
||||
}),
|
||||
"2": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 31,
|
||||
}),
|
||||
"3": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 32,
|
||||
}),
|
||||
"4": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 34,
|
||||
}),
|
||||
"5": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 30,
|
||||
}),
|
||||
},
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(32),
|
||||
uint64ToStr(34): uint64ToStr(34),
|
||||
},
|
||||
uint64ToStr(20): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(30),
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 33,
|
||||
}),
|
||||
},
|
||||
},
|
||||
sessionIDString("3"): map[string]interface{}{
|
||||
string(cSessionAcks): map[string]interface{}{
|
||||
"1": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 35,
|
||||
}),
|
||||
"2": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(1),
|
||||
CommitHeight: 36,
|
||||
}),
|
||||
"3": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 28,
|
||||
}),
|
||||
"4": backupIDToString(&BackupID{
|
||||
ChanID: intToChannelID(2),
|
||||
CommitHeight: 29,
|
||||
}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// preFailCorruptDB should fail the migration due to no session data
|
||||
// being found for a given session ID.
|
||||
preFailCorruptDB = map[string]interface{}{
|
||||
sessionIDString("2"): "",
|
||||
}
|
||||
|
||||
// postSessions is the expected data in the sessions bucket after the
|
||||
// migration.
|
||||
postSessions = map[string]interface{}{
|
||||
sessionIDString("1"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(32),
|
||||
uint64ToStr(34): uint64ToStr(34),
|
||||
},
|
||||
uint64ToStr(20): map[string]interface{}{
|
||||
uint64ToStr(30): uint64ToStr(30),
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionIDString("2"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(33): uint64ToStr(33),
|
||||
},
|
||||
},
|
||||
},
|
||||
sessionIDString("3"): map[string]interface{}{
|
||||
string(cSessionAckRangeIndex): map[string]interface{}{
|
||||
uint64ToStr(10): map[string]interface{}{
|
||||
uint64ToStr(35): uint64ToStr(36),
|
||||
},
|
||||
uint64ToStr(20): map[string]interface{}{
|
||||
uint64ToStr(28): uint64ToStr(29),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
// TestMigrateAckedUpdates tests that the MigrateAckedUpdates function correctly
|
||||
// migrates the existing AckedUpdates bucket for each session to the new
|
||||
// RangeIndex representation.
|
||||
func TestMigrateAckedUpdates(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
shouldFail bool
|
||||
pre map[string]interface{}
|
||||
post map[string]interface{}
|
||||
}{
|
||||
{
|
||||
name: "migration ok",
|
||||
shouldFail: false,
|
||||
pre: preSessions,
|
||||
post: postSessions,
|
||||
},
|
||||
{
|
||||
name: "migration ok after re-starting",
|
||||
shouldFail: false,
|
||||
pre: preMidStateDB,
|
||||
post: postSessions,
|
||||
},
|
||||
{
|
||||
name: "fail due to corrupt db",
|
||||
shouldFail: true,
|
||||
pre: preFailCorruptDB,
|
||||
},
|
||||
{
|
||||
name: "no sessions details",
|
||||
shouldFail: false,
|
||||
pre: nil,
|
||||
post: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
before := before(test.pre)
|
||||
|
||||
// After the migration, we should have an untouched
|
||||
// summary bucket and a new index bucket.
|
||||
after := after(test.shouldFail, test.pre, test.post)
|
||||
|
||||
migtest.ApplyMigrationWithDb(
|
||||
t, before, after, MigrateAckedUpdates(2),
|
||||
test.shouldFail,
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// before returns a call-back function that can be used to set up a db's
|
||||
// cChanDetailsBkt along with the cSessionBkt using the passed preMigDB
|
||||
// structure.
|
||||
func before(preMigDB map[string]interface{}) func(backend kvdb.Backend) error {
|
||||
return func(db kvdb.Backend) error {
|
||||
return db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
err := migtest.RestoreDB(
|
||||
tx, cChanDetailsBkt, details,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migtest.RestoreDB(
|
||||
tx, cSessionBkt, preMigDB,
|
||||
)
|
||||
}, func() {})
|
||||
}
|
||||
}
|
||||
|
||||
// after returns a call-back function that can be used to verify the state of
|
||||
// a db post migration.
|
||||
func after(shouldFail bool, preMigDB,
|
||||
postMigDB map[string]interface{}) func(backend kvdb.Backend) error {
|
||||
|
||||
return func(db kvdb.Backend) error {
|
||||
return db.Update(func(tx walletdb.ReadWriteTx) error {
|
||||
// The channel details bucket should remain untouched.
|
||||
err := migtest.VerifyDB(tx, cChanDetailsBkt, details)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If the migration fails, the sessions bucket should be
|
||||
// untouched.
|
||||
if shouldFail {
|
||||
if err := migtest.VerifyDB(
|
||||
tx, cSessionBkt, preMigDB,
|
||||
); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return migtest.VerifyDB(tx, cSessionBkt, postMigDB)
|
||||
}, func() {})
|
||||
}
|
||||
}
|
||||
|
||||
func sessionIDString(id string) string {
|
||||
var sessID SessionID
|
||||
copy(sessID[:], id)
|
||||
return sessID.String()
|
||||
}
|
||||
|
||||
func intToChannelID(id uint64) ChannelID {
|
||||
var chanID ChannelID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return chanID
|
||||
}
|
||||
|
||||
func channelIDString(id uint64) string {
|
||||
var chanID ChannelID
|
||||
byteOrder.PutUint64(chanID[:], id)
|
||||
return string(chanID[:])
|
||||
}
|
||||
|
||||
func uint64ToStr(id uint64) string {
|
||||
b, err := writeBigSize(id)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func backupIDToString(backup *BackupID) string {
|
||||
var b bytes.Buffer
|
||||
_ = backup.Encode(&b)
|
||||
return b.String()
|
||||
}
|
129
watchtower/wtdb/migration4/codec.go
Normal file
129
watchtower/wtdb/migration4/codec.go
Normal file
@ -0,0 +1,129 @@
|
||||
package migration4
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
// SessionIDSize is 33-bytes; it is a serialized, compressed public key.
|
||||
const SessionIDSize = 33
|
||||
|
||||
// SessionID is created from the remote public key of a client, and serves as a
|
||||
// unique identifier and authentication for sending state updates.
|
||||
type SessionID [SessionIDSize]byte
|
||||
|
||||
// String returns a hex encoding of the session id.
|
||||
func (s SessionID) String() string {
|
||||
return hex.EncodeToString(s[:])
|
||||
}
|
||||
|
||||
// ChannelID is a series of 32-bytes that uniquely identifies all channels
|
||||
// within the network. The ChannelID is computed using the outpoint of the
|
||||
// funding transaction (the txid, and output index). Given a funding output the
|
||||
// ChannelID can be calculated by XOR'ing the big-endian serialization of the
|
||||
// txid and the big-endian serialization of the output index, truncated to
|
||||
// 2 bytes.
|
||||
type ChannelID [32]byte
|
||||
|
||||
// String returns the string representation of the ChannelID. This is just the
|
||||
// hex string encoding of the ChannelID itself.
|
||||
func (c ChannelID) String() string {
|
||||
return hex.EncodeToString(c[:])
|
||||
}
|
||||
|
||||
// BackupID identifies a particular revoked, remote commitment by channel id and
|
||||
// commitment height.
|
||||
type BackupID struct {
|
||||
// ChanID is the channel id of the revoked commitment.
|
||||
ChanID ChannelID
|
||||
|
||||
// CommitHeight is the commitment height of the revoked commitment.
|
||||
CommitHeight uint64
|
||||
}
|
||||
|
||||
// Encode writes the BackupID from the passed io.Writer.
|
||||
func (b *BackupID) Encode(w io.Writer) error {
|
||||
return WriteElements(w,
|
||||
b.ChanID,
|
||||
b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// Decode reads a BackupID from the passed io.Reader.
|
||||
func (b *BackupID) Decode(r io.Reader) error {
|
||||
return ReadElements(r,
|
||||
&b.ChanID,
|
||||
&b.CommitHeight,
|
||||
)
|
||||
}
|
||||
|
||||
// String returns a human-readable encoding of a BackupID.
|
||||
func (b BackupID) String() string {
|
||||
return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight)
|
||||
}
|
||||
|
||||
// WriteElements serializes a variadic list of elements into the given
|
||||
// io.Writer.
|
||||
func WriteElements(w io.Writer, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := WriteElement(w, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElements deserializes the provided io.Reader into a variadic list of
|
||||
// target elements.
|
||||
func ReadElements(r io.Reader, elements ...interface{}) error {
|
||||
for _, element := range elements {
|
||||
if err := ReadElement(r, element); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WriteElement serializes a single element into the provided io.Writer.
|
||||
func WriteElement(w io.Writer, element interface{}) error {
|
||||
switch e := element.(type) {
|
||||
case ChannelID:
|
||||
if _, err := w.Write(e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case uint64:
|
||||
if err := binary.Write(w, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ReadElement deserializes a single element from the provided io.Reader.
|
||||
func ReadElement(r io.Reader, element interface{}) error {
|
||||
switch e := element.(type) {
|
||||
case *ChannelID:
|
||||
if _, err := io.ReadFull(r, e[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case *uint64:
|
||||
if err := binary.Read(r, byteOrder, e); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unexpected type")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
14
watchtower/wtdb/migration4/log.go
Normal file
14
watchtower/wtdb/migration4/log.go
Normal file
@ -0,0 +1,14 @@
|
||||
package migration4
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btclog"
|
||||
)
|
||||
|
||||
// log is a logger that is initialized as disabled. This means the package will
|
||||
// not perform any logging by default until a logger is set.
|
||||
var log = btclog.Disabled
|
||||
|
||||
// UseLogger uses a specified Logger to output package logging info.
|
||||
func UseLogger(logger btclog.Logger) {
|
||||
log = logger
|
||||
}
|
619
watchtower/wtdb/migration4/range_index.go
Normal file
619
watchtower/wtdb/migration4/range_index.go
Normal file
@ -0,0 +1,619 @@
|
||||
package migration4
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// rangeItem represents the start and end values of a range.
|
||||
type rangeItem struct {
|
||||
start uint64
|
||||
end uint64
|
||||
}
|
||||
|
||||
// RangeIndexOption describes the signature of a functional option that can be
|
||||
// used to modify the behaviour of a RangeIndex.
|
||||
type RangeIndexOption func(*RangeIndex)
|
||||
|
||||
// WithSerializeUint64Fn is a functional option that can be used to set the
|
||||
// function to be used to do the serialization of a uint64 into a byte slice.
|
||||
func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption {
|
||||
return func(index *RangeIndex) {
|
||||
index.serializeUint64 = fn
|
||||
}
|
||||
}
|
||||
|
||||
// RangeIndex can be used to keep track of which numbers have been added to a
|
||||
// set. It does so by keeping track of a sorted list of rangeItems. Each
|
||||
// rangeItem has a start and end value of a range where all values in-between
|
||||
// have been added to the set. It works well in situations where it is expected
|
||||
// numbers in the set are not sparse.
|
||||
type RangeIndex struct {
|
||||
// set is a sorted list of rangeItem.
|
||||
set []rangeItem
|
||||
|
||||
// mu is used to ensure safe access to set.
|
||||
mu sync.Mutex
|
||||
|
||||
// serializeUint64 is the function that can be used to convert a uint64
|
||||
// to a byte slice.
|
||||
serializeUint64 func(uint64) ([]byte, error)
|
||||
}
|
||||
|
||||
// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be
|
||||
// passed to the function in the form of a map.
|
||||
func NewRangeIndex(ranges map[uint64]uint64,
|
||||
opts ...RangeIndexOption) (*RangeIndex, error) {
|
||||
|
||||
index := &RangeIndex{
|
||||
serializeUint64: defaultSerializeUint64,
|
||||
set: make([]rangeItem, 0),
|
||||
}
|
||||
|
||||
// Apply any functional options.
|
||||
for _, o := range opts {
|
||||
o(index)
|
||||
}
|
||||
|
||||
for s, e := range ranges {
|
||||
if err := index.addRange(s, e); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// addRange can be used to add an entire new range to the set. This method
|
||||
// should only ever be called by NewRangeIndex to initialise the in-memory
|
||||
// structure and so the RangeIndex mutex is not held during this method.
|
||||
func (a *RangeIndex) addRange(start, end uint64) error {
|
||||
// Check that the given range is valid.
|
||||
if start > end {
|
||||
return fmt.Errorf("invalid range. Start height %d is larger "+
|
||||
"than end height %d", start, end)
|
||||
}
|
||||
|
||||
// min is a helper closure that will return the minimum of two uint64s.
|
||||
min := func(a, b uint64) uint64 {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// max is a helper closure that will return the maximum of two uint64s.
|
||||
max := func(a, b uint64) uint64 {
|
||||
if a > b {
|
||||
return a
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Collect the ranges that fall before and after the new range along
|
||||
// with the start and end values of the new range.
|
||||
var before, after []rangeItem
|
||||
for _, x := range a.set {
|
||||
// If the new start value can't extend the current ranges end
|
||||
// value, then the two cannot be merged. The range is added to
|
||||
// the group of ranges that fall before the new range.
|
||||
if x.end+1 < start {
|
||||
before = append(before, x)
|
||||
continue
|
||||
}
|
||||
|
||||
// If the current ranges start value does not follow on directly
|
||||
// from the new end value, then the two cannot be merged. The
|
||||
// range is added to the group of ranges that fall after the new
|
||||
// range.
|
||||
if end+1 < x.start {
|
||||
after = append(after, x)
|
||||
continue
|
||||
}
|
||||
|
||||
// Otherwise, there is an overlap and so the two can be merged.
|
||||
start = min(start, x.start)
|
||||
end = max(end, x.end)
|
||||
}
|
||||
|
||||
// Re-construct the range index set.
|
||||
a.set = append(append(before, rangeItem{
|
||||
start: start,
|
||||
end: end,
|
||||
}), after...)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsInIndex returns true if the given number is in the range set.
|
||||
func (a *RangeIndex) IsInIndex(n uint64) bool {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
_, isCovered := a.lowerBoundIndex(n)
|
||||
|
||||
return isCovered
|
||||
}
|
||||
|
||||
// NumInSet returns the number of items covered by the range set.
|
||||
func (a *RangeIndex) NumInSet() uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
var numItems uint64
|
||||
for _, r := range a.set {
|
||||
numItems += r.end - r.start + 1
|
||||
}
|
||||
|
||||
return numItems
|
||||
}
|
||||
|
||||
// MaxHeight returns the highest number covered in the range.
|
||||
func (a *RangeIndex) MaxHeight() uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
if len(a.set) == 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return a.set[len(a.set)-1].end
|
||||
}
|
||||
|
||||
// GetAllRanges returns a copy of the range set in the form of a map.
|
||||
func (a *RangeIndex) GetAllRanges() map[uint64]uint64 {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
cp := make(map[uint64]uint64, len(a.set))
|
||||
for _, item := range a.set {
|
||||
cp[item.start] = item.end
|
||||
}
|
||||
|
||||
return cp
|
||||
}
|
||||
|
||||
// lowerBoundIndex returns the index of the RangeIndex that is most appropriate
|
||||
// for the new value, n. In other words, it returns the index of the rangeItem
|
||||
// set of the range where the start value is the highest start value in the set
|
||||
// that is still lower than or equal to the given number, n. The returned
|
||||
// boolean is true if the given number is already covered in the RangeIndex.
|
||||
// A returned index of -1 indicates that no lower bound range exists in the set.
|
||||
// Since the most likely case is that the new number will just extend the
|
||||
// highest range, a check is first done to see if this is the case which will
|
||||
// make the methods' computational complexity O(1). Otherwise, a binary search
|
||||
// is done which brings the computational complexity to O(log N).
|
||||
func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) {
|
||||
// If the set is empty, then there is no such index and the value
|
||||
// definitely is not in the set.
|
||||
if len(a.set) == 0 {
|
||||
return -1, false
|
||||
}
|
||||
|
||||
// In most cases, the last index item will be the one we want. So just
|
||||
// do a quick check on that index first to avoid doing the binary
|
||||
// search.
|
||||
lastIndex := len(a.set) - 1
|
||||
lastRange := a.set[lastIndex]
|
||||
if lastRange.start <= n {
|
||||
return lastIndex, lastRange.end >= n
|
||||
}
|
||||
|
||||
// Otherwise, do a binary search to find the index of interest.
|
||||
var (
|
||||
low = 0
|
||||
high = len(a.set) - 1
|
||||
rangeIndex = -1
|
||||
)
|
||||
for {
|
||||
mid := (low + high) / 2 //nolint: gomnd
|
||||
currentRange := a.set[mid]
|
||||
|
||||
switch {
|
||||
case currentRange.start > n:
|
||||
// If the start of the range is greater than n, we can
|
||||
// completely cut out that entire part of the array.
|
||||
high = mid
|
||||
|
||||
case currentRange.start < n:
|
||||
// If the range already includes the given height, we
|
||||
// can stop searching now.
|
||||
if currentRange.end >= n {
|
||||
return mid, true
|
||||
}
|
||||
|
||||
// If the start of the range is smaller than n, we can
|
||||
// store this as the new best index to return.
|
||||
rangeIndex = mid
|
||||
|
||||
// If low and mid are already equal, then increment low
|
||||
// by 1. Exit if this means that low is now greater than
|
||||
// high.
|
||||
if low == mid {
|
||||
low = mid + 1
|
||||
if low > high {
|
||||
return rangeIndex, false
|
||||
}
|
||||
} else {
|
||||
low = mid
|
||||
}
|
||||
|
||||
continue
|
||||
|
||||
default:
|
||||
// If the height is equal to the start value of the
|
||||
// current range that mid is pointing to, then the
|
||||
// height is already covered.
|
||||
return mid, true
|
||||
}
|
||||
|
||||
// Exit if we have checked all the ranges.
|
||||
if low == high {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return rangeIndex, false
|
||||
}
|
||||
|
||||
// KVStore is an interface representing a key-value store.
|
||||
type KVStore interface {
|
||||
// Put saves the specified key/value pair to the store. Keys that do not
|
||||
// already exist are added and keys that already exist are overwritten.
|
||||
Put(key, value []byte) error
|
||||
|
||||
// Delete removes the specified key from the bucket. Deleting a key that
|
||||
// does not exist does not return an error.
|
||||
Delete(key []byte) error
|
||||
}
|
||||
|
||||
// Add adds a single number to the range set. It first attempts to apply the
|
||||
// necessary changes to the passed KV store and then only if this succeeds, will
|
||||
// the changes be applied to the in-memory structure.
|
||||
func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error {
|
||||
a.mu.Lock()
|
||||
defer a.mu.Unlock()
|
||||
|
||||
// Compute the changes that will need to be applied to both the sorted
|
||||
// rangeItem array representation and the key-value store representation
|
||||
// of the range index.
|
||||
arrayChanges, kvStoreChanges := a.getChanges(newHeight)
|
||||
|
||||
// First attempt to apply the KV store changes. Only if this succeeds
|
||||
// will we apply the changes to our in-memory range index structure.
|
||||
err := a.applyKVChanges(kv, kvStoreChanges)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Since the DB changes were successful, we can now commit the
|
||||
// changes to our in-memory representation of the range set.
|
||||
a.applyArrayChanges(arrayChanges)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyKVChanges applies the given set of kvChanges to a KV store. It is
|
||||
// assumed that a transaction is being held on the kv store so that if any
|
||||
// of the actions of the function fails, the changes will be reverted.
|
||||
func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error {
|
||||
// Exit early if there are no changes to apply.
|
||||
if kv == nil || changes == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Check if any range pair needs to be deleted.
|
||||
if changes.deleteKVKey != nil {
|
||||
del, err := a.serializeUint64(*changes.deleteKVKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := kv.Delete(del); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
start, err := a.serializeUint64(changes.key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
end, err := a.serializeUint64(changes.value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return kv.Put(start, end)
|
||||
}
|
||||
|
||||
// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex
|
||||
// itself. This should only be done once the persisted kv store changes have
|
||||
// already been applied.
|
||||
func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) {
|
||||
if changes == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if changes.indexToDelete != nil {
|
||||
a.set = append(
|
||||
a.set[:*changes.indexToDelete],
|
||||
a.set[*changes.indexToDelete+1:]...,
|
||||
)
|
||||
}
|
||||
|
||||
if changes.newIndex != nil {
|
||||
switch {
|
||||
case *changes.newIndex == 0:
|
||||
a.set = append([]rangeItem{{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}}, a.set...)
|
||||
|
||||
case *changes.newIndex == len(a.set):
|
||||
a.set = append(a.set, rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
})
|
||||
|
||||
default:
|
||||
a.set = append(
|
||||
a.set[:*changes.newIndex+1],
|
||||
a.set[*changes.newIndex:]...,
|
||||
)
|
||||
a.set[*changes.newIndex] = rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if changes.indexToEdit != nil {
|
||||
a.set[*changes.indexToEdit] = rangeItem{
|
||||
start: changes.start,
|
||||
end: changes.end,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// arrayChanges encompasses the diff to apply to the sorted rangeItem array
|
||||
// representation of a range index. Such a diff will either include adding a
|
||||
// new range or editing an existing range. If an existing range is edited, then
|
||||
// the diff might also include deleting an index (this will be the case if the
|
||||
// editing of the one range results in the merge of another range).
|
||||
type arrayChanges struct {
|
||||
start uint64
|
||||
end uint64
|
||||
|
||||
// newIndex, if set, is the index of the in-memory range array where a
|
||||
// new range, [start:end], should be added. newIndex should never be
|
||||
// set at the same time as indexToEdit or indexToDelete.
|
||||
newIndex *int
|
||||
|
||||
// indexToDelete, if set, is the index of the sorted rangeItem array
|
||||
// that should be deleted. This should be applied before reading the
|
||||
// index value of indexToEdit. This should not be set at the same time
|
||||
// as newIndex.
|
||||
indexToDelete *int
|
||||
|
||||
// indexToEdit is the index of the in-memory range array that should be
|
||||
// edited. The range at this index will be changed to [start:end]. This
|
||||
// should only be read after indexToDelete index has been deleted.
|
||||
indexToEdit *int
|
||||
}
|
||||
|
||||
// kvChanges encompasses the diff to apply to a KV-store representation of a
|
||||
// range index. A kv-store diff for the addition of a single number to the range
|
||||
// index will include either a brand new key-value pair or the altering of the
|
||||
// value of an existing key. Optionally, the diff may also include the deletion
|
||||
// of an existing key. A deletion will be required if the addition of the new
|
||||
// number results in the merge of two ranges.
|
||||
type kvChanges struct {
|
||||
key uint64
|
||||
value uint64
|
||||
|
||||
// deleteKVKey, if set, is the key of the kv store representation that
|
||||
// should be deleted.
|
||||
deleteKVKey *uint64
|
||||
}
|
||||
|
||||
// getChanges will calculate and return the changes that need to be applied to
|
||||
// both the sorted-rangeItem-array representation and the key-value store
|
||||
// representation of the range index.
|
||||
func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) {
|
||||
// If the set is empty then a new range item is added.
|
||||
if len(a.set) == 0 {
|
||||
// For the array representation, a new range [n:n] is added to
|
||||
// the first index of the array.
|
||||
firstIndex := 0
|
||||
ac := &arrayChanges{
|
||||
newIndex: &firstIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
// For the KV representation, a new [n:n] pair is added.
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Find the index of the lower bound range to the new number.
|
||||
indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n)
|
||||
|
||||
switch {
|
||||
// The new number is already covered by the range index. No changes are
|
||||
// required.
|
||||
case alreadyCovered:
|
||||
return nil, nil
|
||||
|
||||
// No lower bound index exists.
|
||||
case indexOfRangeBelow < 0:
|
||||
// Check if the very first range can be merged into this new
|
||||
// one.
|
||||
if n+1 == a.set[0].start {
|
||||
// If so, the two ranges can be merged and so the start
|
||||
// value of the range is n and the end value is the end
|
||||
// of the existing first range.
|
||||
start := n
|
||||
end := a.set[0].end
|
||||
|
||||
// For the array representation, we can just edit the
|
||||
// first entry of the array
|
||||
editIndex := 0
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &editIndex,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the KV store representation, we add a new kv pair
|
||||
// and delete the range with the key equal to the start
|
||||
// value of the range we are merging.
|
||||
kvKeyToDelete := a.set[0].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &kvKeyToDelete,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Otherwise, we add a new index.
|
||||
|
||||
// For the array representation, a new range [n:n] is added to
|
||||
// the first index of the array.
|
||||
newIndex := 0
|
||||
ac := &arrayChanges{
|
||||
newIndex: &newIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
// For the KV representation, a new [n:n] pair is added.
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
|
||||
// A lower range does exist, and it can be extended to include this new
|
||||
// number.
|
||||
case a.set[indexOfRangeBelow].end+1 == n:
|
||||
start := a.set[indexOfRangeBelow].start
|
||||
end := n
|
||||
indexToChange := indexOfRangeBelow
|
||||
|
||||
// If there are no intervals above this one or if there are, but
|
||||
// they can't be merged into this one then we just need to edit
|
||||
// this interval.
|
||||
if indexOfRangeBelow == len(a.set)-1 ||
|
||||
a.set[indexOfRangeBelow+1].start != n+1 {
|
||||
|
||||
// For the array representation, we just edit the index.
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &indexToChange,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the key-value representation, we just overwrite
|
||||
// the end value at the existing start key.
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// There is a range above this one that we need to merge into
|
||||
// this one.
|
||||
delIndex := indexOfRangeBelow + 1
|
||||
end = a.set[delIndex].end
|
||||
|
||||
// For the array representation, we delete the range above this
|
||||
// one and edit this range to include the end value of the range
|
||||
// above.
|
||||
ac := &arrayChanges{
|
||||
indexToDelete: &delIndex,
|
||||
indexToEdit: &indexToChange,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the kv representation, we tweak the end value of an
|
||||
// existing key and delete the key of the range we are deleting.
|
||||
deleteKey := a.set[delIndex].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &deleteKey,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
|
||||
// A lower range does exist, but it can't be extended to include this
|
||||
// new number, and so we need to add a new range after the lower bound
|
||||
// range.
|
||||
default:
|
||||
newIndex := indexOfRangeBelow + 1
|
||||
|
||||
// If there are no ranges above this new one or if there are,
|
||||
// but they can't be merged into this new one, then we can just
|
||||
// add the new one as is.
|
||||
if newIndex == len(a.set) || a.set[newIndex].start != n+1 {
|
||||
ac := &arrayChanges{
|
||||
newIndex: &newIndex,
|
||||
start: n,
|
||||
end: n,
|
||||
}
|
||||
|
||||
kvc := &kvChanges{
|
||||
key: n,
|
||||
value: n,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
|
||||
// Else, we merge the above index.
|
||||
start := n
|
||||
end := a.set[newIndex].end
|
||||
toEdit := newIndex
|
||||
|
||||
// For the array representation, we edit the range above to
|
||||
// include the new start value.
|
||||
ac := &arrayChanges{
|
||||
indexToEdit: &toEdit,
|
||||
start: start,
|
||||
end: end,
|
||||
}
|
||||
|
||||
// For the kv representation, we insert the new start-end key
|
||||
// value pair and delete the key using the old start value.
|
||||
delKey := a.set[newIndex].start
|
||||
kvc := &kvChanges{
|
||||
key: start,
|
||||
value: end,
|
||||
deleteKVKey: &delKey,
|
||||
}
|
||||
|
||||
return ac, kvc
|
||||
}
|
||||
}
|
||||
|
||||
func defaultSerializeUint64(i uint64) ([]byte, error) {
|
||||
var b [8]byte
|
||||
byteOrder.PutUint64(b[:], i)
|
||||
return b[:], nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user