diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 2079a0c6f..014e503db 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -23,8 +23,13 @@ var ( // cChanDetailsBkt is a top-level bucket storing: // channel-id => cChannelSummary -> encoded ClientChanSummary. // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 cChanDetailsBkt = []byte("client-channel-detail-bucket") + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // db-session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + // cChanDBID is a key used in the cChanDetailsBkt to store the // db-assigned-id of a channel. cChanDBID = []byte("client-channel-db-id") @@ -1454,7 +1459,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return ErrUninitializedDB } - chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) if chanDetailsBkt == nil { return ErrUninitializedDB } @@ -1538,6 +1543,23 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } + dbSessionID, _, err := getDBSessionID(sessions, *id) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + committedUpdate.BackupID.ChanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + err = putChannelToSessionMapping(chanDetails, dbSessionID) + if err != nil { + return err + } + // Get the range index for the given session-channel pair. index, err := c.getRangeIndex(tx, *id, chanID) if err != nil { @@ -1548,6 +1570,26 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, }, func() {}) } +// putChannelToSessionMapping adds the given session ID to a channel's +// cChanSessions bucket. +func putChannelToSessionMapping(chanDetails kvdb.RwBucket, + dbSessID uint64) error { + + chanSessIDsBkt, err := chanDetails.CreateBucketIfNotExists( + cChanSessions, + ) + if err != nil { + return err + } + + b, err := writeBigSize(dbSessID) + if err != nil { + return err + } + + return chanSessIDsBkt.Put(b, []byte{1}) +} + // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize // the CommittedUpdates, AckUpdates or the Tower associated with the session. diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 638542abb..639030631 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -9,6 +9,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // log is a logger that is initialized with no output filters. This @@ -38,6 +39,7 @@ func UseLogger(logger btclog.Logger) { migration4.UseLogger(logger) migration5.UseLogger(logger) migration6.UseLogger(logger) + migration7.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration7/client_db.go b/watchtower/wtdb/migration7/client_db.go new file mode 100644 index 000000000..0c3c4be40 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db.go @@ -0,0 +1,202 @@ +package migration7 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => chan-id => acked-index-range + cSessionBkt = []byte("client-session-bucket") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + // cChannelSummary is a sub-bucket of cChanDetailsBkt which stores the + // encoded body of ClientChanSummary. + cChannelSummary = []byte("client-channel-summary") + + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing: + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-d of a session. + cSessionDBID = []byte("client-session-db-id") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian +) + +// MigrateChannelToSessionIndex migrates the tower client DB to add an index +// from channel-to-session. This will make it easier in future to check which +// sessions have updates for which channels. +func MigrateChannelToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB to build a new " + + "channel-to-session index") + + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + chanIDsBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDsBkt == nil { + return ErrUninitializedDB + } + + // First gather all the new channel-to-session pairs that we want to + // add. + index, err := collectIndex(sessionsBkt) + if err != nil { + return err + } + + // Then persist those pairs to the db. + return persistIndex(chanDetailsBkt, chanIDsBkt, index) +} + +// collectIndex iterates through all the sessions and uses the keys in the +// cSessionAckRangeIndex bucket to collect all the channels that the session +// has updates for. The function returns a map from channel ID to session ID +// (using the db-assigned IDs for both). +func collectIndex(sessionsBkt kvdb.RBucket) (map[uint64]map[uint64]bool, + error) { + + index := make(map[uint64]map[uint64]bool) + err := sessionsBkt.ForEach(func(sessID, _ []byte) error { + sessionBkt := sessionsBkt.NestedReadBucket(sessID) + if sessionBkt == nil { + return ErrCorruptClientSession + } + + ackedRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if ackedRanges == nil { + return nil + } + + sessDBIDBytes := sessionBkt.Get(cSessionDBID) + if sessDBIDBytes == nil { + return ErrCorruptClientSession + } + + sessDBID, err := readUint64(sessDBIDBytes) + if err != nil { + return err + } + + return ackedRanges.ForEach(func(dbChanIDBytes, _ []byte) error { + dbChanID, err := readUint64(dbChanIDBytes) + if err != nil { + return err + } + + if _, ok := index[dbChanID]; !ok { + index[dbChanID] = make(map[uint64]bool) + } + + index[dbChanID][sessDBID] = true + + return nil + }) + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// persistIndex adds the channel-to-session mapping in each channel's details +// bucket. +func persistIndex(chanDetailsBkt kvdb.RwBucket, chanIDsBkt kvdb.RBucket, + index map[uint64]map[uint64]bool) error { + + for dbChanID, sessIDs := range index { + dbChanIDBytes, err := writeUint64(dbChanID) + if err != nil { + return err + } + + realChanID := chanIDsBkt.Get(dbChanIDBytes) + + chanBkt := chanDetailsBkt.NestedReadWriteBucket(realChanID) + if chanBkt == nil { + return fmt.Errorf("channel not found") + } + + sessIDsBkt, err := chanBkt.CreateBucket(cChanSessions) + if err != nil { + return err + } + + for id := range sessIDs { + sessID, err := writeUint64(id) + if err != nil { + return err + } + + err = sessIDsBkt.Put(sessID, []byte{1}) + if err != nil { + return err + } + } + } + + return nil +} + +func writeUint64(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +func readUint64(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} diff --git a/watchtower/wtdb/migration7/client_db_test.go b/watchtower/wtdb/migration7/client_db_test.go new file mode 100644 index 000000000..4f90edc47 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db_test.go @@ -0,0 +1,191 @@ +package migration7 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + }, + } + + // preFailCorruptDB should fail the migration due to no channel summary + // being found for a given channel ID. + preFailCorruptDB = map[string]interface{}{ + channelIDString(30): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(10): channelIDString(100), + uint64ToStr(20): channelIDString(222), + } + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(32), + uint64ToStr(34): uint64ToStr(34), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(30), + }, + }, + string(cSessionDBID): uint64ToStr(66), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(33): uint64ToStr(33), + }, + }, + string(cSessionDBID): uint64ToStr(77), + }, + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + uint64ToStr(77): string([]byte{1}), + }, + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + }, + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + preDetails map[string]interface{} + preSessions map[string]interface{} + preChanIndex map[string]interface{} + postDetails map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + preDetails: preDetails, + preSessions: sessions, + preChanIndex: channelIDIndex, + postDetails: postDetails, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + preDetails: preFailCorruptDB, + preSessions: sessions, + }, + { + name: "no sessions", + shouldFail: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB( + tx, cChanDetailsBkt, test.preDetails, + ) + if err != nil { + return err + } + + err = migtest.RestoreDB( + tx, cSessionBkt, test.preSessions, + ) + if err != nil { + return err + } + + return migtest.RestoreDB( + tx, cChanIDIndexBkt, test.preChanIndex, + ) + } + + after := func(tx kvdb.RwTx) error { + // If the migration fails, the details bucket + // should be untouched. + if test.shouldFail { + if err := migtest.VerifyDB( + tx, cChanDetailsBkt, + test.preDetails, + ); err != nil { + return err + } + + return nil + } + + // Else, we expect an updated details bucket + // and a new index bucket. + return migtest.VerifyDB( + tx, cChanDetailsBkt, test.postDetails, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeUint64(id) + if err != nil { + panic(err) + } + + return string(b) +} diff --git a/watchtower/wtdb/migration7/codec.go b/watchtower/wtdb/migration7/codec.go new file mode 100644 index 000000000..e94cfe67d --- /dev/null +++ b/watchtower/wtdb/migration7/codec.go @@ -0,0 +1,29 @@ +package migration7 + +import "encoding/hex" + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// String returns the string representation of the ChannelID. This is just the +// hex string encoding of the ChannelID itself. +func (c ChannelID) String() string { + return hex.EncodeToString(c[:]) +} diff --git a/watchtower/wtdb/migration7/log.go b/watchtower/wtdb/migration7/log.go new file mode 100644 index 000000000..39f28b6c0 --- /dev/null +++ b/watchtower/wtdb/migration7/log.go @@ -0,0 +1,14 @@ +package migration7 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 1a186044a..b44ed80eb 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // txMigration is a function which takes a prior outdated version of the @@ -63,6 +64,9 @@ var clientDBVersions = []version{ { txMigration: migration6.MigrateSessionIDIndex, }, + { + txMigration: migration7.MigrateChannelToSessionIndex, + }, } // getLatestDBVersion returns the last known database version.