diff --git a/watchtower/wtdb/migration4/client_db.go b/watchtower/wtdb/migration4/client_db.go new file mode 100644 index 000000000..0dcb499c8 --- /dev/null +++ b/watchtower/wtdb/migration4/client_db.go @@ -0,0 +1,628 @@ +package migration4 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + // cChanDBID is a key used in the cChanDetailsBkt to store the + // db-assigned-id of a channel. + cChanDBID = []byte("client-channel-db-id") + + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionAcks is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded BackupID. + cSessionAcks = []byte("client-session-acks") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing: + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrCorruptChanDetails signals that the clients channel detail's + // on-disk structure deviates from what is expected. + ErrCorruptChanDetails = errors.New("channel details corrupted") + + // ErrChannelNotRegistered signals a channel has not yet been registered + // in the client database. + ErrChannelNotRegistered = errors.New("channel not registered") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian + + // errExit is an error used to signal that the sessionIterator should + // exit. + errExit = errors.New("the exit condition has been met") +) + +// DefaultSessionsPerTx is the default number of sessions that should be +// migrated per db transaction. +const DefaultSessionsPerTx = 5000 + +// MigrateAckedUpdates migrates the tower client DB. It takes the individual +// Acked Updates that are stored for each session and re-stores them using the +// RangeIndex representation. +func MigrateAckedUpdates(sessionsPerTx int) func(kvdb.Backend) error { + return func(db kvdb.Backend) error { + log.Infof("Migrating the tower client db to move all Acked " + + "Updates to the new Range Index representation.") + + // Migrate the old acked-updates. + err := migrateAckedUpdates(db, sessionsPerTx) + if err != nil { + return fmt.Errorf("migration failed: %w", err) + } + + log.Infof("Migrating old session acked updates finished, now " + + "checking the migration results...") + + // Before we can safety delete the old buckets, we perform a + // check to make sure the sessions have been migrated as + // expected. + err = kvdb.View(db, validateMigration, func() {}) + if err != nil { + return fmt.Errorf("validate migration failed: %w", err) + } + + // Delete old acked updates. + err = kvdb.Update(db, deleteOldAckedUpdates, func() {}) + if err != nil { + return fmt.Errorf("failed to delete old acked "+ + "updates: %w", err) + } + + return nil + } +} + +// migrateAckedUpdates migrates the acked updates of each session in the +// wtclient db into the new RangeIndex form. This is done over multiple db +// transactions in order to prevent the migration from taking up too much RAM. +// The sessionsPerTx parameter can be used to set the maximum number of sessions +// that should be migrated per transaction. +func migrateAckedUpdates(db kvdb.Backend, sessionsPerTx int) error { + // Get migration progress stats. + total, migrated, err := logMigrationStats(db) + if err != nil { + return err + } + log.Infof("Total sessions=%d, migrated=%d", total, migrated) + + // Exit early if the old session acked updates have already been + // migrated and deleted. + if total == 0 { + log.Info("Migration already finished!") + return nil + } + + var ( + finished bool + startKey []byte + ) + for { + // Process the migration. + err = kvdb.Update(db, func(tx kvdb.RwTx) error { + startKey, finished, err = processMigration( + tx, startKey, sessionsPerTx, + ) + + return err + }, func() {}) + if err != nil { + return err + } + + if finished { + break + } + + // Each time we finished the above process, we'd read the stats + // again to understand the current progress. + total, migrated, err = logMigrationStats(db) + if err != nil { + return err + } + + // Calculate and log the progress if the progress is less than + // one hundred percent. + //nolint:gomnd + progress := float64(migrated) / float64(total) * 100 + if progress >= 100 { //nolint:gomnd + break + } + + log.Infof("Migration progress: %.3f%%, still have: %d", + progress, total-migrated) + } + + return nil +} + +func validateMigration(tx kvdb.RTx) error { + mainSessionsBkt := tx.ReadBucket(cSessionBkt) + if mainSessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + return mainSessionsBkt.ForEach(func(sessID, _ []byte) error { + // Get the bucket for this particular session. + sessionBkt := mainSessionsBkt.NestedReadBucket(sessID) + if sessionBkt == nil { + return ErrClientSessionNotFound + } + + // Get the bucket where any old acked updates would be stored. + oldAcksBucket := sessionBkt.NestedReadBucket(cSessionAcks) + + // Get the bucket where any new acked updates would be stored. + newAcksBucket := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + + switch { + // If both the old and new acked updates buckets are nil, then + // we can safely skip this session. + case oldAcksBucket == nil && newAcksBucket == nil: + return nil + + case oldAcksBucket == nil: + return fmt.Errorf("no old acks but do have new acks") + + case newAcksBucket == nil: + return fmt.Errorf("no new acks but have old acks") + + default: + } + + // Collect acked ranges for this session. + ackedRanges := make(map[uint64]*RangeIndex) + err := newAcksBucket.ForEach(func(dbChanID, _ []byte) error { + rangeIndexBkt := newAcksBucket.NestedReadBucket( + dbChanID, + ) + if rangeIndexBkt == nil { + return fmt.Errorf("no acked updates bucket "+ + "found for channel %x", dbChanID) + } + + // Read acked ranges from new bucket. + ri, err := readRangeIndex(rangeIndexBkt) + if err != nil { + return err + } + + dbChanIDNum, err := readBigSize(dbChanID) + if err != nil { + return err + } + + ackedRanges[dbChanIDNum] = ri + + return nil + }) + if err != nil { + return err + } + + // Now we will iterate through each of the old acked updates and + // make sure that the update appears in the new bucket. + return oldAcksBucket.ForEach(func(_, v []byte) error { + var backupID BackupID + err := backupID.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + dbChanID, _, err := getDBChanID( + chanDetailsBkt, backupID.ChanID, + ) + if err != nil { + return err + } + + index, ok := ackedRanges[dbChanID] + if !ok { + return fmt.Errorf("no index found for this " + + "channel") + } + + if !index.IsInIndex(backupID.CommitHeight) { + return fmt.Errorf("commit height not found " + + "in index") + } + + return nil + }) + }) +} + +func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) { + ranges := make(map[uint64]uint64) + err := rangesBkt.ForEach(func(k, v []byte) error { + start, err := readBigSize(k) + if err != nil { + return err + } + + end, err := readBigSize(v) + if err != nil { + return err + } + + ranges[start] = end + + return nil + }) + if err != nil { + return nil, err + } + + return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize)) +} + +func deleteOldAckedUpdates(tx kvdb.RwTx) error { + mainSessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if mainSessionsBkt == nil { + return ErrUninitializedDB + } + + return mainSessionsBkt.ForEach(func(sessID, _ []byte) error { + // Get the bucket for this particular session. + sessionBkt := mainSessionsBkt.NestedReadWriteBucket( + sessID, + ) + if sessionBkt == nil { + return ErrClientSessionNotFound + } + + // Get the bucket where any old acked updates would be stored. + oldAcksBucket := sessionBkt.NestedReadBucket(cSessionAcks) + if oldAcksBucket == nil { + return nil + } + + // Now that we have read everything that we need to from + // the cSessionAcks sub-bucket, we can delete it. + return sessionBkt.DeleteNestedBucket(cSessionAcks) + }) +} + +// processMigration uses the given transaction to perform a maximum of +// sessionsPerTx session migrations. If startKey is non-nil, it is used to +// determine the first session to start the migration at. The first return +// item is the key of the last session that was migrated successfully and the +// boolean is true if there are no more sessions left to migrate. +func processMigration(tx kvdb.RwTx, startKey []byte, sessionsPerTx int) ([]byte, + bool, error) { + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return nil, false, ErrUninitializedDB + } + + // sessionCount keeps track of the number of sessions that have been + // migrated under the current db transaction. + var sessionCount int + + // migrateSessionCB is a callback function that calls migrateSession + // in order to migrate a single session. Upon success, the sessionCount + // is incremented and is then compared against sessionsPerTx to + // determine if we should continue migrating more sessions in this db + // transaction. + migrateSessionCB := func(sessionBkt kvdb.RwBucket) error { + err := migrateSession(chanDetailsBkt, sessionBkt) + if err != nil { + return err + } + + sessionCount++ + + // If we have migrated sessionsPerTx sessions in this tx, then + // we return errExit in order to signal that this tx should be + // committed and the migration should be continued in a new + // transaction. + if sessionCount >= sessionsPerTx { + return errExit + } + + return nil + } + + // Starting at startKey, iterate over the sessions in the db and migrate + // them until either all are migrated or until the errExit signal is + // received. + lastKey, err := sessionIterator(tx, startKey, migrateSessionCB) + if err != nil && errors.Is(err, errExit) { + return lastKey, false, nil + } else if err != nil { + return nil, false, err + } + + // The migration is complete. + return nil, true, nil +} + +// migrateSession migrates a single session's acked-updates to the new +// RangeIndex form. +func migrateSession(chanDetailsBkt kvdb.RBucket, + sessionBkt kvdb.RwBucket) error { + + // Get the existing cSessionAcks bucket. If there is no such bucket, + // then there are no acked-updates to migrate for this session. + sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) + if sessionAcks == nil { + return nil + } + + // If there is already a new cSessionAckedRangeIndex bucket, then this + // session has already been migrated. + sessionAckRangesBkt := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if sessionAckRangesBkt != nil { + return nil + } + + // Otherwise, we will iterate over each of the acked-updates, and we + // will construct a new RangeIndex for each channel. + m := make(map[ChannelID]*RangeIndex) + if err := sessionAcks.ForEach(func(_, v []byte) error { + var backupID BackupID + err := backupID.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + if _, ok := m[backupID.ChanID]; !ok { + index, err := NewRangeIndex(nil) + if err != nil { + return err + } + + m[backupID.ChanID] = index + } + + return m[backupID.ChanID].Add(backupID.CommitHeight, nil) + }); err != nil { + return err + } + + // Create a new sub-bucket that will be used to store the new RangeIndex + // representation of the acked updates. + ackRangeBkt, err := sessionBkt.CreateBucket(cSessionAckRangeIndex) + if err != nil { + return err + } + + // Iterate over each of the new range indexes that we will add for this + // session. + for chanID, rangeIndex := range m { + // Get db chanID. + chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:]) + if chanDetails == nil { + return ErrCorruptChanDetails + } + + // Create a sub-bucket for this channel using the db-assigned ID + // for the channel. + dbChanID := chanDetails.Get(cChanDBID) + chanAcksBkt, err := ackRangeBkt.CreateBucket(dbChanID) + if err != nil { + return err + } + + // Iterate over the range pairs that we need to add to the DB. + for k, v := range rangeIndex.GetAllRanges() { + start, err := writeBigSize(k) + if err != nil { + return err + } + + end, err := writeBigSize(v) + if err != nil { + return err + } + + err = chanAcksBkt.Put(start, end) + if err != nil { + return err + } + } + } + + return nil +} + +// logMigrationStats reads the buckets to provide stats over current migration +// progress. The returned values are the numbers of total records and already +// migrated records. +func logMigrationStats(db kvdb.Backend) (uint64, uint64, error) { + var ( + err error + total uint64 + unmigrated uint64 + ) + + err = kvdb.View(db, func(tx kvdb.RTx) error { + total, unmigrated, err = getMigrationStats(tx) + + return err + }, func() {}) + + log.Debugf("Total sessions=%d, unmigrated=%d", total, unmigrated) + + return total, total - unmigrated, err +} + +// getMigrationStats iterates over all sessions. It counts the total number of +// sessions as well as the total number of unmigrated sessions. +func getMigrationStats(tx kvdb.RTx) (uint64, uint64, error) { + var ( + total uint64 + unmigrated uint64 + ) + + // Get sessions bucket. + mainSessionsBkt := tx.ReadBucket(cSessionBkt) + if mainSessionsBkt == nil { + return 0, 0, ErrUninitializedDB + } + + // Iterate over each session ID in the bucket. + err := mainSessionsBkt.ForEach(func(sessID, _ []byte) error { + // Get the bucket for this particular session. + sessionBkt := mainSessionsBkt.NestedReadBucket(sessID) + if sessionBkt == nil { + return ErrClientSessionNotFound + } + + total++ + + // Get the cSessionAckRangeIndex bucket. + sessionAcksBkt := sessionBkt.NestedReadBucket(cSessionAcks) + + // Get the cSessionAckRangeIndex bucket. + sessionAckRangesBkt := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + + // If both buckets do not exist, then this session is empty and + // does not need to be migrated. + if sessionAckRangesBkt == nil && sessionAcksBkt == nil { + return nil + } + + // If the sessionAckRangesBkt is not nil, then the session has + // already been migrated. + if sessionAckRangesBkt != nil { + return nil + } + + // Else the session has not yet been migrated. + unmigrated++ + + return nil + }) + if err != nil { + return 0, 0, err + } + + return total, unmigrated, nil +} + +// getDBChanID returns the db-assigned channel ID for the given real channel ID. +// It returns both the uint64 and byte representation. +func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID ChannelID) (uint64, + []byte, error) { + + chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:]) + if chanDetails == nil { + return 0, nil, ErrChannelNotRegistered + } + + idBytes := chanDetails.Get(cChanDBID) + if len(idBytes) == 0 { + return 0, nil, fmt.Errorf("no db-assigned ID found for "+ + "channel ID %s", chanID) + } + + id, err := readBigSize(idBytes) + if err != nil { + return 0, nil, err + } + + return id, idBytes, nil +} + +// callback defines a type that's used by the sessionIterator. +type callback func(bkt kvdb.RwBucket) error + +// sessionIterator is a helper function that iterates over the main sessions +// bucket and performs the callback function on each individual session. If a +// seeker is specified, it will move the cursor to the given position otherwise +// it will start from the first item. +func sessionIterator(tx kvdb.RwTx, seeker []byte, cb callback) ([]byte, error) { + // Get sessions bucket. + mainSessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if mainSessionsBkt == nil { + return nil, ErrUninitializedDB + } + + c := mainSessionsBkt.ReadCursor() + k, _ := c.First() + + // Move the cursor to the specified position if seeker is non-nil. + if seeker != nil { + k, _ = c.Seek(seeker) + } + + // Start the iteration and exit on condition. + for k := k; k != nil; k, _ = c.Next() { + // Get the bucket for this particular session. + bkt := mainSessionsBkt.NestedReadWriteBucket(k) + if bkt == nil { + return nil, ErrClientSessionNotFound + } + + // Call the callback function with the session's bucket. + if err := cb(bkt); err != nil { + // return k, err + lastIndex := make([]byte, len(k)) + copy(lastIndex, k) + return lastIndex, err + } + } + + return nil, nil +} + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// readBigSize converts the given byte slice into a uint64 and assumes that the +// bytes slice is using BigSize encoding. +func readBigSize(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} diff --git a/watchtower/wtdb/migration4/client_db_test.go b/watchtower/wtdb/migration4/client_db_test.go new file mode 100644 index 000000000..2cdb7c4e9 --- /dev/null +++ b/watchtower/wtdb/migration4/client_db_test.go @@ -0,0 +1,329 @@ +package migration4 + +import ( + "bytes" + "testing" + + "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // details is the expected data of the channel details bucket. This + // bucket should not be changed during the migration, but it is used + // to find the db-assigned ID for each channel. + details = map[string]interface{}{ + channelIDString(1): map[string]interface{}{ + string(cChanDBID): uint64ToStr(10), + }, + channelIDString(2): map[string]interface{}{ + string(cChanDBID): uint64ToStr(20), + }, + } + + // preSessions is the expected data in the sessions bucket before the + // migration. + preSessions = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 30, + }), + "2": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 31, + }), + "3": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 32, + }), + "4": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 34, + }), + "5": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 30, + }), + }, + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 33, + }), + }, + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 35, + }), + "2": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 36, + }), + "3": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 28, + }), + "4": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 29, + }), + }, + }, + } + + // preMidStateDB is a possible state that the db could be in if the + // migration started but was interrupted before completing. In this + // state, some sessions still have the old cSessionAcks bucket along + // with the new cSessionAckRangeIndex. This is a valid pre-state and + // a migration on this state should succeed. + preMidStateDB = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 30, + }), + "2": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 31, + }), + "3": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 32, + }), + "4": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 34, + }), + "5": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 30, + }), + }, + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(32), + uint64ToStr(34): uint64ToStr(34), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(30), + }, + }, + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 33, + }), + }, + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionAcks): map[string]interface{}{ + "1": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 35, + }), + "2": backupIDToString(&BackupID{ + ChanID: intToChannelID(1), + CommitHeight: 36, + }), + "3": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 28, + }), + "4": backupIDToString(&BackupID{ + ChanID: intToChannelID(2), + CommitHeight: 29, + }), + }, + }, + } + + // preFailCorruptDB should fail the migration due to no session data + // being found for a given session ID. + preFailCorruptDB = map[string]interface{}{ + sessionIDString("2"): "", + } + + // postSessions is the expected data in the sessions bucket after the + // migration. + postSessions = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(32), + uint64ToStr(34): uint64ToStr(34), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(30), + }, + }, + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(33): uint64ToStr(33), + }, + }, + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(35): uint64ToStr(36), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(28): uint64ToStr(29), + }, + }, + }, + } +) + +// TestMigrateAckedUpdates tests that the MigrateAckedUpdates function correctly +// migrates the existing AckedUpdates bucket for each session to the new +// RangeIndex representation. +func TestMigrateAckedUpdates(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: preSessions, + post: postSessions, + }, + { + name: "migration ok after re-starting", + shouldFail: false, + pre: preMidStateDB, + post: postSessions, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailCorruptDB, + }, + { + name: "no sessions details", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + before := before(test.pre) + + // After the migration, we should have an untouched + // summary bucket and a new index bucket. + after := after(test.shouldFail, test.pre, test.post) + + migtest.ApplyMigrationWithDb( + t, before, after, MigrateAckedUpdates(2), + test.shouldFail, + ) + }) + } +} + +// before returns a call-back function that can be used to set up a db's +// cChanDetailsBkt along with the cSessionBkt using the passed preMigDB +// structure. +func before(preMigDB map[string]interface{}) func(backend kvdb.Backend) error { + return func(db kvdb.Backend) error { + return db.Update(func(tx walletdb.ReadWriteTx) error { + err := migtest.RestoreDB( + tx, cChanDetailsBkt, details, + ) + if err != nil { + return err + } + + return migtest.RestoreDB( + tx, cSessionBkt, preMigDB, + ) + }, func() {}) + } +} + +// after returns a call-back function that can be used to verify the state of +// a db post migration. +func after(shouldFail bool, preMigDB, + postMigDB map[string]interface{}) func(backend kvdb.Backend) error { + + return func(db kvdb.Backend) error { + return db.Update(func(tx walletdb.ReadWriteTx) error { + // The channel details bucket should remain untouched. + err := migtest.VerifyDB(tx, cChanDetailsBkt, details) + if err != nil { + return err + } + + // If the migration fails, the sessions bucket should be + // untouched. + if shouldFail { + if err := migtest.VerifyDB( + tx, cSessionBkt, preMigDB, + ); err != nil { + return err + } + + return nil + } + + return migtest.VerifyDB(tx, cSessionBkt, postMigDB) + }, func() {}) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func intToChannelID(id uint64) ChannelID { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return chanID +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeBigSize(id) + if err != nil { + panic(err) + } + + return string(b) +} + +func backupIDToString(backup *BackupID) string { + var b bytes.Buffer + _ = backup.Encode(&b) + return b.String() +} diff --git a/watchtower/wtdb/migration4/codec.go b/watchtower/wtdb/migration4/codec.go new file mode 100644 index 000000000..9205f6528 --- /dev/null +++ b/watchtower/wtdb/migration4/codec.go @@ -0,0 +1,129 @@ +package migration4 + +import ( + "encoding/binary" + "encoding/hex" + "fmt" + "io" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// String returns the string representation of the ChannelID. This is just the +// hex string encoding of the ChannelID itself. +func (c ChannelID) String() string { + return hex.EncodeToString(c[:]) +} + +// BackupID identifies a particular revoked, remote commitment by channel id and +// commitment height. +type BackupID struct { + // ChanID is the channel id of the revoked commitment. + ChanID ChannelID + + // CommitHeight is the commitment height of the revoked commitment. + CommitHeight uint64 +} + +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + +// String returns a human-readable encoding of a BackupID. +func (b BackupID) String() string { + return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight) +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} diff --git a/watchtower/wtdb/migration4/log.go b/watchtower/wtdb/migration4/log.go new file mode 100644 index 000000000..3a609d76d --- /dev/null +++ b/watchtower/wtdb/migration4/log.go @@ -0,0 +1,14 @@ +package migration4 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/migration4/range_index.go b/watchtower/wtdb/migration4/range_index.go new file mode 100644 index 000000000..79793ce24 --- /dev/null +++ b/watchtower/wtdb/migration4/range_index.go @@ -0,0 +1,619 @@ +package migration4 + +import ( + "fmt" + "sync" +) + +// rangeItem represents the start and end values of a range. +type rangeItem struct { + start uint64 + end uint64 +} + +// RangeIndexOption describes the signature of a functional option that can be +// used to modify the behaviour of a RangeIndex. +type RangeIndexOption func(*RangeIndex) + +// WithSerializeUint64Fn is a functional option that can be used to set the +// function to be used to do the serialization of a uint64 into a byte slice. +func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption { + return func(index *RangeIndex) { + index.serializeUint64 = fn + } +} + +// RangeIndex can be used to keep track of which numbers have been added to a +// set. It does so by keeping track of a sorted list of rangeItems. Each +// rangeItem has a start and end value of a range where all values in-between +// have been added to the set. It works well in situations where it is expected +// numbers in the set are not sparse. +type RangeIndex struct { + // set is a sorted list of rangeItem. + set []rangeItem + + // mu is used to ensure safe access to set. + mu sync.Mutex + + // serializeUint64 is the function that can be used to convert a uint64 + // to a byte slice. + serializeUint64 func(uint64) ([]byte, error) +} + +// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be +// passed to the function in the form of a map. +func NewRangeIndex(ranges map[uint64]uint64, + opts ...RangeIndexOption) (*RangeIndex, error) { + + index := &RangeIndex{ + serializeUint64: defaultSerializeUint64, + set: make([]rangeItem, 0), + } + + // Apply any functional options. + for _, o := range opts { + o(index) + } + + for s, e := range ranges { + if err := index.addRange(s, e); err != nil { + return nil, err + } + } + + return index, nil +} + +// addRange can be used to add an entire new range to the set. This method +// should only ever be called by NewRangeIndex to initialise the in-memory +// structure and so the RangeIndex mutex is not held during this method. +func (a *RangeIndex) addRange(start, end uint64) error { + // Check that the given range is valid. + if start > end { + return fmt.Errorf("invalid range. Start height %d is larger "+ + "than end height %d", start, end) + } + + // min is a helper closure that will return the minimum of two uint64s. + min := func(a, b uint64) uint64 { + if a < b { + return a + } + + return b + } + + // max is a helper closure that will return the maximum of two uint64s. + max := func(a, b uint64) uint64 { + if a > b { + return a + } + + return b + } + + // Collect the ranges that fall before and after the new range along + // with the start and end values of the new range. + var before, after []rangeItem + for _, x := range a.set { + // If the new start value can't extend the current ranges end + // value, then the two cannot be merged. The range is added to + // the group of ranges that fall before the new range. + if x.end+1 < start { + before = append(before, x) + continue + } + + // If the current ranges start value does not follow on directly + // from the new end value, then the two cannot be merged. The + // range is added to the group of ranges that fall after the new + // range. + if end+1 < x.start { + after = append(after, x) + continue + } + + // Otherwise, there is an overlap and so the two can be merged. + start = min(start, x.start) + end = max(end, x.end) + } + + // Re-construct the range index set. + a.set = append(append(before, rangeItem{ + start: start, + end: end, + }), after...) + + return nil +} + +// IsInIndex returns true if the given number is in the range set. +func (a *RangeIndex) IsInIndex(n uint64) bool { + a.mu.Lock() + defer a.mu.Unlock() + + _, isCovered := a.lowerBoundIndex(n) + + return isCovered +} + +// NumInSet returns the number of items covered by the range set. +func (a *RangeIndex) NumInSet() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + var numItems uint64 + for _, r := range a.set { + numItems += r.end - r.start + 1 + } + + return numItems +} + +// MaxHeight returns the highest number covered in the range. +func (a *RangeIndex) MaxHeight() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + if len(a.set) == 0 { + return 0 + } + + return a.set[len(a.set)-1].end +} + +// GetAllRanges returns a copy of the range set in the form of a map. +func (a *RangeIndex) GetAllRanges() map[uint64]uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + cp := make(map[uint64]uint64, len(a.set)) + for _, item := range a.set { + cp[item.start] = item.end + } + + return cp +} + +// lowerBoundIndex returns the index of the RangeIndex that is most appropriate +// for the new value, n. In other words, it returns the index of the rangeItem +// set of the range where the start value is the highest start value in the set +// that is still lower than or equal to the given number, n. The returned +// boolean is true if the given number is already covered in the RangeIndex. +// A returned index of -1 indicates that no lower bound range exists in the set. +// Since the most likely case is that the new number will just extend the +// highest range, a check is first done to see if this is the case which will +// make the methods' computational complexity O(1). Otherwise, a binary search +// is done which brings the computational complexity to O(log N). +func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) { + // If the set is empty, then there is no such index and the value + // definitely is not in the set. + if len(a.set) == 0 { + return -1, false + } + + // In most cases, the last index item will be the one we want. So just + // do a quick check on that index first to avoid doing the binary + // search. + lastIndex := len(a.set) - 1 + lastRange := a.set[lastIndex] + if lastRange.start <= n { + return lastIndex, lastRange.end >= n + } + + // Otherwise, do a binary search to find the index of interest. + var ( + low = 0 + high = len(a.set) - 1 + rangeIndex = -1 + ) + for { + mid := (low + high) / 2 //nolint: gomnd + currentRange := a.set[mid] + + switch { + case currentRange.start > n: + // If the start of the range is greater than n, we can + // completely cut out that entire part of the array. + high = mid + + case currentRange.start < n: + // If the range already includes the given height, we + // can stop searching now. + if currentRange.end >= n { + return mid, true + } + + // If the start of the range is smaller than n, we can + // store this as the new best index to return. + rangeIndex = mid + + // If low and mid are already equal, then increment low + // by 1. Exit if this means that low is now greater than + // high. + if low == mid { + low = mid + 1 + if low > high { + return rangeIndex, false + } + } else { + low = mid + } + + continue + + default: + // If the height is equal to the start value of the + // current range that mid is pointing to, then the + // height is already covered. + return mid, true + } + + // Exit if we have checked all the ranges. + if low == high { + break + } + } + + return rangeIndex, false +} + +// KVStore is an interface representing a key-value store. +type KVStore interface { + // Put saves the specified key/value pair to the store. Keys that do not + // already exist are added and keys that already exist are overwritten. + Put(key, value []byte) error + + // Delete removes the specified key from the bucket. Deleting a key that + // does not exist does not return an error. + Delete(key []byte) error +} + +// Add adds a single number to the range set. It first attempts to apply the +// necessary changes to the passed KV store and then only if this succeeds, will +// the changes be applied to the in-memory structure. +func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error { + a.mu.Lock() + defer a.mu.Unlock() + + // Compute the changes that will need to be applied to both the sorted + // rangeItem array representation and the key-value store representation + // of the range index. + arrayChanges, kvStoreChanges := a.getChanges(newHeight) + + // First attempt to apply the KV store changes. Only if this succeeds + // will we apply the changes to our in-memory range index structure. + err := a.applyKVChanges(kv, kvStoreChanges) + if err != nil { + return err + } + + // Since the DB changes were successful, we can now commit the + // changes to our in-memory representation of the range set. + a.applyArrayChanges(arrayChanges) + + return nil +} + +// applyKVChanges applies the given set of kvChanges to a KV store. It is +// assumed that a transaction is being held on the kv store so that if any +// of the actions of the function fails, the changes will be reverted. +func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error { + // Exit early if there are no changes to apply. + if kv == nil || changes == nil { + return nil + } + + // Check if any range pair needs to be deleted. + if changes.deleteKVKey != nil { + del, err := a.serializeUint64(*changes.deleteKVKey) + if err != nil { + return err + } + + if err := kv.Delete(del); err != nil { + return err + } + } + + start, err := a.serializeUint64(changes.key) + if err != nil { + return err + } + + end, err := a.serializeUint64(changes.value) + if err != nil { + return err + } + + return kv.Put(start, end) +} + +// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex +// itself. This should only be done once the persisted kv store changes have +// already been applied. +func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) { + if changes == nil { + return + } + + if changes.indexToDelete != nil { + a.set = append( + a.set[:*changes.indexToDelete], + a.set[*changes.indexToDelete+1:]..., + ) + } + + if changes.newIndex != nil { + switch { + case *changes.newIndex == 0: + a.set = append([]rangeItem{{ + start: changes.start, + end: changes.end, + }}, a.set...) + + case *changes.newIndex == len(a.set): + a.set = append(a.set, rangeItem{ + start: changes.start, + end: changes.end, + }) + + default: + a.set = append( + a.set[:*changes.newIndex+1], + a.set[*changes.newIndex:]..., + ) + a.set[*changes.newIndex] = rangeItem{ + start: changes.start, + end: changes.end, + } + } + + return + } + + if changes.indexToEdit != nil { + a.set[*changes.indexToEdit] = rangeItem{ + start: changes.start, + end: changes.end, + } + } +} + +// arrayChanges encompasses the diff to apply to the sorted rangeItem array +// representation of a range index. Such a diff will either include adding a +// new range or editing an existing range. If an existing range is edited, then +// the diff might also include deleting an index (this will be the case if the +// editing of the one range results in the merge of another range). +type arrayChanges struct { + start uint64 + end uint64 + + // newIndex, if set, is the index of the in-memory range array where a + // new range, [start:end], should be added. newIndex should never be + // set at the same time as indexToEdit or indexToDelete. + newIndex *int + + // indexToDelete, if set, is the index of the sorted rangeItem array + // that should be deleted. This should be applied before reading the + // index value of indexToEdit. This should not be set at the same time + // as newIndex. + indexToDelete *int + + // indexToEdit is the index of the in-memory range array that should be + // edited. The range at this index will be changed to [start:end]. This + // should only be read after indexToDelete index has been deleted. + indexToEdit *int +} + +// kvChanges encompasses the diff to apply to a KV-store representation of a +// range index. A kv-store diff for the addition of a single number to the range +// index will include either a brand new key-value pair or the altering of the +// value of an existing key. Optionally, the diff may also include the deletion +// of an existing key. A deletion will be required if the addition of the new +// number results in the merge of two ranges. +type kvChanges struct { + key uint64 + value uint64 + + // deleteKVKey, if set, is the key of the kv store representation that + // should be deleted. + deleteKVKey *uint64 +} + +// getChanges will calculate and return the changes that need to be applied to +// both the sorted-rangeItem-array representation and the key-value store +// representation of the range index. +func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) { + // If the set is empty then a new range item is added. + if len(a.set) == 0 { + // For the array representation, a new range [n:n] is added to + // the first index of the array. + firstIndex := 0 + ac := &arrayChanges{ + newIndex: &firstIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Find the index of the lower bound range to the new number. + indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n) + + switch { + // The new number is already covered by the range index. No changes are + // required. + case alreadyCovered: + return nil, nil + + // No lower bound index exists. + case indexOfRangeBelow < 0: + // Check if the very first range can be merged into this new + // one. + if n+1 == a.set[0].start { + // If so, the two ranges can be merged and so the start + // value of the range is n and the end value is the end + // of the existing first range. + start := n + end := a.set[0].end + + // For the array representation, we can just edit the + // first entry of the array + editIndex := 0 + ac := &arrayChanges{ + indexToEdit: &editIndex, + start: start, + end: end, + } + + // For the KV store representation, we add a new kv pair + // and delete the range with the key equal to the start + // value of the range we are merging. + kvKeyToDelete := a.set[0].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &kvKeyToDelete, + } + + return ac, kvc + } + + // Otherwise, we add a new index. + + // For the array representation, a new range [n:n] is added to + // the first index of the array. + newIndex := 0 + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + + // A lower range does exist, and it can be extended to include this new + // number. + case a.set[indexOfRangeBelow].end+1 == n: + start := a.set[indexOfRangeBelow].start + end := n + indexToChange := indexOfRangeBelow + + // If there are no intervals above this one or if there are, but + // they can't be merged into this one then we just need to edit + // this interval. + if indexOfRangeBelow == len(a.set)-1 || + a.set[indexOfRangeBelow+1].start != n+1 { + + // For the array representation, we just edit the index. + ac := &arrayChanges{ + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the key-value representation, we just overwrite + // the end value at the existing start key. + kvc := &kvChanges{ + key: start, + value: end, + } + + return ac, kvc + } + + // There is a range above this one that we need to merge into + // this one. + delIndex := indexOfRangeBelow + 1 + end = a.set[delIndex].end + + // For the array representation, we delete the range above this + // one and edit this range to include the end value of the range + // above. + ac := &arrayChanges{ + indexToDelete: &delIndex, + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the kv representation, we tweak the end value of an + // existing key and delete the key of the range we are deleting. + deleteKey := a.set[delIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &deleteKey, + } + + return ac, kvc + + // A lower range does exist, but it can't be extended to include this + // new number, and so we need to add a new range after the lower bound + // range. + default: + newIndex := indexOfRangeBelow + 1 + + // If there are no ranges above this new one or if there are, + // but they can't be merged into this new one, then we can just + // add the new one as is. + if newIndex == len(a.set) || a.set[newIndex].start != n+1 { + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Else, we merge the above index. + start := n + end := a.set[newIndex].end + toEdit := newIndex + + // For the array representation, we edit the range above to + // include the new start value. + ac := &arrayChanges{ + indexToEdit: &toEdit, + start: start, + end: end, + } + + // For the kv representation, we insert the new start-end key + // value pair and delete the key using the old start value. + delKey := a.set[newIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &delKey, + } + + return ac, kvc + } +} + +func defaultSerializeUint64(i uint64) ([]byte, error) { + var b [8]byte + byteOrder.PutUint64(b[:], i) + return b[:], nil +}