diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 9d8383cb5..537c8cc73 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -48,6 +48,12 @@ var ( // tower-pubkey -> tower-id. cTowerIndexBkt = []byte("client-tower-index-bucket") + // cTowerToSessionIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerToSessionIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -196,6 +202,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cSessionBkt, cTowerBkt, cTowerIndexBkt, + cTowerToSessionIndexBkt, } for _, bucket := range buckets { @@ -260,6 +267,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { return ErrUninitializedDB } + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check if the tower index already knows of this pubkey. towerIDBytes := towerIndex.Get(towerPubKey[:]) if len(towerIDBytes) == 8 { @@ -321,6 +335,13 @@ func (c *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*Tower, error) { if err != nil { return err } + + // Create a new bucket for this tower in the + // tower-to-sessions index. + _, err = towerToSessionIndex.CreateBucket(towerIDBytes) + if err != nil { + return err + } } // Store the new or updated tower under its tower id. @@ -349,11 +370,19 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if towers == nil { return ErrUninitializedDB } + towerIndex := tx.ReadWriteBucket(cTowerIndexBkt) if towerIndex == nil { return ErrUninitializedDB } + towersToSessionsIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towersToSessionsIndex == nil { + return ErrUninitializedDB + } + // Don't return an error if the watchtower doesn't exist to act // as a NOP. pubKeyBytes := pubKey.SerializeCompressed() @@ -402,7 +431,14 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { if err := towerIndex.Delete(pubKeyBytes); err != nil { return err } - return towers.Delete(towerIDBytes) + + if err := towers.Delete(towerIDBytes); err != nil { + return err + } + + return towersToSessionsIndex.DeleteNestedBucket( + towerIDBytes, + ) } // We'll mark its sessions as inactive as long as they don't @@ -581,6 +617,13 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { return ErrUninitializedDB } + towerToSessionIndex := tx.ReadWriteBucket( + cTowerToSessionIndexBkt, + ) + if towerToSessionIndex == nil { + return ErrUninitializedDB + } + // Check that client session with this session id doesn't // already exist. existingSessionBytes := sessions.NestedReadWriteBucket( @@ -625,6 +668,19 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { } } + // Add the new entry to the towerID-to-SessionID index. + indexBkt := towerToSessionIndex.NestedReadWriteBucket( + towerID.Bytes(), + ) + if indexBkt == nil { + return ErrTowerNotFound + } + + err = indexBkt.Put(session.ID[:], []byte{1}) + if err != nil { + return err + } + // Finally, write the client session's body in the sessions // bucket. return putClientSessionBody(sessions, session) diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 0e14ea996..6ddb6c35f 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // log is a logger that is initialized with no output filters. This @@ -26,6 +27,7 @@ func DisableLog() { // using btclog. func UseLogger(logger btclog.Logger) { log = logger + migration1.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration1/client_db.go b/watchtower/wtdb/migration1/client_db.go new file mode 100644 index 000000000..d09ef6ef7 --- /dev/null +++ b/watchtower/wtdb/migration1/client_db.go @@ -0,0 +1,145 @@ +package migration1 + +import ( + "bytes" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // cTowerIDToSessionIDIndexBkt is a top-level bucket storing: + // tower-id -> session-id -> 1 + cTowerIDToSessionIDIndexBkt = []byte( + "client-tower-to-session-index-bucket", + ) + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrClientSessionNotFound signals that the requested client session + // was not found in the database. + ErrClientSessionNotFound = errors.New("client session not found") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") +) + +// MigrateTowerToSessionIndex constructs a new towerID-to-sessionID for the +// watchtower client DB. +func MigrateTowerToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client db to add a " + + "towerID-to-sessionID index") + + // First, we collect all the entries we want to add to the index. + entries, err := getIndexEntries(tx) + if err != nil { + return err + } + + // Then we create a new top-level bucket for the index. + indexBkt, err := tx.CreateTopLevelBucket(cTowerIDToSessionIDIndexBkt) + if err != nil { + return err + } + + // Finally, we add all the collected entries to the index. + for towerID, sessions := range entries { + // Create a sub-bucket using the tower ID. + towerBkt, err := indexBkt.CreateBucketIfNotExists( + towerID.Bytes(), + ) + if err != nil { + return err + } + + for sessionID := range sessions { + err := addIndex(towerBkt, sessionID) + if err != nil { + return err + } + } + } + + return nil +} + +// addIndex adds a new towerID-sessionID pair to the given bucket. The +// session ID is used as a key within the bucket and a value of []byte{1} is +// used for each session ID key. +func addIndex(towerBkt kvdb.RwBucket, sessionID SessionID) error { + session := towerBkt.Get(sessionID[:]) + if session != nil { + return fmt.Errorf("session %x duplicated", sessionID) + } + + return towerBkt.Put(sessionID[:], []byte{1}) +} + +// getIndexEntries collects all the towerID-sessionID entries that need to be +// added to the new index. +func getIndexEntries(tx kvdb.RwTx) (map[TowerID]map[SessionID]bool, error) { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return nil, ErrUninitializedDB + } + + index := make(map[TowerID]map[SessionID]bool) + err := sessions.ForEach(func(k, _ []byte) error { + session, err := getClientSession(sessions, k) + if err != nil { + return err + } + + if index[session.TowerID] == nil { + index[session.TowerID] = make(map[SessionID]bool) + } + + index[session.TowerID][session.ID] = true + return nil + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// getClientSession fetches the session with the given ID from the db. +func getClientSession(sessions kvdb.RBucket, idBytes []byte) (*ClientSession, + error) { + + sessionBkt := sessions.NestedReadBucket(idBytes) + if sessionBkt == nil { + return nil, ErrClientSessionNotFound + } + + // Should never have a sessionBkt without also having its body. + sessionBody := sessionBkt.Get(cSessionBody) + if sessionBody == nil { + return nil, ErrCorruptClientSession + } + + var session ClientSession + copy(session.ID[:], idBytes) + + err := session.Decode(bytes.NewReader(sessionBody)) + if err != nil { + return nil, err + } + + return &session, nil +} diff --git a/watchtower/wtdb/migration1/client_db_test.go b/watchtower/wtdb/migration1/client_db_test.go new file mode 100644 index 000000000..acae177ad --- /dev/null +++ b/watchtower/wtdb/migration1/client_db_test.go @@ -0,0 +1,155 @@ +package migration1 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + s1 = &ClientSessionBody{ + TowerID: TowerID(1), + } + s2 = &ClientSessionBody{ + TowerID: TowerID(3), + } + s3 = &ClientSessionBody{ + TowerID: TowerID(6), + } + + // pre is the expected data in the DB before the migration. + pre = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s3), + }, + sessionIDString("3"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("4"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("5"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s2), + }, + } + + // preFailNoSessionBody should fail the migration due to there being a + // session without an associated session body. + preFailNoSessionBody = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionBody): clientSessionString(s1), + }, + sessionIDString("2"): map[string]interface{}{}, + } + + // post is the expected data after migration. + post = map[string]interface{}{ + towerIDString(1): map[string]interface{}{ + sessionIDString("1"): string([]byte{1}), + sessionIDString("3"): string([]byte{1}), + sessionIDString("4"): string([]byte{1}), + }, + towerIDString(3): map[string]interface{}{ + sessionIDString("5"): string([]byte{1}), + }, + towerIDString(6): map[string]interface{}{ + sessionIDString("2"): string([]byte{1}), + }, + } +) + +// TestMigrateTowerToSessionIndex tests that the TestMigrateTowerToSessionIndex +// function correctly adds a new towerID-to-sessionID index to the tower client +// db. +func TestMigrateTowerToSessionIndex(t *testing.T) { + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + post map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + post: post, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailNoSessionBody, + post: nil, + }, + { + name: "no sessions", + shouldFail: false, + pre: nil, + post: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Before the migration we have a sessions bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, cSessionBkt, test.pre, + ) + } + + // After the migration, we should have an untouched + // sessions bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + if err := migtest.VerifyDB( + tx, cSessionBkt, test.pre, + ); err != nil { + return err + } + + // If we expect our migration to fail, we don't + // expect an index bucket. + if test.shouldFail { + return nil + } + + return migtest.VerifyDB( + tx, cTowerIDToSessionIDIndexBkt, + test.post, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateTowerToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return string(sessID[:]) +} + +func clientSessionString(s *ClientSessionBody) string { + var b bytes.Buffer + err := s.Encode(&b) + if err != nil { + panic(err) + } + + return b.String() +} + +func towerIDString(id int) string { + towerID := TowerID(id) + return string(towerID.Bytes()) +} diff --git a/watchtower/wtdb/migration1/codec.go b/watchtower/wtdb/migration1/codec.go new file mode 100644 index 000000000..8c5a2299c --- /dev/null +++ b/watchtower/wtdb/migration1/codec.go @@ -0,0 +1,241 @@ +package migration1 + +import ( + "encoding/binary" + "io" + + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/watchtower/blob" + "github.com/lightningnetwork/lnd/watchtower/wtpolicy" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// UnknownElementType is an alias for channeldb.UnknownElementType. +type UnknownElementType = channeldb.UnknownElementType + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// TowerID is a unique 64-bit identifier allocated to each unique watchtower. +// This allows the client to conserve on-disk space by not needing to always +// reference towers by their pubkey. +type TowerID uint64 + +// Bytes encodes a TowerID into an 8-byte slice in big-endian byte order. +func (id TowerID) Bytes() []byte { + var buf [8]byte + binary.BigEndian.PutUint64(buf[:], uint64(id)) + return buf[:] +} + +// ClientSession encapsulates a SessionInfo returned from a successful +// session negotiation, and also records the tower and ephemeral secret used for +// communicating with the tower. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID SessionID + ClientSessionBody +} + +// CSessionStatus is a bit-field representing the possible statuses of +// ClientSessions. +type CSessionStatus uint8 + +type ClientSessionBody struct { + // SeqNum is the next unallocated sequence number that can be sent to + // the tower. + SeqNum uint16 + + // TowerLastApplied the last last-applied the tower has echoed back. + TowerLastApplied uint16 + + // TowerID is the unique, db-assigned identifier that references the + // Tower with which the session is negotiated. + TowerID TowerID + + // KeyIndex is the index of key locator used to derive the client's + // session key so that it can authenticate with the tower to update its + // session. In order to rederive the private key, the key locator should + // use the keychain.KeyFamilyTowerSession key family. + KeyIndex uint32 + + // Policy holds the negotiated session parameters. + Policy wtpolicy.Policy + + // Status indicates the current state of the ClientSession. + Status CSessionStatus + + // RewardPkScript is the pkscript that the tower's reward will be + // deposited to if a sweep transaction confirms and the sessions + // specifies a reward output. + RewardPkScript []byte +} + +// Encode writes a ClientSessionBody to the passed io.Writer. +func (s *ClientSessionBody) Encode(w io.Writer) error { + return WriteElements(w, + s.SeqNum, + s.TowerLastApplied, + uint64(s.TowerID), + s.KeyIndex, + uint8(s.Status), + s.Policy, + s.RewardPkScript, + ) +} + +// Decode reads a ClientSessionBody from the passed io.Reader. +func (s *ClientSessionBody) Decode(r io.Reader) error { + var ( + towerID uint64 + status uint8 + ) + err := ReadElements(r, + &s.SeqNum, + &s.TowerLastApplied, + &towerID, + &s.KeyIndex, + &status, + &s.Policy, + &s.RewardPkScript, + ) + if err != nil { + return err + } + + s.TowerID = TowerID(towerID) + s.Status = CSessionStatus(status) + + return nil +} + +// WriteElements serializes a variadic list of elements into the given +// io.Writer. +func WriteElements(w io.Writer, elements ...interface{}) error { + for _, element := range elements { + if err := WriteElement(w, element); err != nil { + return err + } + } + + return nil +} + +// WriteElement serializes a single element into the provided io.Writer. +func WriteElement(w io.Writer, element interface{}) error { + err := channeldb.WriteElement(w, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case SessionID: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case blob.BreachHint: + if _, err := w.Write(e[:]); err != nil { + return err + } + + case wtpolicy.Policy: + return channeldb.WriteElements(w, + uint16(e.BlobType), + e.MaxUpdates, + e.RewardBase, + e.RewardRate, + uint64(e.SweepFeeRate), + ) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "WriteElement", element, + ) + } + + return nil +} + +// ReadElements deserializes the provided io.Reader into a variadic list of +// target elements. +func ReadElements(r io.Reader, elements ...interface{}) error { + for _, element := range elements { + if err := ReadElement(r, element); err != nil { + return err + } + } + + return nil +} + +// ReadElement deserializes a single element from the provided io.Reader. +func ReadElement(r io.Reader, element interface{}) error { + err := channeldb.ReadElement(r, element) + switch { + // Known to channeldb codec. + case err == nil: + return nil + + // Fail if error is not UnknownElementType. + default: + if _, ok := err.(UnknownElementType); !ok { + return err + } + } + + // Process any wtdb-specific extensions to the codec. + switch e := element.(type) { + case *SessionID: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *blob.BreachHint: + if _, err := io.ReadFull(r, e[:]); err != nil { + return err + } + + case *wtpolicy.Policy: + var ( + blobType uint16 + sweepFeeRate uint64 + ) + err := channeldb.ReadElements(r, + &blobType, + &e.MaxUpdates, + &e.RewardBase, + &e.RewardRate, + &sweepFeeRate, + ) + if err != nil { + return err + } + + e.BlobType = blob.Type(blobType) + e.SweepFeeRate = chainfee.SatPerKWeight(sweepFeeRate) + + // Type is still unknown to wtdb extensions, fail. + default: + return channeldb.NewUnknownElementType( + "ReadElement", element, + ) + } + + return nil +} diff --git a/watchtower/wtdb/migration1/log.go b/watchtower/wtdb/migration1/log.go new file mode 100644 index 000000000..1dc105280 --- /dev/null +++ b/watchtower/wtdb/migration1/log.go @@ -0,0 +1,14 @@ +package migration1 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 229b8a9dd..4785b0ae2 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -3,6 +3,7 @@ package wtdb import ( "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" ) // migration is a function which takes a prior outdated version of the database @@ -24,7 +25,11 @@ var towerDBVersions = []version{} // clientDBVersions stores all versions and migrations of the client database. // This list will be used when opening the database to determine if any // migrations must be applied. -var clientDBVersions = []version{} +var clientDBVersions = []version{ + { + migration: migration1.MigrateTowerToSessionIndex, + }, +} // getLatestDBVersion returns the last known database version. func getLatestDBVersion(versions []version) uint32 {