diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 639030631..ed453383f 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // log is a logger that is initialized with no output filters. This @@ -40,6 +41,7 @@ func UseLogger(logger btclog.Logger) { migration5.UseLogger(logger) migration6.UseLogger(logger) migration7.UseLogger(logger) + migration8.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration8/codec.go b/watchtower/wtdb/migration8/codec.go new file mode 100644 index 000000000..9c8dca1a3 --- /dev/null +++ b/watchtower/wtdb/migration8/codec.go @@ -0,0 +1,234 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "encoding/hex" + "fmt" + "io" + + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/tlv" +) + +// BreachHintSize is the length of the identifier used to detect remote +// commitment broadcasts. +const BreachHintSize = 16 + +// BreachHint is the first 16-bytes of SHA256(txid), which is used to identify +// the breach transaction. +type BreachHint [BreachHintSize]byte + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +// readBigSize converts the given byte slice into a uint64 and assumes that the +// bytes slice is using BigSize encoding. +func readBigSize(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} + +// CommittedUpdate holds a state update sent by a client along with its +// allocated sequence number and the exact remote commitment the encrypted +// justice transaction can rectify. +type CommittedUpdate struct { + // SeqNum is the unique sequence number allocated by the session to this + // update. + SeqNum uint16 + + CommittedUpdateBody +} + +// BackupID identifies a particular revoked, remote commitment by channel id and +// commitment height. +type BackupID struct { + // ChanID is the channel id of the revoked commitment. + ChanID ChannelID + + // CommitHeight is the commitment height of the revoked commitment. + CommitHeight uint64 +} + +// Encode writes the BackupID from the passed io.Writer. +func (b *BackupID) Encode(w io.Writer) error { + return WriteElements(w, + b.ChanID, + b.CommitHeight, + ) +} + +// Decode reads a BackupID from the passed io.Reader. +func (b *BackupID) Decode(r io.Reader) error { + return ReadElements(r, + &b.ChanID, + &b.CommitHeight, + ) +} + +// String returns a human-readable encoding of a BackupID. +func (b BackupID) String() string { + return fmt.Sprintf("backup(%v, %d)", b.ChanID, b.CommitHeight) +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + switch e := element.(type) { + case ChannelID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case uint64: + if err := binary.Write(w, byteOrder, e); err != nil { + return err + } + + case BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case []byte: + if err := wire.WriteVarBytes(w, 0, e); err != nil { + return err + } + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + switch e := element.(type) { + case *ChannelID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *uint64: + if err := binary.Read(r, byteOrder, e); err != nil { + return err + } + + case *BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *[]byte: + bytes, err := wire.ReadVarBytes(r, 0, 66000, "[]byte") + if err != nil { + return err + } + + *e = bytes + + default: + return fmt.Errorf("unexpected type") + } + + return nil +} + +// CommittedUpdateBody represents the primary components of a CommittedUpdate. +// On disk, this is stored under the sequence number, which acts as its key. +type CommittedUpdateBody struct { + // BackupID identifies the breached commitment that the encrypted blob + // can spend from. + BackupID BackupID + + // Hint is the 16-byte prefix of the revoked commitment transaction ID. + Hint BreachHint + + // EncryptedBlob is a ciphertext containing the sweep information for + // exacting justice if the commitment transaction matching the breach + // hint is broadcast. + EncryptedBlob []byte +} + +// Encode writes the CommittedUpdateBody to the passed io.Writer. +func (u *CommittedUpdateBody) Encode(w io.Writer) error { + err := u.BackupID.Encode(w) + if err != nil { + return err + } + + return WriteElements(w, + u.Hint, + u.EncryptedBlob, + ) +} + +// Decode reads a CommittedUpdateBody from the passed io.Reader. +func (u *CommittedUpdateBody) Decode(r io.Reader) error { + err := u.BackupID.Decode(r) + if err != nil { + return err + } + + return ReadElements(r, + &u.Hint, + &u.EncryptedBlob, + ) +} + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} diff --git a/watchtower/wtdb/migration8/log.go b/watchtower/wtdb/migration8/log.go new file mode 100644 index 000000000..ab35682c5 --- /dev/null +++ b/watchtower/wtdb/migration8/log.go @@ -0,0 +1,14 @@ +package migration8 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/migration8/migration.go b/watchtower/wtdb/migration8/migration.go new file mode 100644 index 000000000..2e9d041e3 --- /dev/null +++ b/watchtower/wtdb/migration8/migration.go @@ -0,0 +1,223 @@ +package migration8 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => db-chan-id => start -> end + // => cSessionRogueUpdateCount -> count + cSessionBkt = []byte("client-session-bucket") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionBody is a sub-bucket of cSessionBkt storing: + // seqnum -> encoded CommittedUpdate. + cSessionCommits = []byte("client-session-commits") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height + // => cChanMaxCommitmentHeight -> commitment-height + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + cChanMaxCommitmentHeight = []byte( + "client-channel-max-commitment-height", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + byteOrder = binary.BigEndian +) + +// MigrateChannelMaxHeights migrates the tower client db by collecting all the +// max commitment heights that have been backed up for each channel and then +// storing those heights alongside the channel info. +func MigrateChannelMaxHeights(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB for quick channel max " + + "commitment height lookup") + + heights, err := collectChanMaxHeights(tx) + if err != nil { + return err + } + + return writeChanMaxHeights(tx, heights) +} + +// writeChanMaxHeights iterates over the given channel ID to height map and +// writes an entry under the cChanMaxCommitmentHeight key for each channel. +func writeChanMaxHeights(tx kvdb.RwTx, heights map[ChannelID]uint64) error { + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + for chanID, maxHeight := range heights { + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + + // If the details bucket for this channel ID does not exist, + // it is probably a channel that has been closed and deleted + // already. So we can skip this height. + if chanDetails == nil { + continue + } + + b, err := writeBigSize(maxHeight) + if err != nil { + return err + } + + err = chanDetails.Put(cChanMaxCommitmentHeight, b) + if err != nil { + return err + } + } + + return nil +} + +// collectChanMaxHeights iterates over all the sessions in the DB. For each +// session, it iterates over all the Acked updates and the committed updates +// to collect the maximum commitment height for each channel. +func collectChanMaxHeights(tx kvdb.RwTx) (map[ChannelID]uint64, error) { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return nil, ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return nil, ErrUninitializedDB + } + + heights := make(map[ChannelID]uint64) + + // For each update we consider, we will only update the heights map if + // the commitment height for the channel is larger than the current + // max height stored for the channel. + cb := func(chanID ChannelID, commitHeight uint64) { + if commitHeight > heights[chanID] { + heights[chanID] = commitHeight + } + } + + err := sessionsBkt.ForEach(func(sessIDBytes, _ []byte) error { + sessBkt := sessionsBkt.NestedReadBucket(sessIDBytes) + if sessBkt == nil { + return fmt.Errorf("bucket not found for session %x", + sessIDBytes) + } + + err := forEachCommittedUpdate(sessBkt, cb) + if err != nil { + return err + } + + return forEachAckedUpdate(sessBkt, chanIDIndexBkt, cb) + }) + if err != nil { + return nil, err + } + + return heights, nil +} + +// forEachCommittedUpdate iterates over all the given session's committed +// updates and calls the call-back for each. +func forEachCommittedUpdate(sessBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionCommits := sessBkt.NestedReadBucket(cSessionCommits) + if sessionCommits == nil { + return nil + } + + return sessionCommits.ForEach(func(k, v []byte) error { + var update CommittedUpdate + err := update.Decode(bytes.NewReader(v)) + if err != nil { + return err + } + + cb(update.BackupID.ChanID, update.BackupID.CommitHeight) + + return nil + }) +} + +// forEachAckedUpdate iterates over all the given session's acked update range +// indices and calls the call-back for each. +func forEachAckedUpdate(sessBkt, chanIDIndexBkt kvdb.RBucket, + cb func(chanID ChannelID, commitHeight uint64)) error { + + sessionAcksRanges := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if sessionAcksRanges == nil { + return nil + } + + return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error { + rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID) + if rangeBkt == nil { + return nil + } + + index, err := readRangeIndex(rangeBkt) + if err != nil { + return err + } + + chanIDBytes := chanIDIndexBkt.Get(dbChanID) + var chanID ChannelID + copy(chanID[:], chanIDBytes) + + cb(chanID, index.MaxHeight()) + + return nil + }) +} + +// readRangeIndex reads a persisted RangeIndex from the passed bucket and into +// a new in-memory RangeIndex. +func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) { + ranges := make(map[uint64]uint64) + err := rangesBkt.ForEach(func(k, v []byte) error { + start, err := readBigSize(k) + if err != nil { + return err + } + + end, err := readBigSize(v) + if err != nil { + return err + } + + ranges[start] = end + + return nil + }) + if err != nil { + return nil, err + } + + return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize)) +} diff --git a/watchtower/wtdb/migration8/migration_test.go b/watchtower/wtdb/migration8/migration_test.go new file mode 100644 index 000000000..336069bfd --- /dev/null +++ b/watchtower/wtdb/migration8/migration_test.go @@ -0,0 +1,214 @@ +package migration8 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/stretchr/testify/require" +) + +const ( + chan1ID = 10 + chan2ID = 20 + chan3ID = 30 + chan4ID = 40 + + chan1DBID = 111 + chan2DBID = 222 + chan3DBID = 333 +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{}, + channelIDString(chan2ID): map[string]interface{}{}, + channelIDString(chan3ID): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(chan1DBID): channelIDString(chan1ID), + uint64ToStr(chan2DBID): channelIDString(chan2ID), + uint64ToStr(chan3DBID): channelIDString(chan3ID), + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(chan1ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(105), + }, + channelIDString(chan2ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(205), + }, + channelIDString(chan3ID): map[string]interface{}{ + string(cChanMaxCommitmentHeight): uint64ToStr(304), + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + update1 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan1ID), + CommitHeight: 105, + }, + }, + } + var update1B bytes.Buffer + require.NoError(t, update1.Encode(&update1B)) + + update3 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan3ID), + CommitHeight: 304, + }, + }, + } + var update3B bytes.Buffer + require.NoError(t, update3.Encode(&update3B)) + + update4 := &CommittedUpdate{ + SeqNum: 1, + CommittedUpdateBody: CommittedUpdateBody{ + BackupID: BackupID{ + ChanID: intToChannelID(chan4ID), + CommitHeight: 400, + }, + }, + } + var update4B bytes.Buffer + require.NoError(t, update4.Encode(&update4B)) + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions := map[string]interface{}{ + // A session with both acked and committed updates. + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 1 a max height + // of 104. + uint64ToStr(chan1DBID): map[string]interface{}{ + uint64ToStr(100): uint64ToStr(101), + uint64ToStr(104): uint64ToStr(104), + }, + // This range index gives channel 2 a max height + // of 200. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(200): uint64ToStr(200), + }, + }, + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 1 a max + // height of 105 and so it overrides the heights + // from the range index. + uint64ToStr(1): update1B.String(), + }, + }, + // A session with only acked updates. + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + // This range index gives channel 2 a max height + // of 205. + uint64ToStr(chan2DBID): map[string]interface{}{ + uint64ToStr(201): uint64ToStr(205), + }, + }, + }, + // A session with only committed updates. + sessionIDString("3"): map[string]interface{}{ + string(cSessionCommits): map[string]interface{}{ + // This committed update gives channel 3 a max + // height of 304. + uint64ToStr(1): update3B.String(), + }, + }, + // This session only contains heights for channel 4 which has + // been closed and so this should have no effect. + sessionIDString("4"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(444): map[string]interface{}{ + uint64ToStr(400): uint64ToStr(402), + uint64ToStr(403): uint64ToStr(405), + }, + }, + string(cSessionCommits): map[string]interface{}{ + uint64ToStr(1): update4B.String(), + }, + }, + // A session with no updates. + sessionIDString("5"): map[string]interface{}{}, + } + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB(tx, cChanDetailsBkt, preDetails) + if err != nil { + return err + } + + err = migtest.RestoreDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.RestoreDB(tx, cChanIDIndexBkt, channelIDIndex) + } + + after := func(tx kvdb.RwTx) error { + err := migtest.VerifyDB(tx, cSessionBkt, sessions) + if err != nil { + return err + } + + return migtest.VerifyDB(tx, cChanDetailsBkt, postDetails) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelMaxHeights, false, + ) +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeBigSize(id) + if err != nil { + panic(err) + } + + return string(b) +} + +func intToChannelID(id uint64) ChannelID { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return chanID +} diff --git a/watchtower/wtdb/migration8/range_index.go b/watchtower/wtdb/migration8/range_index.go new file mode 100644 index 000000000..94f0e2030 --- /dev/null +++ b/watchtower/wtdb/migration8/range_index.go @@ -0,0 +1,619 @@ +package migration8 + +import ( + "fmt" + "sync" +) + +// rangeItem represents the start and end values of a range. +type rangeItem struct { + start uint64 + end uint64 +} + +// RangeIndexOption describes the signature of a functional option that can be +// used to modify the behaviour of a RangeIndex. +type RangeIndexOption func(*RangeIndex) + +// WithSerializeUint64Fn is a functional option that can be used to set the +// function to be used to do the serialization of a uint64 into a byte slice. +func WithSerializeUint64Fn(fn func(uint64) ([]byte, error)) RangeIndexOption { + return func(index *RangeIndex) { + index.serializeUint64 = fn + } +} + +// RangeIndex can be used to keep track of which numbers have been added to a +// set. It does so by keeping track of a sorted list of rangeItems. Each +// rangeItem has a start and end value of a range where all values in-between +// have been added to the set. It works well in situations where it is expected +// numbers in the set are not sparse. +type RangeIndex struct { + // set is a sorted list of rangeItem. + set []rangeItem + + // mu is used to ensure safe access to set. + mu sync.Mutex + + // serializeUint64 is the function that can be used to convert a uint64 + // to a byte slice. + serializeUint64 func(uint64) ([]byte, error) +} + +// NewRangeIndex constructs a new RangeIndex. An initial set of ranges may be +// passed to the function in the form of a map. +func NewRangeIndex(ranges map[uint64]uint64, + opts ...RangeIndexOption) (*RangeIndex, error) { + + index := &RangeIndex{ + serializeUint64: defaultSerializeUint64, + set: make([]rangeItem, 0), + } + + // Apply any functional options. + for _, o := range opts { + o(index) + } + + for s, e := range ranges { + if err := index.addRange(s, e); err != nil { + return nil, err + } + } + + return index, nil +} + +// addRange can be used to add an entire new range to the set. This method +// should only ever be called by NewRangeIndex to initialise the in-memory +// structure and so the RangeIndex mutex is not held during this method. +func (a *RangeIndex) addRange(start, end uint64) error { + // Check that the given range is valid. + if start > end { + return fmt.Errorf("invalid range. Start height %d is larger "+ + "than end height %d", start, end) + } + + // min is a helper closure that will return the minimum of two uint64s. + min := func(a, b uint64) uint64 { + if a < b { + return a + } + + return b + } + + // max is a helper closure that will return the maximum of two uint64s. + max := func(a, b uint64) uint64 { + if a > b { + return a + } + + return b + } + + // Collect the ranges that fall before and after the new range along + // with the start and end values of the new range. + var before, after []rangeItem + for _, x := range a.set { + // If the new start value can't extend the current ranges end + // value, then the two cannot be merged. The range is added to + // the group of ranges that fall before the new range. + if x.end+1 < start { + before = append(before, x) + continue + } + + // If the current ranges start value does not follow on directly + // from the new end value, then the two cannot be merged. The + // range is added to the group of ranges that fall after the new + // range. + if end+1 < x.start { + after = append(after, x) + continue + } + + // Otherwise, there is an overlap and so the two can be merged. + start = min(start, x.start) + end = max(end, x.end) + } + + // Re-construct the range index set. + a.set = append(append(before, rangeItem{ + start: start, + end: end, + }), after...) + + return nil +} + +// IsInIndex returns true if the given number is in the range set. +func (a *RangeIndex) IsInIndex(n uint64) bool { + a.mu.Lock() + defer a.mu.Unlock() + + _, isCovered := a.lowerBoundIndex(n) + + return isCovered +} + +// NumInSet returns the number of items covered by the range set. +func (a *RangeIndex) NumInSet() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + var numItems uint64 + for _, r := range a.set { + numItems += r.end - r.start + 1 + } + + return numItems +} + +// MaxHeight returns the highest number covered in the range. +func (a *RangeIndex) MaxHeight() uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + if len(a.set) == 0 { + return 0 + } + + return a.set[len(a.set)-1].end +} + +// GetAllRanges returns a copy of the range set in the form of a map. +func (a *RangeIndex) GetAllRanges() map[uint64]uint64 { + a.mu.Lock() + defer a.mu.Unlock() + + cp := make(map[uint64]uint64, len(a.set)) + for _, item := range a.set { + cp[item.start] = item.end + } + + return cp +} + +// lowerBoundIndex returns the index of the RangeIndex that is most appropriate +// for the new value, n. In other words, it returns the index of the rangeItem +// set of the range where the start value is the highest start value in the set +// that is still lower than or equal to the given number, n. The returned +// boolean is true if the given number is already covered in the RangeIndex. +// A returned index of -1 indicates that no lower bound range exists in the set. +// Since the most likely case is that the new number will just extend the +// highest range, a check is first done to see if this is the case which will +// make the methods' computational complexity O(1). Otherwise, a binary search +// is done which brings the computational complexity to O(log N). +func (a *RangeIndex) lowerBoundIndex(n uint64) (int, bool) { + // If the set is empty, then there is no such index and the value + // definitely is not in the set. + if len(a.set) == 0 { + return -1, false + } + + // In most cases, the last index item will be the one we want. So just + // do a quick check on that index first to avoid doing the binary + // search. + lastIndex := len(a.set) - 1 + lastRange := a.set[lastIndex] + if lastRange.start <= n { + return lastIndex, lastRange.end >= n + } + + // Otherwise, do a binary search to find the index of interest. + var ( + low = 0 + high = len(a.set) - 1 + rangeIndex = -1 + ) + for { + mid := (low + high) / 2 + currentRange := a.set[mid] + + switch { + case currentRange.start > n: + // If the start of the range is greater than n, we can + // completely cut out that entire part of the array. + high = mid + + case currentRange.start < n: + // If the range already includes the given height, we + // can stop searching now. + if currentRange.end >= n { + return mid, true + } + + // If the start of the range is smaller than n, we can + // store this as the new best index to return. + rangeIndex = mid + + // If low and mid are already equal, then increment low + // by 1. Exit if this means that low is now greater than + // high. + if low == mid { + low = mid + 1 + if low > high { + return rangeIndex, false + } + } else { + low = mid + } + + continue + + default: + // If the height is equal to the start value of the + // current range that mid is pointing to, then the + // height is already covered. + return mid, true + } + + // Exit if we have checked all the ranges. + if low == high { + break + } + } + + return rangeIndex, false +} + +// KVStore is an interface representing a key-value store. +type KVStore interface { + // Put saves the specified key/value pair to the store. Keys that do not + // already exist are added and keys that already exist are overwritten. + Put(key, value []byte) error + + // Delete removes the specified key from the bucket. Deleting a key that + // does not exist does not return an error. + Delete(key []byte) error +} + +// Add adds a single number to the range set. It first attempts to apply the +// necessary changes to the passed KV store and then only if this succeeds, will +// the changes be applied to the in-memory structure. +func (a *RangeIndex) Add(newHeight uint64, kv KVStore) error { + a.mu.Lock() + defer a.mu.Unlock() + + // Compute the changes that will need to be applied to both the sorted + // rangeItem array representation and the key-value store representation + // of the range index. + arrayChanges, kvStoreChanges := a.getChanges(newHeight) + + // First attempt to apply the KV store changes. Only if this succeeds + // will we apply the changes to our in-memory range index structure. + err := a.applyKVChanges(kv, kvStoreChanges) + if err != nil { + return err + } + + // Since the DB changes were successful, we can now commit the + // changes to our in-memory representation of the range set. + a.applyArrayChanges(arrayChanges) + + return nil +} + +// applyKVChanges applies the given set of kvChanges to a KV store. It is +// assumed that a transaction is being held on the kv store so that if any +// of the actions of the function fails, the changes will be reverted. +func (a *RangeIndex) applyKVChanges(kv KVStore, changes *kvChanges) error { + // Exit early if there are no changes to apply. + if kv == nil || changes == nil { + return nil + } + + // Check if any range pair needs to be deleted. + if changes.deleteKVKey != nil { + del, err := a.serializeUint64(*changes.deleteKVKey) + if err != nil { + return err + } + + if err := kv.Delete(del); err != nil { + return err + } + } + + start, err := a.serializeUint64(changes.key) + if err != nil { + return err + } + + end, err := a.serializeUint64(changes.value) + if err != nil { + return err + } + + return kv.Put(start, end) +} + +// applyArrayChanges applies the given arrayChanges to the in-memory RangeIndex +// itself. This should only be done once the persisted kv store changes have +// already been applied. +func (a *RangeIndex) applyArrayChanges(changes *arrayChanges) { + if changes == nil { + return + } + + if changes.indexToDelete != nil { + a.set = append( + a.set[:*changes.indexToDelete], + a.set[*changes.indexToDelete+1:]..., + ) + } + + if changes.newIndex != nil { + switch { + case *changes.newIndex == 0: + a.set = append([]rangeItem{{ + start: changes.start, + end: changes.end, + }}, a.set...) + + case *changes.newIndex == len(a.set): + a.set = append(a.set, rangeItem{ + start: changes.start, + end: changes.end, + }) + + default: + a.set = append( + a.set[:*changes.newIndex+1], + a.set[*changes.newIndex:]..., + ) + a.set[*changes.newIndex] = rangeItem{ + start: changes.start, + end: changes.end, + } + } + + return + } + + if changes.indexToEdit != nil { + a.set[*changes.indexToEdit] = rangeItem{ + start: changes.start, + end: changes.end, + } + } +} + +// arrayChanges encompasses the diff to apply to the sorted rangeItem array +// representation of a range index. Such a diff will either include adding a +// new range or editing an existing range. If an existing range is edited, then +// the diff might also include deleting an index (this will be the case if the +// editing of the one range results in the merge of another range). +type arrayChanges struct { + start uint64 + end uint64 + + // newIndex, if set, is the index of the in-memory range array where a + // new range, [start:end], should be added. newIndex should never be + // set at the same time as indexToEdit or indexToDelete. + newIndex *int + + // indexToDelete, if set, is the index of the sorted rangeItem array + // that should be deleted. This should be applied before reading the + // index value of indexToEdit. This should not be set at the same time + // as newIndex. + indexToDelete *int + + // indexToEdit is the index of the in-memory range array that should be + // edited. The range at this index will be changed to [start:end]. This + // should only be read after indexToDelete index has been deleted. + indexToEdit *int +} + +// kvChanges encompasses the diff to apply to a KV-store representation of a +// range index. A kv-store diff for the addition of a single number to the range +// index will include either a brand new key-value pair or the altering of the +// value of an existing key. Optionally, the diff may also include the deletion +// of an existing key. A deletion will be required if the addition of the new +// number results in the merge of two ranges. +type kvChanges struct { + key uint64 + value uint64 + + // deleteKVKey, if set, is the key of the kv store representation that + // should be deleted. + deleteKVKey *uint64 +} + +// getChanges will calculate and return the changes that need to be applied to +// both the sorted-rangeItem-array representation and the key-value store +// representation of the range index. +func (a *RangeIndex) getChanges(n uint64) (*arrayChanges, *kvChanges) { + // If the set is empty then a new range item is added. + if len(a.set) == 0 { + // For the array representation, a new range [n:n] is added to + // the first index of the array. + firstIndex := 0 + ac := &arrayChanges{ + newIndex: &firstIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Find the index of the lower bound range to the new number. + indexOfRangeBelow, alreadyCovered := a.lowerBoundIndex(n) + + switch { + // The new number is already covered by the range index. No changes are + // required. + case alreadyCovered: + return nil, nil + + // No lower bound index exists. + case indexOfRangeBelow < 0: + // Check if the very first range can be merged into this new + // one. + if n+1 == a.set[0].start { + // If so, the two ranges can be merged and so the start + // value of the range is n and the end value is the end + // of the existing first range. + start := n + end := a.set[0].end + + // For the array representation, we can just edit the + // first entry of the array + editIndex := 0 + ac := &arrayChanges{ + indexToEdit: &editIndex, + start: start, + end: end, + } + + // For the KV store representation, we add a new kv pair + // and delete the range with the key equal to the start + // value of the range we are merging. + kvKeyToDelete := a.set[0].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &kvKeyToDelete, + } + + return ac, kvc + } + + // Otherwise, we add a new index. + + // For the array representation, a new range [n:n] is added to + // the first index of the array. + newIndex := 0 + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + // For the KV representation, a new [n:n] pair is added. + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + + // A lower range does exist, and it can be extended to include this new + // number. + case a.set[indexOfRangeBelow].end+1 == n: + start := a.set[indexOfRangeBelow].start + end := n + indexToChange := indexOfRangeBelow + + // If there are no intervals above this one or if there are, but + // they can't be merged into this one then we just need to edit + // this interval. + if indexOfRangeBelow == len(a.set)-1 || + a.set[indexOfRangeBelow+1].start != n+1 { + + // For the array representation, we just edit the index. + ac := &arrayChanges{ + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the key-value representation, we just overwrite + // the end value at the existing start key. + kvc := &kvChanges{ + key: start, + value: end, + } + + return ac, kvc + } + + // There is a range above this one that we need to merge into + // this one. + delIndex := indexOfRangeBelow + 1 + end = a.set[delIndex].end + + // For the array representation, we delete the range above this + // one and edit this range to include the end value of the range + // above. + ac := &arrayChanges{ + indexToDelete: &delIndex, + indexToEdit: &indexToChange, + start: start, + end: end, + } + + // For the kv representation, we tweak the end value of an + // existing key and delete the key of the range we are deleting. + deleteKey := a.set[delIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &deleteKey, + } + + return ac, kvc + + // A lower range does exist, but it can't be extended to include this + // new number, and so we need to add a new range after the lower bound + // range. + default: + newIndex := indexOfRangeBelow + 1 + + // If there are no ranges above this new one or if there are, + // but they can't be merged into this new one, then we can just + // add the new one as is. + if newIndex == len(a.set) || a.set[newIndex].start != n+1 { + ac := &arrayChanges{ + newIndex: &newIndex, + start: n, + end: n, + } + + kvc := &kvChanges{ + key: n, + value: n, + } + + return ac, kvc + } + + // Else, we merge the above index. + start := n + end := a.set[newIndex].end + toEdit := newIndex + + // For the array representation, we edit the range above to + // include the new start value. + ac := &arrayChanges{ + indexToEdit: &toEdit, + start: start, + end: end, + } + + // For the kv representation, we insert the new start-end key + // value pair and delete the key using the old start value. + delKey := a.set[newIndex].start + kvc := &kvChanges{ + key: start, + value: end, + deleteKVKey: &delKey, + } + + return ac, kvc + } +} + +func defaultSerializeUint64(i uint64) ([]byte, error) { + var b [8]byte + byteOrder.PutUint64(b[:], i) + return b[:], nil +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index b44ed80eb..dd9c55472 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration8" ) // txMigration is a function which takes a prior outdated version of the @@ -67,6 +68,9 @@ var clientDBVersions = []version{ { txMigration: migration7.MigrateChannelToSessionIndex, }, + { + txMigration: migration8.MigrateChannelMaxHeights, + }, } // getLatestDBVersion returns the last known database version.