From b16df45076c7fcf2657137b018c8b8f2af280bbc Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 9 Mar 2023 10:26:06 +0200 Subject: [PATCH 01/19] watchtower: add sessionID index In this commit, a new session-ID index is added to the tower client db with the help of a migration. This index holds a mapping from a db-assigned-ID (a uint64 encoded using BigSize encoding) to real session ID (33 bytes). This mapping will help us save space in future when persisting references to sessions. --- watchtower/wtdb/client_db.go | 82 ++++++++++- watchtower/wtdb/log.go | 2 + watchtower/wtdb/migration6/client_db.go | 114 ++++++++++++++ watchtower/wtdb/migration6/client_db_test.go | 147 +++++++++++++++++++ watchtower/wtdb/migration6/codec.go | 17 +++ watchtower/wtdb/migration6/log.go | 14 ++ watchtower/wtdb/version.go | 4 + 7 files changed, 373 insertions(+), 7 deletions(-) create mode 100644 watchtower/wtdb/migration6/client_db.go create mode 100644 watchtower/wtdb/migration6/client_db_test.go create mode 100644 watchtower/wtdb/migration6/codec.go create mode 100644 watchtower/wtdb/migration6/log.go diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 3cbe0c6f1..2079a0c6f 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -35,10 +35,15 @@ var ( // cSessionBkt is a top-level bucket storing: // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id // => cSessionCommits => seqnum -> encoded CommittedUpdate // => cSessionAckRangeIndex => db-chan-id => start -> end cSessionBkt = []byte("client-session-bucket") + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-id of a session. + cSessionDBID = []byte("client-session-db-id") + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of // the ClientSession. cSessionBody = []byte("client-session-body") @@ -55,6 +60,10 @@ var ( // db-assigned-id -> channel-ID cChanIDIndexBkt = []byte("client-channel-id-index") + // cSessionIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> session-id + cSessionIDIndexBkt = []byte("client-session-id-index") + // cTowerBkt is a top-level bucket storing: // tower-id -> encoded Tower. cTowerBkt = []byte("client-tower-bucket") @@ -241,6 +250,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cTowerIndexBkt, cTowerToSessionIndexBkt, cChanIDIndexBkt, + cSessionIDIndexBkt, } for _, bucket := range buckets { @@ -723,24 +733,58 @@ func (c *ClientDB) CreateClientSession(session *ClientSession) error { } } - // Add the new entry to the towerID-to-SessionID index. - indexBkt := towerToSessionIndex.NestedReadWriteBucket( - towerID.Bytes(), - ) - if indexBkt == nil { - return ErrTowerNotFound + // Get the session-ID index bucket. + dbIDIndex := tx.ReadWriteBucket(cSessionIDIndexBkt) + if dbIDIndex == nil { + return ErrUninitializedDB } - err = indexBkt.Put(session.ID[:], []byte{1}) + // Get a new, unique, ID for this session from the session-ID + // index bucket. + nextSeq, err := dbIDIndex.NextSequence() if err != nil { return err } + // Add the new entry to the dbID-to-SessionID index. + newIndex, err := writeBigSize(nextSeq) + if err != nil { + return err + } + + err = dbIDIndex.Put(newIndex, session.ID[:]) + if err != nil { + return err + } + + // Also add the db-assigned-id to the session bucket under the + // cSessionDBID key. sessionBkt, err := sessions.CreateBucket(session.ID[:]) if err != nil { return err } + err = sessionBkt.Put(cSessionDBID, newIndex) + if err != nil { + return err + } + + // TODO(elle): migrate the towerID-to-SessionID to use the + // new db-assigned sessionID's rather. + + // Add the new entry to the towerID-to-SessionID index. + towerSessions := towerToSessionIndex.NestedReadWriteBucket( + towerID.Bytes(), + ) + if towerSessions == nil { + return ErrTowerNotFound + } + + err = towerSessions.Put(session.ID[:], []byte{1}) + if err != nil { + return err + } + // Finally, write the client session's body in the sessions // bucket. return putClientSessionBody(sessionBkt, session) @@ -1882,6 +1926,30 @@ func getDBChanID(chanDetailsBkt kvdb.RBucket, chanID lnwire.ChannelID) (uint64, return id, idBytes, nil } +// getDBSessionID returns the db-assigned session ID for the given real session +// ID. It returns both the uint64 and byte representation. +func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, + []byte, error) { + + sessionBkt := sessionsBkt.NestedReadBucket(sessionID[:]) + if sessionBkt == nil { + return 0, nil, ErrClientSessionNotFound + } + + idBytes := sessionBkt.Get(cSessionDBID) + if len(idBytes) == 0 { + return 0, nil, fmt.Errorf("no db-assigned ID found for "+ + "session ID %s", sessionID) + } + + id, err := readBigSize(idBytes) + if err != nil { + return 0, nil, err + } + + return id, idBytes, nil +} + // writeBigSize will encode the given uint64 as a BigSize byte slice. func writeBigSize(i uint64) ([]byte, error) { var b bytes.Buffer diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index f7952ba4e..638542abb 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -8,6 +8,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" ) // log is a logger that is initialized with no output filters. This @@ -36,6 +37,7 @@ func UseLogger(logger btclog.Logger) { migration3.UseLogger(logger) migration4.UseLogger(logger) migration5.UseLogger(logger) + migration6.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration6/client_db.go b/watchtower/wtdb/migration6/client_db.go new file mode 100644 index 000000000..8d5ffbc29 --- /dev/null +++ b/watchtower/wtdb/migration6/client_db.go @@ -0,0 +1,114 @@ +package migration6 + +import ( + "bytes" + "encoding/binary" + "errors" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAcks => seqnum -> encoded BackupID + cSessionBkt = []byte("client-session-bucket") + + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-id of a session. + cSessionDBID = []byte("client-session-db-id") + + // cSessionIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> session-id + cSessionIDIndexBkt = []byte("client-session-id-index") + + // cSessionBody is a sub-bucket of cSessionBkt storing only the body of + // the ClientSession. + cSessionBody = []byte("client-session-body") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + byteOrder = binary.BigEndian +) + +// MigrateSessionIDIndex adds a new session ID index to the tower client db. +// This index is a mapping from db-assigned ID (a uint64 encoded using BigSize) +// to real session ID (33 bytes). This mapping will allow us to persist session +// pointers with fewer bytes in the future. +func MigrateSessionIDIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client db to add a new session ID " + + "index which stores a mapping from db-assigned ID to real " + + "session ID") + + // Create a new top-level bucket for the index. + indexBkt, err := tx.CreateTopLevelBucket(cSessionIDIndexBkt) + if err != nil { + return err + } + + // Get the existing top-level sessions bucket. + sessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + // Iterate over the sessions bucket where each key is a session-ID. + return sessionsBkt.ForEach(func(sessionID, _ []byte) error { + // Ask the DB for a new, unique, id for the index bucket. + nextSeq, err := indexBkt.NextSequence() + if err != nil { + return err + } + + newIndex, err := writeBigSize(nextSeq) + if err != nil { + return err + } + + // Add the new db-assigned-ID to real-session-ID pair to the + // new index bucket. + err = indexBkt.Put(newIndex, sessionID) + if err != nil { + return err + } + + // Get the sub-bucket for this specific session ID. + sessionBkt := sessionsBkt.NestedReadWriteBucket(sessionID) + if sessionBkt == nil { + return ErrCorruptClientSession + } + + // Here we ensure that the session bucket includes a session + // body. The only reason we do this is so that we can simulate + // a migration fail in a test to ensure that a migration fail + // results in an untouched db. + sessionBodyBytes := sessionBkt.Get(cSessionBody) + if sessionBodyBytes == nil { + return ErrCorruptClientSession + } + + // Add the db-assigned ID of the session to the session under + // the cSessionDBID key. + return sessionBkt.Put(cSessionDBID, newIndex) + }) +} + +// writeBigSize will encode the given uint64 as a BigSize byte slice. +func writeBigSize(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} diff --git a/watchtower/wtdb/migration6/client_db_test.go b/watchtower/wtdb/migration6/client_db_test.go new file mode 100644 index 000000000..c4928e2f9 --- /dev/null +++ b/watchtower/wtdb/migration6/client_db_test.go @@ -0,0 +1,147 @@ +package migration6 + +import ( + "bytes" + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // pre is the expected data in the sessions bucket before the migration. + pre = map[string]interface{}{ + sessionIDToString(100): map[string]interface{}{ + string(cSessionBody): string([]byte{1, 2, 3}), + }, + sessionIDToString(222): map[string]interface{}{ + string(cSessionBody): string([]byte{4, 5, 6}), + }, + } + + // preFailCorruptDB should fail the migration due to no session body + // being found for a given session ID. + preFailCorruptDB = map[string]interface{}{ + sessionIDToString(100): "", + } + + // post is the expected session index after migration. + postIndex = map[string]interface{}{ + indexToString(1): sessionIDToString(100), + indexToString(2): sessionIDToString(222), + } + + // postSessions is the expected data in the sessions bucket after the + // migration. + postSessions = map[string]interface{}{ + sessionIDToString(100): map[string]interface{}{ + string(cSessionBody): string([]byte{1, 2, 3}), + string(cSessionDBID): indexToString(1), + }, + sessionIDToString(222): map[string]interface{}{ + string(cSessionBody): string([]byte{4, 5, 6}), + string(cSessionDBID): indexToString(2), + }, + } +) + +// TestMigrateSessionIDIndex tests that the MigrateSessionIDIndex function +// correctly adds a new session-id index to the DB and also correctly updates +// the existing session bucket. +func TestMigrateSessionIDIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + pre map[string]interface{} + postSessions map[string]interface{} + postIndex map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + pre: pre, + postSessions: postSessions, + postIndex: postIndex, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + pre: preFailCorruptDB, + }, + { + name: "no channel details", + shouldFail: false, + pre: nil, + postSessions: nil, + postIndex: nil, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Before the migration we have a details bucket. + before := func(tx kvdb.RwTx) error { + return migtest.RestoreDB( + tx, cSessionBkt, test.pre, + ) + } + + // After the migration, we should have an untouched + // summary bucket and a new index bucket. + after := func(tx kvdb.RwTx) error { + // If the migration fails, the details bucket + // should be untouched. + if test.shouldFail { + if err := migtest.VerifyDB( + tx, cSessionBkt, test.pre, + ); err != nil { + return err + } + + return nil + } + + // Else, we expect an updated summary bucket + // and a new index bucket. + err := migtest.VerifyDB( + tx, cSessionBkt, test.postSessions, + ) + if err != nil { + return err + } + + return migtest.VerifyDB( + tx, cSessionIDIndexBkt, test.postIndex, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateSessionIDIndex, + test.shouldFail, + ) + }) + } +} + +func indexToString(id uint64) string { + var newIndex bytes.Buffer + err := tlv.WriteVarInt(&newIndex, id, &[8]byte{}) + if err != nil { + panic(err) + } + + return newIndex.String() +} + +func sessionIDToString(id uint64) string { + var chanID SessionID + byteOrder.PutUint64(chanID[:], id) + return chanID.String() +} diff --git a/watchtower/wtdb/migration6/codec.go b/watchtower/wtdb/migration6/codec.go new file mode 100644 index 000000000..11edbf299 --- /dev/null +++ b/watchtower/wtdb/migration6/codec.go @@ -0,0 +1,17 @@ +package migration6 + +import ( + "encoding/hex" +) + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} diff --git a/watchtower/wtdb/migration6/log.go b/watchtower/wtdb/migration6/log.go new file mode 100644 index 000000000..e43e7d27e --- /dev/null +++ b/watchtower/wtdb/migration6/log.go @@ -0,0 +1,14 @@ +package migration6 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index dbcad3715..1a186044a 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -10,6 +10,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" ) // txMigration is a function which takes a prior outdated version of the @@ -59,6 +60,9 @@ var clientDBVersions = []version{ { txMigration: migration5.MigrateCompleteTowerToSessionIndex, }, + { + txMigration: migration6.MigrateSessionIDIndex, + }, } // getLatestDBVersion returns the last known database version. From ee0353dd24e7819d12df6423375f53272fdd9b6e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 9 Mar 2023 10:30:12 +0200 Subject: [PATCH 02/19] watchtower: build channel to sessionIDs index In this commit, a migration is added that adds an index from channel to sessionIDs (using the DB-assigned session IDs). This will make it easier in future to know which sessions have updates for which channels. --- watchtower/wtdb/client_db.go | 44 +++- watchtower/wtdb/log.go | 2 + watchtower/wtdb/migration7/client_db.go | 202 +++++++++++++++++++ watchtower/wtdb/migration7/client_db_test.go | 191 ++++++++++++++++++ watchtower/wtdb/migration7/codec.go | 29 +++ watchtower/wtdb/migration7/log.go | 14 ++ watchtower/wtdb/version.go | 4 + 7 files changed, 485 insertions(+), 1 deletion(-) create mode 100644 watchtower/wtdb/migration7/client_db.go create mode 100644 watchtower/wtdb/migration7/client_db_test.go create mode 100644 watchtower/wtdb/migration7/codec.go create mode 100644 watchtower/wtdb/migration7/log.go diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 2079a0c6f..014e503db 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -23,8 +23,13 @@ var ( // cChanDetailsBkt is a top-level bucket storing: // channel-id => cChannelSummary -> encoded ClientChanSummary. // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 cChanDetailsBkt = []byte("client-channel-detail-bucket") + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // db-session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + // cChanDBID is a key used in the cChanDetailsBkt to store the // db-assigned-id of a channel. cChanDBID = []byte("client-channel-db-id") @@ -1454,7 +1459,7 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return ErrUninitializedDB } - chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) if chanDetailsBkt == nil { return ErrUninitializedDB } @@ -1538,6 +1543,23 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } + dbSessionID, _, err := getDBSessionID(sessions, *id) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + committedUpdate.BackupID.ChanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + err = putChannelToSessionMapping(chanDetails, dbSessionID) + if err != nil { + return err + } + // Get the range index for the given session-channel pair. index, err := c.getRangeIndex(tx, *id, chanID) if err != nil { @@ -1548,6 +1570,26 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, }, func() {}) } +// putChannelToSessionMapping adds the given session ID to a channel's +// cChanSessions bucket. +func putChannelToSessionMapping(chanDetails kvdb.RwBucket, + dbSessID uint64) error { + + chanSessIDsBkt, err := chanDetails.CreateBucketIfNotExists( + cChanSessions, + ) + if err != nil { + return err + } + + b, err := writeBigSize(dbSessID) + if err != nil { + return err + } + + return chanSessIDsBkt.Put(b, []byte{1}) +} + // getClientSessionBody loads the body of a ClientSession from the sessions // bucket corresponding to the serialized session id. This does not deserialize // the CommittedUpdates, AckUpdates or the Tower associated with the session. diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 638542abb..639030631 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -9,6 +9,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // log is a logger that is initialized with no output filters. This @@ -38,6 +39,7 @@ func UseLogger(logger btclog.Logger) { migration4.UseLogger(logger) migration5.UseLogger(logger) migration6.UseLogger(logger) + migration7.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/migration7/client_db.go b/watchtower/wtdb/migration7/client_db.go new file mode 100644 index 000000000..0c3c4be40 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db.go @@ -0,0 +1,202 @@ +package migration7 + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + + "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/tlv" +) + +var ( + // cSessionBkt is a top-level bucket storing: + // session-id => cSessionBody -> encoded ClientSessionBody + // => cSessionDBID -> db-assigned-id + // => cSessionCommits => seqnum -> encoded CommittedUpdate + // => cSessionAckRangeIndex => chan-id => acked-index-range + cSessionBkt = []byte("client-session-bucket") + + // cChanDetailsBkt is a top-level bucket storing: + // channel-id => cChannelSummary -> encoded ClientChanSummary. + // => cChanDBID -> db-assigned-id + // => cChanSessions => db-session-id -> 1 + cChanDetailsBkt = []byte("client-channel-detail-bucket") + + // cChannelSummary is a sub-bucket of cChanDetailsBkt which stores the + // encoded body of ClientChanSummary. + cChannelSummary = []byte("client-channel-summary") + + // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: + // session-id -> 1 + cChanSessions = []byte("client-channel-sessions") + + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing: + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") + + // cSessionDBID is a key used in the cSessionBkt to store the + // db-assigned-d of a session. + cSessionDBID = []byte("client-session-db-id") + + // cChanIDIndexBkt is a top-level bucket storing: + // db-assigned-id -> channel-ID + cChanIDIndexBkt = []byte("client-channel-id-index") + + // ErrUninitializedDB signals that top-level buckets for the database + // have not been initialized. + ErrUninitializedDB = errors.New("db not initialized") + + // ErrCorruptClientSession signals that the client session's on-disk + // structure deviates from what is expected. + ErrCorruptClientSession = errors.New("client session corrupted") + + // byteOrder is the default endianness used when serializing integers. + byteOrder = binary.BigEndian +) + +// MigrateChannelToSessionIndex migrates the tower client DB to add an index +// from channel-to-session. This will make it easier in future to check which +// sessions have updates for which channels. +func MigrateChannelToSessionIndex(tx kvdb.RwTx) error { + log.Infof("Migrating the tower client DB to build a new " + + "channel-to-session index") + + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + chanIDsBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDsBkt == nil { + return ErrUninitializedDB + } + + // First gather all the new channel-to-session pairs that we want to + // add. + index, err := collectIndex(sessionsBkt) + if err != nil { + return err + } + + // Then persist those pairs to the db. + return persistIndex(chanDetailsBkt, chanIDsBkt, index) +} + +// collectIndex iterates through all the sessions and uses the keys in the +// cSessionAckRangeIndex bucket to collect all the channels that the session +// has updates for. The function returns a map from channel ID to session ID +// (using the db-assigned IDs for both). +func collectIndex(sessionsBkt kvdb.RBucket) (map[uint64]map[uint64]bool, + error) { + + index := make(map[uint64]map[uint64]bool) + err := sessionsBkt.ForEach(func(sessID, _ []byte) error { + sessionBkt := sessionsBkt.NestedReadBucket(sessID) + if sessionBkt == nil { + return ErrCorruptClientSession + } + + ackedRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if ackedRanges == nil { + return nil + } + + sessDBIDBytes := sessionBkt.Get(cSessionDBID) + if sessDBIDBytes == nil { + return ErrCorruptClientSession + } + + sessDBID, err := readUint64(sessDBIDBytes) + if err != nil { + return err + } + + return ackedRanges.ForEach(func(dbChanIDBytes, _ []byte) error { + dbChanID, err := readUint64(dbChanIDBytes) + if err != nil { + return err + } + + if _, ok := index[dbChanID]; !ok { + index[dbChanID] = make(map[uint64]bool) + } + + index[dbChanID][sessDBID] = true + + return nil + }) + }) + if err != nil { + return nil, err + } + + return index, nil +} + +// persistIndex adds the channel-to-session mapping in each channel's details +// bucket. +func persistIndex(chanDetailsBkt kvdb.RwBucket, chanIDsBkt kvdb.RBucket, + index map[uint64]map[uint64]bool) error { + + for dbChanID, sessIDs := range index { + dbChanIDBytes, err := writeUint64(dbChanID) + if err != nil { + return err + } + + realChanID := chanIDsBkt.Get(dbChanIDBytes) + + chanBkt := chanDetailsBkt.NestedReadWriteBucket(realChanID) + if chanBkt == nil { + return fmt.Errorf("channel not found") + } + + sessIDsBkt, err := chanBkt.CreateBucket(cChanSessions) + if err != nil { + return err + } + + for id := range sessIDs { + sessID, err := writeUint64(id) + if err != nil { + return err + } + + err = sessIDsBkt.Put(sessID, []byte{1}) + if err != nil { + return err + } + } + } + + return nil +} + +func writeUint64(i uint64) ([]byte, error) { + var b bytes.Buffer + err := tlv.WriteVarInt(&b, i, &[8]byte{}) + if err != nil { + return nil, err + } + + return b.Bytes(), nil +} + +func readUint64(b []byte) (uint64, error) { + r := bytes.NewReader(b) + i, err := tlv.ReadVarInt(r, &[8]byte{}) + if err != nil { + return 0, err + } + + return i, nil +} diff --git a/watchtower/wtdb/migration7/client_db_test.go b/watchtower/wtdb/migration7/client_db_test.go new file mode 100644 index 000000000..4f90edc47 --- /dev/null +++ b/watchtower/wtdb/migration7/client_db_test.go @@ -0,0 +1,191 @@ +package migration7 + +import ( + "testing" + + "github.com/lightningnetwork/lnd/channeldb/migtest" + "github.com/lightningnetwork/lnd/kvdb" +) + +var ( + // preDetails is the expected data of the channel details bucket before + // the migration. + preDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + }, + } + + // preFailCorruptDB should fail the migration due to no channel summary + // being found for a given channel ID. + preFailCorruptDB = map[string]interface{}{ + channelIDString(30): map[string]interface{}{}, + } + + // channelIDIndex is the data in the channelID index that is used to + // find the mapping between the db-assigned channel ID and the real + // channel ID. + channelIDIndex = map[string]interface{}{ + uint64ToStr(10): channelIDString(100), + uint64ToStr(20): channelIDString(222), + } + + // sessions is the expected data in the sessions bucket before and + // after the migration. + sessions = map[string]interface{}{ + sessionIDString("1"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(32), + uint64ToStr(34): uint64ToStr(34), + }, + uint64ToStr(20): map[string]interface{}{ + uint64ToStr(30): uint64ToStr(30), + }, + }, + string(cSessionDBID): uint64ToStr(66), + }, + sessionIDString("2"): map[string]interface{}{ + string(cSessionAckRangeIndex): map[string]interface{}{ + uint64ToStr(10): map[string]interface{}{ + uint64ToStr(33): uint64ToStr(33), + }, + }, + string(cSessionDBID): uint64ToStr(77), + }, + } + + // postDetails is the expected data in the channel details bucket after + // the migration. + postDetails = map[string]interface{}{ + channelIDString(100): map[string]interface{}{ + string(cChannelSummary): string([]byte{1, 2, 3}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + uint64ToStr(77): string([]byte{1}), + }, + }, + channelIDString(222): map[string]interface{}{ + string(cChannelSummary): string([]byte{4, 5, 6}), + string(cChanSessions): map[string]interface{}{ + uint64ToStr(66): string([]byte{1}), + }, + }, + } +) + +// TestMigrateChannelToSessionIndex tests that the MigrateChannelToSessionIndex +// function correctly builds the new channel-to-sessionID index to the tower +// client DB. +func TestMigrateChannelToSessionIndex(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + shouldFail bool + preDetails map[string]interface{} + preSessions map[string]interface{} + preChanIndex map[string]interface{} + postDetails map[string]interface{} + }{ + { + name: "migration ok", + shouldFail: false, + preDetails: preDetails, + preSessions: sessions, + preChanIndex: channelIDIndex, + postDetails: postDetails, + }, + { + name: "fail due to corrupt db", + shouldFail: true, + preDetails: preFailCorruptDB, + preSessions: sessions, + }, + { + name: "no sessions", + shouldFail: false, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + // Before the migration we have a channel details + // bucket, a sessions bucket, a session ID index bucket + // and a channel ID index bucket. + before := func(tx kvdb.RwTx) error { + err := migtest.RestoreDB( + tx, cChanDetailsBkt, test.preDetails, + ) + if err != nil { + return err + } + + err = migtest.RestoreDB( + tx, cSessionBkt, test.preSessions, + ) + if err != nil { + return err + } + + return migtest.RestoreDB( + tx, cChanIDIndexBkt, test.preChanIndex, + ) + } + + after := func(tx kvdb.RwTx) error { + // If the migration fails, the details bucket + // should be untouched. + if test.shouldFail { + if err := migtest.VerifyDB( + tx, cChanDetailsBkt, + test.preDetails, + ); err != nil { + return err + } + + return nil + } + + // Else, we expect an updated details bucket + // and a new index bucket. + return migtest.VerifyDB( + tx, cChanDetailsBkt, test.postDetails, + ) + } + + migtest.ApplyMigration( + t, before, after, MigrateChannelToSessionIndex, + test.shouldFail, + ) + }) + } +} + +func sessionIDString(id string) string { + var sessID SessionID + copy(sessID[:], id) + return sessID.String() +} + +func channelIDString(id uint64) string { + var chanID ChannelID + byteOrder.PutUint64(chanID[:], id) + return string(chanID[:]) +} + +func uint64ToStr(id uint64) string { + b, err := writeUint64(id) + if err != nil { + panic(err) + } + + return string(b) +} diff --git a/watchtower/wtdb/migration7/codec.go b/watchtower/wtdb/migration7/codec.go new file mode 100644 index 000000000..e94cfe67d --- /dev/null +++ b/watchtower/wtdb/migration7/codec.go @@ -0,0 +1,29 @@ +package migration7 + +import "encoding/hex" + +// SessionIDSize is 33-bytes; it is a serialized, compressed public key. +const SessionIDSize = 33 + +// SessionID is created from the remote public key of a client, and serves as a +// unique identifier and authentication for sending state updates. +type SessionID [SessionIDSize]byte + +// String returns a hex encoding of the session id. +func (s SessionID) String() string { + return hex.EncodeToString(s[:]) +} + +// ChannelID is a series of 32-bytes that uniquely identifies all channels +// within the network. The ChannelID is computed using the outpoint of the +// funding transaction (the txid, and output index). Given a funding output the +// ChannelID can be calculated by XOR'ing the big-endian serialization of the +// txid and the big-endian serialization of the output index, truncated to +// 2 bytes. +type ChannelID [32]byte + +// String returns the string representation of the ChannelID. This is just the +// hex string encoding of the ChannelID itself. +func (c ChannelID) String() string { + return hex.EncodeToString(c[:]) +} diff --git a/watchtower/wtdb/migration7/log.go b/watchtower/wtdb/migration7/log.go new file mode 100644 index 000000000..39f28b6c0 --- /dev/null +++ b/watchtower/wtdb/migration7/log.go @@ -0,0 +1,14 @@ +package migration7 + +import ( + "github.com/btcsuite/btclog" +) + +// log is a logger that is initialized as disabled. This means the package will +// not perform any logging by default until a logger is set. +var log = btclog.Disabled + +// UseLogger uses a specified Logger to output package logging info. +func UseLogger(logger btclog.Logger) { + log = logger +} diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index 1a186044a..b44ed80eb 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration5" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration6" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration7" ) // txMigration is a function which takes a prior outdated version of the @@ -63,6 +64,9 @@ var clientDBVersions = []version{ { txMigration: migration6.MigrateSessionIDIndex, }, + { + txMigration: migration7.MigrateChannelToSessionIndex, + }, } // getLatestDBVersion returns the last known database version. From 5283e2c341e40f36599c2d4a432e01c69a2fc1be Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 13:19:48 +0200 Subject: [PATCH 03/19] watchtower/wtdb: remove unnecessary tower load --- watchtower/wtdb/client_db.go | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 014e503db..4801ca1ab 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1023,20 +1023,14 @@ func (c *ClientDB) ListClientSessions(id *TowerID, return ErrUninitializedDB } - towers := tx.ReadBucket(cTowerBkt) - if towers == nil { - return ErrUninitializedDB - } - chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) if chanIDIndexBkt == nil { return ErrUninitializedDB } - var err error - // If no tower ID is specified, then fetch all the sessions // known to the db. + var err error if id == nil { clientSessions, err = c.listClientAllSessions( sessions, chanIDIndexBkt, filterFn, opts..., From a3050ed21377e1acf27b52a5c7f20abbcdd87c39 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:18:31 +0200 Subject: [PATCH 04/19] watchtower: add GetClientSession func to DB This commit adds a new `GetClientSession` method to the tower client DB which can be used to fetch a session by its ID from the DB. --- watchtower/wtclient/interface.go | 5 +++++ watchtower/wtdb/client_db.go | 30 ++++++++++++++++++++++++++++++ watchtower/wtmock/client_db.go | 30 ++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 9751988ba..32ebe93ad 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -64,6 +64,11 @@ type DB interface { ...wtdb.ClientSessionListOption) ( map[wtdb.SessionID]*wtdb.ClientSession, error) + // GetClientSession loads the ClientSession with the given ID from the + // DB. + GetClientSession(wtdb.SessionID, + ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) + // FetchSessionCommittedUpdates retrieves the current set of un-acked // updates of the given session. FetchSessionCommittedUpdates(id *wtdb.SessionID) ( diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 4801ca1ab..eaf188470 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1009,6 +1009,36 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID, return byteOrder.Uint32(keyIndexBytes), nil } +// GetClientSession loads the ClientSession with the given ID from the DB. +func (c *ClientDB) GetClientSession(id SessionID, + opts ...ClientSessionListOption) (*ClientSession, error) { + + var sess *ClientSession + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + session, err := c.getClientSession( + sessionsBkt, chanIDIndexBkt, id[:], nil, opts..., + ) + if err != nil { + return err + } + + sess = session + return nil + }, func() {}) + + return sess, err +} + // ListClientSessions returns the set of all client sessions known to the db. An // optional tower ID can be used to filter out any client sessions in the // response that do not correspond to this tower. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 071ee7782..9d38c2da2 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -554,6 +554,36 @@ func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { return summaries, nil } +// GetClientSession loads the ClientSession with the given ID from the DB. +func (m *ClientDB) GetClientSession(id wtdb.SessionID, + opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { + + cfg := wtdb.NewClientSessionCfg() + for _, o := range opts { + o(cfg) + } + + session, ok := m.activeSessions[id] + if !ok { + return nil, wtdb.ErrClientSessionNotFound + } + + if cfg.PerMaxHeight != nil { + for chanID, index := range m.ackedUpdates[session.ID] { + cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) + } + } + + if cfg.PerCommittedUpdate != nil { + for _, update := range m.committedUpdates[session.ID] { + update := update + cfg.PerCommittedUpdate(&session, &update) + } + } + + return &session, nil +} + // RegisterChannel registers a channel for use within the client database. For // now, all that is stored in the channel summary is the sweep pkscript that // we'd like any tower sweeps to pay into. In the future, this will be extended From 571966440c28f640c6bf2ab1385844a770b03f6c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:21:18 +0200 Subject: [PATCH 05/19] watchtower: add MarkChannelClosed db method This commit adds a `MarkChannelClosed` method to the tower client DB. This function can be called when a channel is closed and it will check the channel's associated sessions to see if any of them are "closable". Any closable sessions are added to a new `cClosableSessionsBkt` bucket so that they can be evaluated in future. Note that only the logic for this function is added in this commit and it is not yet called. --- watchtower/wtclient/interface.go | 9 ++ watchtower/wtdb/client_db.go | 261 ++++++++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 163 ++++++++++++++++++- watchtower/wtmock/client_db.go | 134 +++++++++++++-- 4 files changed, 553 insertions(+), 14 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 32ebe93ad..e5fc5d22b 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -86,6 +86,15 @@ type DB interface { // their channel summaries. FetchChanSummaries() (wtdb.ChannelSummaries, error) + // MarkChannelClosed will mark a registered channel as closed by setting + // its closed-height as the given block height. It returns a list of + // session IDs for sessions that are now considered closable due to the + // close of this channel. The details for this channel will be deleted + // from the DB if there are no more sessions in the DB that contain + // updates for this channel. + MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) ( + []wtdb.SessionID, error) + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index eaf188470..d88fd631e 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -24,6 +24,7 @@ var ( // channel-id => cChannelSummary -> encoded ClientChanSummary. // => cChanDBID -> db-assigned-id // => cChanSessions => db-session-id -> 1 + // => cChanClosedHeight -> block-height cChanDetailsBkt = []byte("client-channel-detail-bucket") // cChanSessions is a sub-bucket of cChanDetailsBkt which stores: @@ -34,6 +35,12 @@ var ( // db-assigned-id of a channel. cChanDBID = []byte("client-channel-db-id") + // cChanClosedHeight is a key used in the cChanDetailsBkt to store the + // block height at which the channel's closing transaction was mined in. + // If this there is no associated value for this key, then the channel + // has not yet been marked as closed. + cChanClosedHeight = []byte("client-channel-closed-height") + // cChannelSummary is a key used in cChanDetailsBkt to store the encoded // body of ClientChanSummary. cChannelSummary = []byte("client-channel-summary") @@ -83,6 +90,10 @@ var ( "client-tower-to-session-index-bucket", ) + // cClosableSessionsBkt is a top-level bucket storing: + // db-session-id -> last-channel-close-height + cClosableSessionsBkt = []byte("client-closable-sessions-bucket") + // ErrTowerNotFound signals that the target tower was not found in the // database. ErrTowerNotFound = errors.New("tower not found") @@ -156,6 +167,14 @@ var ( // ErrSessionFailedFilterFn indicates that a particular session did // not pass the filter func provided by the caller. ErrSessionFailedFilterFn = errors.New("session failed filter func") + + // errSessionHasOpenChannels is an error used to indicate that a + // session has updates for channels that are still open. + errSessionHasOpenChannels = errors.New("session has open channels") + + // errSessionHasUnackedUpdates is an error used to indicate that a + // session has un-acked updates. + errSessionHasUnackedUpdates = errors.New("session has un-acked updates") ) // NewBoltBackendCreator returns a function that creates a new bbolt backend for @@ -256,6 +275,7 @@ func initClientDBBuckets(tx kvdb.RwTx) error { cTowerToSessionIndexBkt, cChanIDIndexBkt, cSessionIDIndexBkt, + cClosableSessionsBkt, } for _, bucket := range buckets { @@ -1365,6 +1385,209 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, return nil } +// MarkChannelClosed will mark a registered channel as closed by setting its +// closed-height as the given block height. It returns a list of session IDs for +// sessions that are now considered closable due to the close of this channel. +// The details for this channel will be deleted from the DB if there are no more +// sessions in the DB that contain updates for this channel. +func (c *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]SessionID, error) { + + var closableSessions []SessionID + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + closableSessBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableSessBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket(chanID[:]) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel + // details can be deleted. + chanSessIDsBkt := chanDetails.NestedReadBucket(cChanSessions) + if chanSessIDsBkt == nil { + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + } + + // Otherwise, mark the channel as closed. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + + err := chanDetails.Put(cChanClosedHeight, height[:]) + if err != nil { + return err + } + + // Now iterate through all the sessions of the channel to check + // if any of them are closeable. + return chanSessIDsBkt.ForEach(func(sessDBID, _ []byte) error { + sessDBIDInt, err := readBigSize(sessDBID) + if err != nil { + return err + } + + // Use the session-ID index to get the real session ID. + sID, err := getRealSessionID( + sessIDIndexBkt, sessDBIDInt, + ) + if err != nil { + return err + } + + isClosable, err := isSessionClosable( + sessionsBkt, chanDetailsBkt, chanIDIndexBkt, + sID, + ) + if err != nil { + return err + } + + if !isClosable { + return nil + } + + // Add session to "closableSessions" list and add the + // block height that this last channel was closed in. + // This will be used in future to determine when we + // should delete the session. + var height [4]byte + byteOrder.PutUint32(height[:], blockHeight) + err = closableSessBkt.Put(sessDBID, height[:]) + if err != nil { + return err + } + + closableSessions = append(closableSessions, *sID) + + return nil + }) + }, func() { + closableSessions = nil + }) + if err != nil { + return nil, err + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if all the following points are true: +// 1) It has no un-acked updates. +// 2) It is exhausted (ie it can't accept any more updates) +// 3) All the channels that it has acked updates for are closed. +func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket, + id *SessionID) (bool, error) { + + sessBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessBkt == nil { + return false, ErrSessionNotFound + } + + commitsBkt := sessBkt.NestedReadBucket(cSessionCommits) + if commitsBkt == nil { + // If the session has no cSessionCommits bucket then we can be + // sure that no updates have ever been committed to the session + // and so it is not yet exhausted. + return false, nil + } + + // If the session has any un-acked updates, then it is not yet closable. + err := commitsBkt.ForEach(func(_, _ []byte) error { + return errSessionHasUnackedUpdates + }) + if errors.Is(err, errSessionHasUnackedUpdates) { + return false, nil + } else if err != nil { + return false, err + } + + session, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return false, err + } + + // We have already checked that the session has no more committed + // updates. So now we can check if the session is exhausted. + if session.SeqNum < session.Policy.MaxUpdates { + // If the session is not yet exhausted, it is not yet closable. + return false, nil + } + + // If the session has no acked-updates, then something is wrong since + // the above check ensures that this session has been exhausted meaning + // that it should have MaxUpdates acked updates. + ackedRangeBkt := sessBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackedRangeBkt == nil { + return false, fmt.Errorf("no acked-updates found for "+ + "exhausted session %s", id) + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + err = ackedRangeBkt.ForEach(func(dbChanID, _ []byte) error { + dbChanIDInt, err := readBigSize(dbChanID) + if err != nil { + return err + } + + chanID, err := getRealChannelID(chanIDIndexBkt, dbChanIDInt) + if err != nil { + return err + } + + // Get the channel details bucket for the channel. + chanDetails := chanDetailsBkt.NestedReadBucket(chanID[:]) + if chanDetails == nil { + return fmt.Errorf("no channel details found for "+ + "channel %s referenced by session %s", chanID, + id) + } + + // If a closed height has been set, then the channel is closed. + closedHeight := chanDetails.Get(cChanClosedHeight) + if len(closedHeight) > 0 { + return nil + } + + // Otherwise, the channel is not yet closed meaning that the + // session is not yet closable. We break the ForEach by + // returning an error to indicate this. + return errSessionHasOpenChannels + }) + if errors.Is(err, errSessionHasOpenChannels) { + return false, nil + } else if err != nil { + return false, err + } + + return true, nil +} + // CommitUpdate persists the CommittedUpdate provided in the slot for (session, // seqNum). This allows the client to retransmit this update on startup. func (c *ClientDB) CommitUpdate(id *SessionID, @@ -2016,6 +2239,44 @@ func getDBSessionID(sessionsBkt kvdb.RBucket, sessionID SessionID) (uint64, return id, idBytes, nil } +func getRealSessionID(sessIDIndexBkt kvdb.RBucket, dbID uint64) (*SessionID, + error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + sessIDBytes := sessIDIndexBkt.Get(dbIDBytes) + if len(sessIDBytes) != SessionIDSize { + return nil, fmt.Errorf("session ID not found") + } + + var sessID SessionID + copy(sessID[:], sessIDBytes) + + return &sessID, nil +} + +func getRealChannelID(chanIDIndexBkt kvdb.RBucket, + dbID uint64) (*lnwire.ChannelID, error) { + + dbIDBytes, err := writeBigSize(dbID) + if err != nil { + return nil, err + } + + chanIDBytes := chanIDIndexBkt.Get(dbIDBytes) + if len(chanIDBytes) != 32 { //nolint:gomnd + return nil, fmt.Errorf("channel ID not found") + } + + var chanIDS lnwire.ChannelID + copy(chanIDS[:], chanIDBytes) + + return &chanIDS, nil +} + // writeBigSize will encode the given uint64 as a BigSize byte slice. func writeBigSize(i uint64) ([]byte, error) { var b bytes.Buffer diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index cd77ec77e..4f5f80749 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -3,6 +3,7 @@ package wtdb_test import ( crand "crypto/rand" "io" + "math/rand" "net" "testing" @@ -17,6 +18,8 @@ import ( "github.com/stretchr/testify/require" ) +const blobType = blob.TypeAltruistCommit + // pseudoAddr is a fake network address to be used for testing purposes. var pseudoAddr = &net.TCPAddr{IP: []byte{0x01, 0x00, 0x00, 0x00}, Port: 9911} @@ -193,6 +196,17 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, + blockHeight uint32, expErr error) []wtdb.SessionID { + + h.t.Helper() + + closableSessions, err := h.db.MarkChannelClosed(id, blockHeight) + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -605,6 +619,105 @@ func testCommitUpdate(h *clientDBHarness) { }, nil) } +// testMarkChannelClosed asserts the behaviour of MarkChannelClosed. +func testMarkChannelClosed(h *clientDBHarness) { + tower := h.newTower() + + // Create channel 1. + chanID1 := randChannelID(h.t) + + // Since we have not yet registered the channel, we expect an error + // when attempting to mark it as closed. + h.markChannelClosed(chanID1, 1, wtdb.ErrChannelNotRegistered) + + // Now register the channel. + h.registerChan(chanID1, nil, nil) + + // Since there are still no sessions that would have updates for the + // channel, marking it as closed now should succeed. + h.markChannelClosed(chanID1, 1, nil) + + // Register channel 2. + chanID2 := randChannelID(h.t) + h.registerChan(chanID2, nil, nil) + + // Create session1 with MaxUpdates set to 5. + session1 := h.randSession(h.t, tower.ID, 5) + h.insertSession(session1, nil) + + // Add an update for channel 2 in session 1 and ack it too. + update := randCommittedUpdateForChannel(h.t, chanID2, 1) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + require.Zero(h.t, lastApplied) + h.ackUpdate(&session1.ID, 1, 1, nil) + + // Marking channel 2 now should not result in any closable sessions + // since session 1 is not yet exhausted. + sl := h.markChannelClosed(chanID2, 1, nil) + require.Empty(h.t, sl) + + // Create channel 3 and 4. + chanID3 := randChannelID(h.t) + h.registerChan(chanID3, nil, nil) + + chanID4 := randChannelID(h.t) + h.registerChan(chanID4, nil, nil) + + // Add an update for channel 4 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID4, 2) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 1, lastApplied) + h.ackUpdate(&session1.ID, 2, 2, nil) + + // Add an update for channel 3 in session 1. But dont ack it yet. + update = randCommittedUpdateForChannel(h.t, chanID2, 3) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 2, lastApplied) + + // Mark channel 4 as closed & assert that session 1 is not seen as + // closable since it still has committed updates. + sl = h.markChannelClosed(chanID4, 1, nil) + require.Empty(h.t, sl) + + // Now ack the update we added above. + h.ackUpdate(&session1.ID, 3, 3, nil) + + // Mark channel 3 as closed & assert that session 1 is still not seen as + // closable since it is not yet exhausted. + sl = h.markChannelClosed(chanID3, 1, nil) + require.Empty(h.t, sl) + + // Create channel 5 and 6. + chanID5 := randChannelID(h.t) + h.registerChan(chanID5, nil, nil) + + chanID6 := randChannelID(h.t) + h.registerChan(chanID6, nil, nil) + + // Add an update for channel 5 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID5, 4) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 3, lastApplied) + h.ackUpdate(&session1.ID, 4, 4, nil) + + // Add an update for channel 6 and ack it. + update = randCommittedUpdateForChannel(h.t, chanID6, 5) + lastApplied = h.commitUpdate(&session1.ID, update, nil) + require.EqualValues(h.t, 4, lastApplied) + h.ackUpdate(&session1.ID, 5, 5, nil) + + // The session is no exhausted. + // If we now close channel 5, session 1 should still not be closable + // since it has an update for channel 6 which is still open. + sl = h.markChannelClosed(chanID5, 1, nil) + require.Empty(h.t, sl) + + // Finally, if we close channel 6, session 1 _should_ be in the closable + // list. + sl = h.markChannelClosed(chanID6, 1, nil) + require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID}) +} + // testAckUpdate asserts the behavior of AckUpdate. func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit @@ -821,6 +934,10 @@ func TestClientDB(t *testing.T) { name: "ack update", run: testAckUpdate, }, + { + name: "mark channel closed", + run: testMarkChannelClosed, + }, } for _, database := range dbs { @@ -841,12 +958,32 @@ func TestClientDB(t *testing.T) { // randCommittedUpdate generates a random committed update. func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { + t.Helper() + + chanID := randChannelID(t) + + return randCommittedUpdateForChannel(t, chanID, seqNum) +} + +func randChannelID(t *testing.T) lnwire.ChannelID { + t.Helper() + var chanID lnwire.ChannelID _, err := io.ReadFull(crand.Reader, chanID[:]) require.NoError(t, err) + return chanID +} + +// randCommittedUpdateForChannel generates a random committed update for the +// given channel ID. +func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID, + seqNum uint16) *wtdb.CommittedUpdate { + + t.Helper() + var hint blob.BreachHint - _, err = io.ReadFull(crand.Reader, hint[:]) + _, err := io.ReadFull(crand.Reader, hint[:]) require.NoError(t, err) encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) @@ -865,3 +1002,27 @@ func randCommittedUpdate(t *testing.T, seqNum uint16) *wtdb.CommittedUpdate { }, } } + +func (h *clientDBHarness) randSession(t *testing.T, + towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession { + + t.Helper() + + var id wtdb.SessionID + rand.Read(id[:]) + + return &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: towerID, + Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, + MaxUpdates: maxUpdates, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + KeyIndex: h.nextKeyIndex(towerID, blobType), + }, + ID: id, + } +} diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 9d38c2da2..2820d74cd 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -25,19 +25,26 @@ type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore +type channel struct { + summary *wtdb.ClientChanSummary + closedHeight uint32 + sessions map[wtdb.SessionID]bool +} + // ClientDB is a mock, in-memory database or testing the watchtower client // behavior. type ClientDB struct { nextTowerID uint64 // to be used atomically mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + channels map[lnwire.ChannelID]*channel activeSessions map[wtdb.SessionID]wtdb.ClientSession ackedUpdates rangeIndexArrayMap persistedAckedUpdates rangeIndexKVStore committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate towerIndex map[towerPK]wtdb.TowerID towers map[wtdb.TowerID]*wtdb.Tower + closableSessions map[wtdb.SessionID]uint32 nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -47,9 +54,7 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make( - map[lnwire.ChannelID]wtdb.ClientChanSummary, - ), + channels: make(map[lnwire.ChannelID]*channel), activeSessions: make( map[wtdb.SessionID]wtdb.ClientSession, ), @@ -58,10 +63,11 @@ func NewClientDB() *ClientDB { committedUpdates: make( map[wtdb.SessionID][]wtdb.CommittedUpdate, ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), + closableSessions: make(map[wtdb.SessionID]uint32), } } @@ -503,6 +509,13 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, continue } + // Add sessionID to channel. + channel, ok := m.channels[update.BackupID.ChanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + channel.sessions[*id] = true + // Remove the committed update from disk and mark the update as // acked. The tower last applied value is also recorded to send // along with the next update. @@ -545,15 +558,107 @@ func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { defer m.mu.Unlock() summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, summary := range m.summaries { + for chanID, channel := range m.channels { summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(summary.SweepPkScript), + SweepPkScript: cloneBytes( + channel.summary.SweepPkScript, + ), } } return summaries, nil } +// MarkChannelClosed will mark a registered channel as closed by setting +// its closed-height as the given block height. It returns a list of +// session IDs for sessions that are now considered closable due to the +// close of this channel. +func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, + blockHeight uint32) ([]wtdb.SessionID, error) { + + m.mu.Lock() + defer m.mu.Unlock() + + channel, ok := m.channels[chanID] + if !ok { + return nil, wtdb.ErrChannelNotRegistered + } + + // If there are no sessions for this channel, the channel details can be + // deleted. + if len(channel.sessions) == 0 { + delete(m.channels, chanID) + return nil, nil + } + + // Mark the channel as closed. + channel.closedHeight = blockHeight + + // Now iterate through all the sessions of the channel to check if any + // of them are closeable. + var closableSessions []wtdb.SessionID + for sessID := range channel.sessions { + isClosable, err := m.isSessionClosable(sessID) + if err != nil { + return nil, err + } + + if !isClosable { + continue + } + + closableSessions = append(closableSessions, sessID) + + // Add session to "closableSessions" list and add the block + // height that this last channel was closed in. This will be + // used in future to determine when we should delete the + // session. + m.closableSessions[sessID] = blockHeight + } + + return closableSessions, nil +} + +// isSessionClosable returns true if a session is considered closable. A session +// is considered closable only if: +// 1) It has no un-acked updates +// 2) It is exhausted (ie it cant accept any more updates) +// 3) All the channels that it has acked-updates for are closed. +func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { + // The session is not closable if it has un-acked updates. + if len(m.committedUpdates[id]) > 0 { + return false, nil + } + + sess, ok := m.activeSessions[id] + if !ok { + return false, wtdb.ErrClientSessionNotFound + } + + // The session is not closable if it is not yet exhausted. + if sess.SeqNum != sess.Policy.MaxUpdates { + return false, nil + } + + // Iterate over each of the channels that the session has acked-updates + // for. If any of those channels are not closed, then the session is + // not yet closable. + for chanID := range m.ackedUpdates[id] { + channel, ok := m.channels[chanID] + if !ok { + continue + } + + // Channel is not yet closed, and so we can not yet delete the + // session. + if channel.closedHeight == 0 { + return false, nil + } + } + + return true, nil +} + // GetClientSession loads the ClientSession with the given ID from the DB. func (m *ClientDB) GetClientSession(id wtdb.SessionID, opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { @@ -595,12 +700,15 @@ func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, m.mu.Lock() defer m.mu.Unlock() - if _, ok := m.summaries[chanID]; ok { + if _, ok := m.channels[chanID]; ok { return wtdb.ErrChannelAlreadyRegistered } - m.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), + m.channels[chanID] = &channel{ + summary: &wtdb.ClientChanSummary{ + SweepPkScript: cloneBytes(sweepPkScript), + }, + sessions: make(map[wtdb.SessionID]bool), } return nil From 3577c829d316685da24d7e5b39b01f1061287b0e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:22:20 +0200 Subject: [PATCH 06/19] watchtower: add ListClosableSessions method This commit adds a new ListClosableSessions method to the tower client DB. This method will return a map of sessionIDs to block heights. The IDs belong to sessions that are considered closable and the block heights are the block height at which the last associated channel for the session was closed in. --- watchtower/wtclient/interface.go | 4 ++++ watchtower/wtdb/client_db.go | 40 +++++++++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 18 +++++++++++++- watchtower/wtmock/client_db.go | 14 +++++++++++ 4 files changed, 75 insertions(+), 1 deletion(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index e5fc5d22b..b5d2418ab 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -95,6 +95,10 @@ type DB interface { MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) ( []wtdb.SessionID, error) + // ListClosableSessions fetches and returns the IDs for all sessions + // marked as closable. + ListClosableSessions() (map[wtdb.SessionID]uint32, error) + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index d88fd631e..4a74c7bb4 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1385,6 +1385,46 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, return nil } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) { + sessions := make(map[SessionID]uint32) + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + csBkt := tx.ReadBucket(cClosableSessionsBkt) + if csBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + return csBkt.ForEach(func(dbIDBytes, heightBytes []byte) error { + dbID, err := readBigSize(dbIDBytes) + if err != nil { + return err + } + + sessID, err := getRealSessionID(sessIDIndexBkt, dbID) + if err != nil { + return err + } + + sessions[*sessID] = byteOrder.Uint32(heightBytes) + + return nil + }) + }, func() { + sessions = make(map[SessionID]uint32) + }) + if err != nil { + return nil, err + } + + return sessions, nil +} + // MarkChannelClosed will mark a registered channel as closed by setting its // closed-height as the given block height. It returns a list of session IDs for // sessions that are now considered closable due to the close of this channel. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 4f5f80749..73f4e5550 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -207,6 +207,17 @@ func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, return closableSessions } +func (h *clientDBHarness) listClosableSessions( + expErr error) map[wtdb.SessionID]uint32 { + + h.t.Helper() + + closableSessions, err := h.db.ListClosableSessions() + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -711,11 +722,16 @@ func testMarkChannelClosed(h *clientDBHarness) { // since it has an update for channel 6 which is still open. sl = h.markChannelClosed(chanID5, 1, nil) require.Empty(h.t, sl) + require.Empty(h.t, h.listClosableSessions(nil)) // Finally, if we close channel 6, session 1 _should_ be in the closable // list. - sl = h.markChannelClosed(chanID6, 1, nil) + sl = h.markChannelClosed(chanID6, 100, nil) require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID}) + slMap := h.listClosableSessions(nil) + require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{ + session1.ID: 100, + }, 0) } // testAckUpdate asserts the behavior of AckUpdate. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 2820d74cd..f439e9182 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -551,6 +551,20 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, return wtdb.ErrCommittedUpdateNotFound } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) + for id, height := range m.closableSessions { + cs[id] = height + } + + return cs, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { From e432261dab3593f04f8c9bc30e161bb3bb27b0bf Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:23:20 +0200 Subject: [PATCH 07/19] watchtower: add DeleteSession method Add a DeleteSession method to the tower client DB. This can be used to delete a closable session along with any references to the session. --- watchtower/wtclient/interface.go | 6 + watchtower/wtdb/client_db.go | 181 ++++++++++++++++++++++++++++++ watchtower/wtdb/client_db_test.go | 15 +++ watchtower/wtmock/client_db.go | 28 +++++ 4 files changed, 230 insertions(+) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index b5d2418ab..a69e9980f 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -99,6 +99,12 @@ type DB interface { // marked as closable. ListClosableSessions() (map[wtdb.SessionID]uint32, error) + // DeleteSession can be called when a session should be deleted from the + // DB. All references to the session will also be deleted from the DB. + // A session will only be deleted if it was previously marked as + // closable. + DeleteSession(id wtdb.SessionID) error + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 4a74c7bb4..c3e5447d6 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -168,6 +168,10 @@ var ( // not pass the filter func provided by the caller. ErrSessionFailedFilterFn = errors.New("session failed filter func") + // ErrSessionNotClosable is returned when a session is not found in the + // closable list. + ErrSessionNotClosable = errors.New("session is not closable") + // errSessionHasOpenChannels is an error used to indicate that a // session has updates for channels that are still open. errSessionHasOpenChannels = errors.New("session has open channels") @@ -175,6 +179,11 @@ var ( // errSessionHasUnackedUpdates is an error used to indicate that a // session has un-acked updates. errSessionHasUnackedUpdates = errors.New("session has un-acked updates") + + // errChannelHasMoreSessions is an error used to indicate that a channel + // has updates in other non-closed sessions. + errChannelHasMoreSessions = errors.New("channel has updates in " + + "other sessions") ) // NewBoltBackendCreator returns a function that creates a new bbolt backend for @@ -1053,6 +1062,7 @@ func (c *ClientDB) GetClientSession(id SessionID, } sess = session + return nil }, func() {}) @@ -1425,6 +1435,177 @@ func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) { return sessions, nil } +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if was previously marked as closable. +func (c *ClientDB) DeleteSession(id SessionID) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + closableBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadWriteBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadWriteBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + towerToSessBkt := tx.ReadWriteBucket(cTowerToSessionIndexBkt) + if towerToSessBkt == nil { + return ErrUninitializedDB + } + + // Get the sub-bucket for this session ID. If it does not exist + // then the session has already been deleted and so our work is + // done. + sessionBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessionBkt == nil { + return nil + } + + _, dbIDBytes, err := getDBSessionID(sessionsBkt, id) + if err != nil { + return err + } + + // First we check if the session has actually been marked as + // closable. + if closableBkt.Get(dbIDBytes) == nil { + return ErrSessionNotClosable + } + + sess, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return err + } + + // Delete from the tower-to-sessionID index. + towerIndexBkt := towerToSessBkt.NestedReadWriteBucket( + sess.TowerID.Bytes(), + ) + if towerIndexBkt == nil { + return fmt.Errorf("no entry in the tower-to-session "+ + "index found for tower ID %v", sess.TowerID) + } + + err = towerIndexBkt.Delete(id[:]) + if err != nil { + return err + } + + // Delete entry from session ID index. + err = sessIDIndexBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Delete the entry from the closable sessions index. + err = closableBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Get the acked updates range index for the session. This is + // used to get the list of channels that the session has updates + // for. + ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackRanges == nil { + // A session would only be considered closable if it + // was exhausted. Meaning that it should not be the + // case that it has no acked-updates. + return fmt.Errorf("cannot delete session %s since it "+ + "is not yet exhausted", id) + } + + // For each of the channels, delete the session ID entry. + err = ackRanges.ForEach(func(chanDBID, _ []byte) error { + chanDBIDInt, err := readBigSize(chanDBID) + if err != nil { + return err + } + + chanID, err := getRealChannelID( + chanIDIndexBkt, chanDBIDInt, + ) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + chanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + chanSessions := chanDetails.NestedReadWriteBucket( + cChanSessions, + ) + if chanSessions == nil { + return fmt.Errorf("no session list found for "+ + "channel %s", chanID) + } + + // Check that this session was actually listed in the + // session list for this channel. + if len(chanSessions.Get(dbIDBytes)) == 0 { + return fmt.Errorf("session %s not found in "+ + "the session list for channel %s", id, + chanID) + } + + // If it was, then delete it. + err = chanSessions.Delete(dbIDBytes) + if err != nil { + return err + } + + // If this was the last session for this channel, we can + // now delete the channel details for this channel + // completely. + err = chanSessions.ForEach(func(_, _ []byte) error { + return errChannelHasMoreSessions + }) + if errors.Is(err, errChannelHasMoreSessions) { + return nil + } else if err != nil { + return err + } + + // Delete the channel's entry from the channel-id-index. + dbID := chanDetails.Get(cChanDBID) + err = chanIDIndexBkt.Delete(dbID) + if err != nil { + return err + } + + // Delete the channel details. + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + }) + if err != nil { + return err + } + + // Delete the actual session. + return sessionsBkt.DeleteNestedBucket(id[:]) + }, func() {}) +} + // MarkChannelClosed will mark a registered channel as closed by setting its // closed-height as the given block height. It returns a list of session IDs for // sessions that are now considered closable due to the close of this channel. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 73f4e5550..b3d241175 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -218,6 +218,13 @@ func (h *clientDBHarness) listClosableSessions( return closableSessions } +func (h *clientDBHarness) deleteSession(id wtdb.SessionID, expErr error) { + h.t.Helper() + + err := h.db.DeleteSession(id) + require.ErrorIs(h.t, err, expErr) +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -724,6 +731,10 @@ func testMarkChannelClosed(h *clientDBHarness) { require.Empty(h.t, sl) require.Empty(h.t, h.listClosableSessions(nil)) + // Also check that attempting to delete the session will fail since it + // is not yet considered closable. + h.deleteSession(session1.ID, wtdb.ErrSessionNotClosable) + // Finally, if we close channel 6, session 1 _should_ be in the closable // list. sl = h.markChannelClosed(chanID6, 100, nil) @@ -732,6 +743,10 @@ func testMarkChannelClosed(h *clientDBHarness) { require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{ session1.ID: 100, }, 0) + + // Assert that we now can delete the session. + h.deleteSession(session1.ID, nil) + require.Empty(h.t, h.listClosableSessions(nil)) } // testAckUpdate asserts the behavior of AckUpdate. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index f439e9182..7213f17b6 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -703,6 +703,34 @@ func (m *ClientDB) GetClientSession(id wtdb.SessionID, return &session, nil } +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if it is considered closable. +func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { + m.mu.Lock() + defer m.mu.Unlock() + + _, ok := m.closableSessions[id] + if !ok { + return wtdb.ErrSessionNotClosable + } + + // For each of the channels, delete the session ID entry. + for chanID := range m.ackedUpdates[id] { + c, ok := m.channels[chanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + + delete(c.sessions, id) + } + + delete(m.closableSessions, id) + delete(m.activeSessions, id) + + return nil +} + // RegisterChannel registers a channel for use within the client database. For // now, all that is stored in the channel summary is the sweep pkscript that // we'd like any tower sweeps to pay into. In the future, this will be extended From 0ed5c750c8bec4ca66254aee26dcfffcb05298a8 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:24:02 +0200 Subject: [PATCH 08/19] watchtower: add GetTower to tower iterator Add a GetTower method to the tower iterator. --- watchtower/wtclient/candidate_iterator.go | 18 ++++++++++++++ .../wtclient/candidate_iterator_test.go | 24 ++++++++++++++++--- watchtower/wtclient/errors.go | 8 ++++++- 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index faf3169c6..10ef86465 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -29,6 +29,10 @@ type TowerCandidateIterator interface { // candidates available as long as they remain in the set. Reset() error + // GetTower gets the tower with the given ID from the iterator. If no + // such tower is found then ErrTowerNotInIterator is returned. + GetTower(id wtdb.TowerID) (*Tower, error) + // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. @@ -76,6 +80,20 @@ func (t *towerListIterator) Reset() error { return nil } +// GetTower gets the tower with the given ID from the iterator. If no such tower +// is found then ErrTowerNotInIterator is returned. +func (t *towerListIterator) GetTower(id wtdb.TowerID) (*Tower, error) { + t.mu.Lock() + defer t.mu.Unlock() + + tower, ok := t.candidates[id] + if !ok { + return nil, ErrTowerNotInIterator + } + + return tower, nil +} + // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 7fe6ba723..70dfb7505 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -83,9 +83,15 @@ func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { tower, err := i.Next() require.NoError(t, err) - require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) - require.Equal(t, tower.ID, c.ID) - require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) + assertTowersEqual(t, c, tower) +} + +func assertTowersEqual(t *testing.T, expected, actual *Tower) { + t.Helper() + + require.True(t, expected.IdentityKey.IsEqual(actual.IdentityKey)) + require.Equal(t, expected.ID, actual.ID) + require.Equal(t, expected.Addresses.GetAll(), actual.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -155,4 +161,16 @@ func TestTowerCandidateIterator(t *testing.T) { towerIterator.AddCandidate(secondTower) assertActiveCandidate(t, towerIterator, secondTower, true) assertNextCandidate(t, towerIterator, secondTower) + + // Assert that the GetTower correctly returns the tower too. + tower, err := towerIterator.GetTower(secondTower.ID) + require.NoError(t, err) + assertTowersEqual(t, secondTower, tower) + + // Now remove the tower and assert that GetTower returns expected error. + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) + + _, err = towerIterator.GetTower(secondTower.ID) + require.ErrorIs(t, err, ErrTowerNotInIterator) } diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index f496074bf..c6884bb35 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -1,6 +1,8 @@ package wtclient -import "errors" +import ( + "errors" +) var ( // ErrClientExiting signals that the watchtower client is shutting down. @@ -11,6 +13,10 @@ var ( ErrTowerCandidatesExhausted = errors.New("exhausted all tower " + "candidates") + // ErrTowerNotInIterator is returned when a requested tower was not + // found in the iterator. + ErrTowerNotInIterator = errors.New("tower not in iterator") + // ErrPermanentTowerFailure signals that the tower has reported that it // has permanently failed or the client believes this has happened based // on the tower's behavior. From 41e36c7ec788178fc111a26ea087fb12e64449ba Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:24:46 +0200 Subject: [PATCH 09/19] watchtower: add wtclient.ClientSession constructor --- watchtower/wtclient/client.go | 22 +++++++--------------- watchtower/wtclient/interface.go | 27 +++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 32622b7da..fda072840 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -435,27 +435,19 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, } for _, s := range sessions { - towerKeyDesc, err := keyRing.DeriveKey( - keychain.KeyLocator{ - Family: keychain.KeyFamilyTowerSession, - Index: s.KeyIndex, - }, + if !sessionFilter(s) { + continue + } + + cs, err := NewClientSessionFromDBSession( + s, tower, keyRing, ) if err != nil { return nil, err } - sessionKeyECDH := keychain.NewPubKeyECDH( - towerKeyDesc, keyRing, - ) - // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = &ClientSession{ - ID: s.ID, - ClientSessionBody: s.ClientSessionBody, - Tower: tower, - SessionKeyECDH: sessionKeyECDH, - } + candidateSessions[s.ID] = cs perActiveTower(tower) } diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index a69e9980f..4dd9176ae 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -198,3 +198,30 @@ type ClientSession struct { // key used to connect to the watchtower. SessionKeyECDH keychain.SingleKeyECDH } + +// NewClientSessionFromDBSession converts a wtdb.ClientSession to a +// ClientSession. +func NewClientSessionFromDBSession(s *wtdb.ClientSession, tower *Tower, + keyRing ECDHKeyRing) (*ClientSession, error) { + + towerKeyDesc, err := keyRing.DeriveKey( + keychain.KeyLocator{ + Family: keychain.KeyFamilyTowerSession, + Index: s.KeyIndex, + }, + ) + if err != nil { + return nil, err + } + + sessionKeyECDH := keychain.NewPubKeyECDH( + towerKeyDesc, keyRing, + ) + + return &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + }, nil +} From 16008c00321c1fb37b23f8e4d94744ab535a2870 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:29:25 +0200 Subject: [PATCH 10/19] watchtower: handle channel closures Add a channel-close handler that waits for channel close events and marks channels as closed in the tower client DB. --- server.go | 16 ++ watchtower/wtclient/client.go | 147 ++++++++++- watchtower/wtclient/client_test.go | 406 +++++++++++++++++++++++------ 3 files changed, 483 insertions(+), 86 deletions(-) diff --git a/server.go b/server.go index 7758ebe8c..61b31ffa6 100644 --- a/server.go +++ b/server.go @@ -1512,7 +1512,16 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) } + fetchClosedChannel := s.chanStateDB.FetchClosedChannelForID + s.towerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, @@ -1536,6 +1545,13 @@ func newServer(cfg *Config, listenAddrs []net.Addr, blob.Type(blob.FlagAnchorChannel) s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ + FetchClosedChannel: fetchClosedChannel, + SubscribeChannelEvents: func() (subscribe.Subscription, + error) { + + return s.channelNotifier. + SubscribeChannelEvents() + }, Signer: cc.Wallet.Cfg.Signer, NewAddress: newSweepPkScriptGen(cc.Wallet), SecretKeyRing: s.cc.KeyRing, diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index fda072840..c0a4c2331 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,6 +2,7 @@ package wtclient import ( "bytes" + "errors" "fmt" "net" "sync" @@ -12,10 +13,12 @@ import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -146,6 +149,16 @@ type Config struct { // transaction. Signer input.Signer + // SubscribeChannelEvents can be used to subscribe to channel event + // notifications. + SubscribeChannelEvents func() (subscribe.Subscription, error) + + // FetchClosedChannel can be used to fetch the info about a closed + // channel. If the channel is not found or not yet closed then + // channeldb.ErrClosedChannelNotFound will be returned. + FetchClosedChannel func(cid lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -269,6 +282,7 @@ type TowerClient struct { staleTowers chan *staleTowerMsg wg sync.WaitGroup + quit chan struct{} forceQuit chan struct{} } @@ -319,6 +333,7 @@ func New(config *Config) (*TowerClient, error) { newTowers: make(chan *newTowerMsg), staleTowers: make(chan *staleTowerMsg), forceQuit: make(chan struct{}), + quit: make(chan struct{}), } // perUpdate is a callback function that will be used to inspect the @@ -364,7 +379,7 @@ func New(config *Config) (*TowerClient, error) { return } - log.Infof("Using private watchtower %s, offering policy %s", + c.log.Infof("Using private watchtower %s, offering policy %s", tower, cfg.Policy) // Add the tower to the set of candidate towers. @@ -540,10 +555,45 @@ func (c *TowerClient) Start() error { } } + chanSub, err := c.cfg.SubscribeChannelEvents() + if err != nil { + returnErr = err + return + } + + // Iterate over the list of registered channels and check if + // any of them can be marked as closed. + for id := range c.summaries { + isClosed, closedHeight, err := c.isChannelClosed(id) + if err != nil { + returnErr = err + return + } + + if !isClosed { + continue + } + + _, err = c.cfg.DB.MarkChannelClosed(id, closedHeight) + if err != nil { + c.log.Errorf("could not mark channel(%s) as "+ + "closed: %v", id, err) + + continue + } + + // Since the channel has been marked as closed, we can + // also remove it from the channel summaries map. + delete(c.summaries, id) + } + + c.wg.Add(1) + go c.handleChannelCloses(chanSub) + // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. - err := c.negotiator.Start() + err = c.negotiator.Start() if err != nil { returnErr = err return @@ -591,6 +641,7 @@ func (c *TowerClient) Stop() error { // dispatcher to exit. The backup queue will signal it's // completion to the dispatcher, which releases the wait group // after all tasks have been assigned to session queues. + close(c.quit) c.wg.Wait() // 4. Since all valid tasks have been assigned to session @@ -772,6 +823,82 @@ func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { return c.getOrInitActiveQueue(candidateSession, updates), nil } +// handleChannelCloses listens for channel close events and marks channels as +// closed in the DB. +// +// NOTE: This method MUST be run as a goroutine. +func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { + defer c.wg.Done() + + c.log.Debugf("Starting channel close handler") + defer c.log.Debugf("Stopping channel close handler") + + for { + select { + case update, ok := <-chanSub.Updates(): + if !ok { + c.log.Debugf("Channel notifier has exited") + return + } + + // We only care about channel-close events. + event, ok := update.(channelnotifier.ClosedChannelEvent) + if !ok { + continue + } + + chanID := lnwire.NewChanIDFromOutPoint( + &event.CloseSummary.ChanPoint, + ) + + c.log.Debugf("Received ClosedChannelEvent for "+ + "channel: %s", chanID) + + err := c.handleClosedChannel( + chanID, event.CloseSummary.CloseHeight, + ) + if err != nil { + c.log.Errorf("Could not handle channel close "+ + "event for channel(%s): %v", chanID, + err) + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// handleClosedChannel handles the closure of a single channel. It will mark the +// channel as closed in the DB. +func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, + closeHeight uint32) error { + + c.backupMu.Lock() + defer c.backupMu.Unlock() + + // We only care about channels registered with the tower client. + if _, ok := c.summaries[chanID]; !ok { + return nil + } + + c.log.Debugf("Marking channel(%s) as closed", chanID) + + _, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) + if err != nil { + return fmt.Errorf("could not mark channel(%s) as closed: %w", + chanID, err) + } + + delete(c.summaries, chanID) + delete(c.chanCommitHeights, chanID) + + return nil +} + // backupDispatcher processes events coming from the taskPipeline and is // responsible for detecting when the client needs to renegotiate a session to // fulfill continuing demand. The event loop exits after all tasks have been @@ -1145,6 +1272,22 @@ func (c *TowerClient) initActiveQueue(s *ClientSession, return sq } +// isChanClosed can be used to check if the channel with the given ID has been +// closed. If it has been, the block height in which its closing transaction was +// mined will also be returned. +func (c *TowerClient) isChannelClosed(id lnwire.ChannelID) (bool, uint32, + error) { + + chanSum, err := c.cfg.FetchClosedChannel(id) + if errors.Is(err, channeldb.ErrClosedChannelNotFound) { + return false, 0, nil + } else if err != nil { + return false, 0, err + } + + return true, chanSum.CloseHeight, nil +} + // AddTower adds a new watchtower reachable at the given address and considers // it for new sessions. If the watchtower already exists, then any new addresses // included will be considered when dialing it for session negotiations and diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 1490c6d10..29c4e7a53 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1,6 +1,7 @@ package wtclient_test import ( + "bytes" "encoding/binary" "errors" "fmt" @@ -16,11 +17,13 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/subscribe" "github.com/lightningnetwork/lnd/tor" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" @@ -393,8 +396,12 @@ type testHarness struct { server *wtserver.Server net *mockNet - mu sync.Mutex - channels map[lnwire.ChannelID]*mockChannel + channelEvents *mockSubscription + sendUpdatesOn bool + + mu sync.Mutex + channels map[lnwire.ChannelID]*mockChannel + closedChannels map[lnwire.ChannelID]uint32 quit chan struct{} } @@ -441,13 +448,50 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { mockNet := newMockNet() clientDB := wtmock.NewClientDB() - clientCfg := &wtclient.Config{ - Signer: signer, - Dial: mockNet.Dial, - DB: clientDB, - AuthDial: mockNet.AuthDial, - SecretKeyRing: wtmock.NewSecretKeyRing(), - Policy: cfg.policy, + h := &testHarness{ + t: t, + cfg: cfg, + signer: signer, + capacity: cfg.localBalance + cfg.remoteBalance, + clientDB: clientDB, + serverAddr: towerAddr, + serverDB: serverDB, + serverCfg: serverCfg, + net: mockNet, + channelEvents: newMockSubscription(t), + channels: make(map[lnwire.ChannelID]*mockChannel), + closedChannels: make(map[lnwire.ChannelID]uint32), + quit: make(chan struct{}), + } + t.Cleanup(func() { + close(h.quit) + }) + + fetchChannel := func(id lnwire.ChannelID) ( + *channeldb.ChannelCloseSummary, error) { + + h.mu.Lock() + defer h.mu.Unlock() + + height, ok := h.closedChannels[id] + if !ok { + return nil, channeldb.ErrClosedChannelNotFound + } + + return &channeldb.ChannelCloseSummary{CloseHeight: height}, nil + } + + h.clientCfg = &wtclient.Config{ + Signer: signer, + SubscribeChannelEvents: func() (subscribe.Subscription, error) { + return h.channelEvents, nil + }, + FetchClosedChannel: fetchChannel, + Dial: mockNet.Dial, + DB: clientDB, + AuthDial: mockNet.AuthDial, + SecretKeyRing: wtmock.NewSecretKeyRing(), + Policy: cfg.policy, NewAddress: func() ([]byte, error) { return addrScript, nil }, @@ -458,24 +502,6 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { ForceQuitDelay: 10 * time.Second, } - h := &testHarness{ - t: t, - cfg: cfg, - signer: signer, - capacity: cfg.localBalance + cfg.remoteBalance, - clientDB: clientDB, - clientCfg: clientCfg, - serverAddr: towerAddr, - serverDB: serverDB, - serverCfg: serverCfg, - net: mockNet, - channels: make(map[lnwire.ChannelID]*mockChannel), - quit: make(chan struct{}), - } - t.Cleanup(func() { - close(h.quit) - }) - if !cfg.noServerStart { h.startServer() t.Cleanup(h.stopServer) @@ -576,6 +602,41 @@ func (h *testHarness) channel(id uint64) *mockChannel { return c } +// closeChannel marks a channel as closed. +// +// NOTE: The method fails if a channel for id does not exist. +func (h *testHarness) closeChannel(id uint64, height uint32) { + h.t.Helper() + + h.mu.Lock() + defer h.mu.Unlock() + + chanID := chanIDFromInt(id) + + _, ok := h.channels[chanID] + require.Truef(h.t, ok, "unable to fetch channel %d", id) + + h.closedChannels[chanID] = height + delete(h.channels, chanID) + + chanPointHash, err := chainhash.NewHash(chanID[:]) + require.NoError(h.t, err) + + if !h.sendUpdatesOn { + return + } + + h.channelEvents.sendUpdate(channelnotifier.ClosedChannelEvent{ + CloseSummary: &channeldb.ChannelCloseSummary{ + ChanPoint: wire.OutPoint{ + Hash: *chanPointHash, + Index: 0, + }, + CloseHeight: height, + }, + }) +} + // registerChannel registers the channel identified by id with the client. func (h *testHarness) registerChannel(id uint64) { h.t.Helper() @@ -624,7 +685,7 @@ func (h *testHarness) backupState(id, i uint64, expErr error) { err := h.client.BackupState( &chanID, retribution, channeldb.SingleFunderBit, ) - require.ErrorIs(h.t, expErr, err) + require.ErrorIs(h.t, err, expErr) } // sendPayments instructs the channel identified by id to send amt to the remote @@ -770,11 +831,94 @@ func (h *testHarness) removeTower(pubKey *btcec.PublicKey, addr net.Addr) { require.NoError(h.t, err) } +// relevantSessions returns a list of session IDs that have acked updates for +// the given channel ID. +func (h *testHarness) relevantSessions(chanID uint64) []wtdb.SessionID { + h.t.Helper() + + var ( + sessionIDs []wtdb.SessionID + cID = chanIDFromInt(chanID) + ) + + collectSessions := wtdb.WithPerNumAckedUpdates( + func(session *wtdb.ClientSession, id lnwire.ChannelID, + _ uint16) { + + if !bytes.Equal(id[:], cID[:]) { + return + } + + sessionIDs = append(sessionIDs, session.ID) + }, + ) + + _, err := h.clientDB.ListClientSessions(nil, nil, collectSessions) + require.NoError(h.t, err) + + return sessionIDs +} + +// isSessionClosable returns true if the given session has been marked as +// closable in the DB. +func (h *testHarness) isSessionClosable(id wtdb.SessionID) bool { + h.t.Helper() + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + _, ok := cs[id] + + return ok +} + +// mockSubscription is a mock subscription client that blocks on sends into the +// updates channel. +type mockSubscription struct { + t *testing.T + updates chan interface{} + + // Embed the subscription interface in this mock so that we satisfy it. + subscribe.Subscription +} + +// newMockSubscription creates a mock subscription. +func newMockSubscription(t *testing.T) *mockSubscription { + t.Helper() + + return &mockSubscription{ + t: t, + updates: make(chan interface{}), + } +} + +// sendUpdate sends an update into our updates channel, mocking the dispatch of +// an update from a subscription server. This call will fail the test if the +// update is not consumed within our timeout. +func (m *mockSubscription) sendUpdate(update interface{}) { + select { + case m.updates <- update: + + case <-time.After(waitTime): + m.t.Fatalf("update: %v timeout", update) + } +} + +// Updates returns the updates channel for the mock. +func (m *mockSubscription) Updates() <-chan interface{} { + return m.updates +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) ) +var defaultTxPolicy = wtpolicy.TxPolicy{ + BlobType: blob.TypeAltruistCommit, + SweepFeeRate: wtpolicy.DefaultSweepFeeRate, +} + type clientTest struct { name string cfg harnessCfg @@ -791,10 +935,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, noRegisterChan0: true, @@ -825,10 +966,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -860,10 +998,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -927,10 +1062,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 20000, }, }, @@ -1006,10 +1138,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1062,10 +1191,7 @@ var clientTests = []clientTest{ localBalance: 100000001, // ensure (% amt != 0) remoteBalance: 200000001, // ensure (% amt != 0) policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 1000, }, }, @@ -1106,10 +1232,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1156,10 +1279,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1212,10 +1332,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noAckCreateSession: true, @@ -1274,10 +1391,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 10, }, }, @@ -1333,10 +1447,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1381,10 +1492,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1489,10 +1597,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, }, @@ -1557,10 +1662,7 @@ var clientTests = []clientTest{ localBalance: localBalance, remoteBalance: remoteBalance, policy: wtpolicy.Policy{ - TxPolicy: wtpolicy.TxPolicy{ - BlobType: blob.TypeAltruistCommit, - SweepFeeRate: wtpolicy.DefaultSweepFeeRate, - }, + TxPolicy: defaultTxPolicy, MaxUpdates: 5, }, noServerStart: true, @@ -1654,6 +1756,142 @@ var clientTests = []clientTest{ }, waitTime) require.NoError(h.t, err) }, + }, { + name: "assert that sessions are correctly marked as closable", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const numUpdates = 5 + + // In this test we assert that a channel is correctly + // marked as closed and that sessions are also correctly + // marked as closable. + + // We start with the sendUpdatesOn parameter set to + // false so that we can test that channels are correctly + // evaluated at startup. + h.sendUpdatesOn = false + + // Advance channel 0 to create all states and back them + // all up. This will saturate the session with updates + // for channel 0 which means that the session should be + // considered closable when channel 0 is closed. + hints := h.advanceChannelN(0, numUpdates) + h.backupStates(0, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // We expect only 1 session to have updates for this + // channel. + sessionIDs := h.relevantSessions(0) + require.Len(h.t, sessionIDs, 1) + + // Since channel 0 is still open, the session should not + // yet be closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Close the channel. + h.closeChannel(0, 1) + + // Since updates are currently not being sent, we expect + // the session to still not be marked as closable. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Restart the client. + h.client.ForceQuit() + h.startClient() + + // The session should now have been marked as closable. + err := wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we set sendUpdatesOn to true and do the same with + // a new channel. A restart should now not be necessary + // anymore. + h.sendUpdatesOn = true + + h.makeChannel( + 1, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(1) + + hints = h.advanceChannelN(1, numUpdates) + h.backupStates(1, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(1) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Assert that the session is not yet closable since + // the channel is still open. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Now close the channel. + h.closeChannel(1, 1) + + // Since the updates have been turned on, the session + // should now show up as closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + + // Now we test that a session must be exhausted with all + // channels closed before it is seen as closable. + h.makeChannel( + 2, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(2) + + // Fill up only half of the session updates. + hints = h.advanceChannelN(2, numUpdates) + h.backupStates(2, 0, numUpdates/2, nil) + h.waitServerUpdates(hints[:numUpdates/2], waitTime) + + // Determine the ID of the session of interest. + sessionIDs = h.relevantSessions(2) + + // We expect only 1 session to have updates for this + // channel. + require.Len(h.t, sessionIDs, 1) + + // Now close the channel. + h.closeChannel(2, 1) + + // The session should _not_ be closable due to it not + // being exhausted yet. + require.False(h.t, h.isSessionClosable(sessionIDs[0])) + + // Create a new channel. + h.makeChannel( + 3, h.cfg.localBalance, h.cfg.remoteBalance, + ) + h.registerChannel(3) + + hints = h.advanceChannelN(3, numUpdates) + h.backupStates(3, 0, numUpdates, nil) + h.waitServerUpdates(hints, waitTime) + + // Close it. + h.closeChannel(3, 1) + + // Now the session should be closable. + err = wait.Predicate(func() bool { + return h.isSessionClosable(sessionIDs[0]) + }, waitTime) + require.NoError(h.t, err) + }, }, } From 2b08d3443f1b82f96d1b7a2ffb17ef959026a1f4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 11:30:23 +0200 Subject: [PATCH 11/19] watchtowers: add thread safe min-heap In this commit, a thread-safe min-heap is implemented. It will carry sessionCloseItems which carry a sessionID and a block height at which the session should be closed. --- watchtower/wtclient/client.go | 27 +++--- watchtower/wtclient/sess_close_min_heap.go | 95 +++++++++++++++++++ .../wtclient/sess_close_min_heap_test.go | 52 ++++++++++ 3 files changed, 162 insertions(+), 12 deletions(-) create mode 100644 watchtower/wtclient/sess_close_min_heap.go create mode 100644 watchtower/wtclient/sess_close_min_heap_test.go diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index c0a4c2331..b2ce38351 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -271,6 +271,8 @@ type TowerClient struct { sessionQueue *sessionQueue prevTask *backupTask + closableSessionQueue *sessionCloseMinHeap + backupMu sync.Mutex summaries wtdb.ChannelSummaries chanCommitHeights map[lnwire.ChannelID]uint64 @@ -322,18 +324,19 @@ func New(config *Config) (*TowerClient, error) { } c := &TowerClient{ - cfg: cfg, - log: plog, - pipeline: newTaskPipeline(plog), - chanCommitHeights: make(map[lnwire.ChannelID]uint64), - activeSessions: make(sessionQueueSet), - summaries: chanSummaries, - statTicker: time.NewTicker(DefaultStatInterval), - stats: new(ClientStats), - newTowers: make(chan *newTowerMsg), - staleTowers: make(chan *staleTowerMsg), - forceQuit: make(chan struct{}), - quit: make(chan struct{}), + cfg: cfg, + log: plog, + pipeline: newTaskPipeline(plog), + chanCommitHeights: make(map[lnwire.ChannelID]uint64), + activeSessions: make(sessionQueueSet), + summaries: chanSummaries, + closableSessionQueue: newSessionCloseMinHeap(), + statTicker: time.NewTicker(DefaultStatInterval), + stats: new(ClientStats), + newTowers: make(chan *newTowerMsg), + staleTowers: make(chan *staleTowerMsg), + forceQuit: make(chan struct{}), + quit: make(chan struct{}), } // perUpdate is a callback function that will be used to inspect the diff --git a/watchtower/wtclient/sess_close_min_heap.go b/watchtower/wtclient/sess_close_min_heap.go new file mode 100644 index 000000000..c5f58ec1a --- /dev/null +++ b/watchtower/wtclient/sess_close_min_heap.go @@ -0,0 +1,95 @@ +package wtclient + +import ( + "sync" + + "github.com/lightningnetwork/lnd/queue" + "github.com/lightningnetwork/lnd/watchtower/wtdb" +) + +// sessionCloseMinHeap is a thread-safe min-heap implementation that stores +// sessionCloseItem items and prioritises the item with the lowest block height. +type sessionCloseMinHeap struct { + queue queue.PriorityQueue + mu sync.Mutex +} + +// newSessionCloseMinHeap constructs a new sessionCloseMineHeap. +func newSessionCloseMinHeap() *sessionCloseMinHeap { + return &sessionCloseMinHeap{} +} + +// Len returns the length of the queue. +func (h *sessionCloseMinHeap) Len() int { + h.mu.Lock() + defer h.mu.Unlock() + + return h.queue.Len() +} + +// Empty returns true if the queue is empty. +func (h *sessionCloseMinHeap) Empty() bool { + h.mu.Lock() + defer h.mu.Unlock() + + return h.queue.Empty() +} + +// Push adds an item to the priority queue. +func (h *sessionCloseMinHeap) Push(item *sessionCloseItem) { + h.mu.Lock() + defer h.mu.Unlock() + + h.queue.Push(item) +} + +// Pop removes the top most item from the queue. +func (h *sessionCloseMinHeap) Pop() *sessionCloseItem { + h.mu.Lock() + defer h.mu.Unlock() + + if h.queue.Empty() { + return nil + } + + item := h.queue.Pop() + + return item.(*sessionCloseItem) //nolint:forcetypeassert +} + +// Top returns the top most item from the queue without removing it. +func (h *sessionCloseMinHeap) Top() *sessionCloseItem { + h.mu.Lock() + defer h.mu.Unlock() + + if h.queue.Empty() { + return nil + } + + item := h.queue.Top() + + return item.(*sessionCloseItem) //nolint:forcetypeassert +} + +// sessionCloseItem represents a session that is ready to be deleted. +type sessionCloseItem struct { + // sessionID is the ID of the session in question. + sessionID wtdb.SessionID + + // deleteHeight is the block height after which we can delete the + // session. + deleteHeight uint32 +} + +// Less returns true if the current item's delete height is less than the +// other sessionCloseItem's delete height. This results in lower block heights +// being popped first from the heap. +// +// NOTE: this is part of the queue.PriorityQueueItem interface. +func (s *sessionCloseItem) Less(other queue.PriorityQueueItem) bool { + o := other.(*sessionCloseItem).deleteHeight //nolint:forcetypeassert + + return s.deleteHeight < o +} + +var _ queue.PriorityQueueItem = (*sessionCloseItem)(nil) diff --git a/watchtower/wtclient/sess_close_min_heap_test.go b/watchtower/wtclient/sess_close_min_heap_test.go new file mode 100644 index 000000000..9983f5b93 --- /dev/null +++ b/watchtower/wtclient/sess_close_min_heap_test.go @@ -0,0 +1,52 @@ +package wtclient + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +// TestSessionCloseMinHeap asserts that the sessionCloseMinHeap behaves as +// expected. +func TestSessionCloseMinHeap(t *testing.T) { + t.Parallel() + + heap := newSessionCloseMinHeap() + require.Nil(t, heap.Pop()) + require.Nil(t, heap.Top()) + require.True(t, heap.Empty()) + require.Zero(t, heap.Len()) + + // Add an item with height 10. + item1 := &sessionCloseItem{ + sessionID: [33]byte{1, 2, 3}, + deleteHeight: 10, + } + + heap.Push(item1) + require.Equal(t, item1, heap.Top()) + require.False(t, heap.Empty()) + require.EqualValues(t, 1, heap.Len()) + + // Add a bunch more items with heights 1, 2, 6, 11, 6, 30, 9. + heap.Push(&sessionCloseItem{deleteHeight: 1}) + heap.Push(&sessionCloseItem{deleteHeight: 2}) + heap.Push(&sessionCloseItem{deleteHeight: 6}) + heap.Push(&sessionCloseItem{deleteHeight: 11}) + heap.Push(&sessionCloseItem{deleteHeight: 6}) + heap.Push(&sessionCloseItem{deleteHeight: 30}) + heap.Push(&sessionCloseItem{deleteHeight: 9}) + + // Now pop from the queue and assert that the items are returned in + // ascending order. + require.EqualValues(t, 1, heap.Pop().deleteHeight) + require.EqualValues(t, 2, heap.Pop().deleteHeight) + require.EqualValues(t, 6, heap.Pop().deleteHeight) + require.EqualValues(t, 6, heap.Pop().deleteHeight) + require.EqualValues(t, 9, heap.Pop().deleteHeight) + require.EqualValues(t, 10, heap.Pop().deleteHeight) + require.EqualValues(t, 11, heap.Pop().deleteHeight) + require.EqualValues(t, 30, heap.Pop().deleteHeight) + require.Nil(t, heap.Pop()) + require.Zero(t, heap.Len()) +} From 0209e6feb8246b5a61776580c24f5f9541dbbb39 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Thu, 9 Feb 2023 17:05:23 +0200 Subject: [PATCH 12/19] watchtower/wtclient: add Copy method to AddressIterator --- watchtower/wtclient/addr_iterator.go | 33 +++++++++++++++++++ watchtower/wtclient/addr_iterator_test.go | 5 +++ .../wtclient/candidate_iterator_test.go | 6 +--- 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/watchtower/wtclient/addr_iterator.go b/watchtower/wtclient/addr_iterator.go index 87065c011..cb16d335a 100644 --- a/watchtower/wtclient/addr_iterator.go +++ b/watchtower/wtclient/addr_iterator.go @@ -69,6 +69,12 @@ type AddressIterator interface { // Reset clears the iterators state, and makes the address at the front // of the list the next item to be returned. Reset() + + // Copy constructs a new AddressIterator that has the same addresses + // as this iterator. + // + // NOTE that the address locks are not expected to be copied. + Copy() AddressIterator } // A compile-time check to ensure that addressIterator implements the @@ -324,6 +330,33 @@ func (a *addressIterator) GetAll() []net.Addr { a.mu.Lock() defer a.mu.Unlock() + return a.getAllUnsafe() +} + +// Copy constructs a new AddressIterator that has the same addresses +// as this iterator. +// +// NOTE that the address locks will not be copied. +func (a *addressIterator) Copy() AddressIterator { + a.mu.Lock() + defer a.mu.Unlock() + + addrs := a.getAllUnsafe() + + // Since newAddressIterator will only ever return an error if it is + // initialised with zero addresses, we can ignore the error here since + // we are initialising it with the set of addresses of this + // addressIterator which is by definition a non-empty list. + iter, _ := newAddressIterator(addrs...) + + return iter +} + +// getAllUnsafe returns a copy of all the addresses in the iterator. +// +// NOTE: this method is not thread safe and so must only be called once the +// addressIterator mutex is already being held. +func (a *addressIterator) getAllUnsafe() []net.Addr { var addrs []net.Addr cursor := a.addrList.Front() diff --git a/watchtower/wtclient/addr_iterator_test.go b/watchtower/wtclient/addr_iterator_test.go index d3674d985..89a35c4cb 100644 --- a/watchtower/wtclient/addr_iterator_test.go +++ b/watchtower/wtclient/addr_iterator_test.go @@ -97,6 +97,11 @@ func TestAddrIterator(t *testing.T) { addrList := iter.GetAll() require.ElementsMatch(t, addrList, []net.Addr{addr1, addr2, addr3}) + // Also check that an iterator constructed via the Copy method, also + // contains all the expected addresses. + newIterAddrs := iter.Copy().GetAll() + require.ElementsMatch(t, newIterAddrs, []net.Addr{addr1, addr2, addr3}) + // Let's now remove addr3. err = iter.Remove(addr3) require.NoError(t, err) diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 70dfb7505..b4df80f4d 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -52,14 +52,10 @@ func randTower(t *testing.T) *Tower { func copyTower(t *testing.T, tower *Tower) *Tower { t.Helper() - addrs := tower.Addresses.GetAll() - addrIterator, err := newAddressIterator(addrs...) - require.NoError(t, err) - return &Tower{ ID: tower.ID, IdentityKey: tower.IdentityKey, - Addresses: addrIterator, + Addresses: tower.Addresses.Copy(), } } From 8478b56ce6609ccc4b63f59ae09490fe69d92e76 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Mar 2023 10:56:23 +0200 Subject: [PATCH 13/19] watchtower: method to dial tower and send DeleteSession This commit adds a deleteSessionFromTower method which can be used to dial the tower that we created a given session with and then sends that tower the DeleteSession method. --- watchtower/wtclient/client.go | 124 ++++++++++++++++++++++++++++++++++ 1 file changed, 124 insertions(+) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index b2ce38351..464e93f16 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -902,6 +902,130 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, return nil } +// deleteSessionFromTower dials the tower that we created the session with and +// attempts to send the tower the DeleteSession message. +func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { + // First, we check if we have already loaded this tower in our + // candidate towers iterator. + tower, err := c.candidateTowers.GetTower(sess.TowerID) + if errors.Is(err, ErrTowerNotInIterator) { + // If not, then we attempt to load it from the DB. + dbTower, err := c.cfg.DB.LoadTowerByID(sess.TowerID) + if err != nil { + return err + } + + tower, err = NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + } else if err != nil { + return err + } + + session, err := NewClientSessionFromDBSession( + sess, tower, c.cfg.SecretKeyRing, + ) + if err != nil { + return err + } + + localInit := wtwire.NewInitMessage( + lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), + c.cfg.ChainHash, + ) + + var ( + conn wtserver.Peer + + // addrIterator is a copy of the tower's address iterator. + // We use this copy so that iterating through the addresses does + // not affect any other threads using this iterator. + addrIterator = tower.Addresses.Copy() + towerAddr = addrIterator.Peek() + ) + // Attempt to dial the tower with its available addresses. + for { + conn, err = c.dial( + session.SessionKeyECDH, &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: towerAddr, + }, + ) + if err != nil { + // If there are more addrs available, immediately try + // those. + nextAddr, iteratorErr := addrIterator.Next() + if iteratorErr == nil { + towerAddr = nextAddr + continue + } + + // Otherwise, if we have exhausted the address list, + // exit. + addrIterator.Reset() + + return fmt.Errorf("failed to dial tower(%x) at any "+ + "available addresses", + tower.IdentityKey.SerializeCompressed()) + } + + break + } + defer conn.Close() + + // Send Init to tower. + err = c.sendMessage(conn, localInit) + if err != nil { + return err + } + + // Receive Init from tower. + remoteMsg, err := c.readMessage(conn) + if err != nil { + return err + } + + remoteInit, ok := remoteMsg.(*wtwire.Init) + if !ok { + return fmt.Errorf("watchtower %s responded with %T to Init", + towerAddr, remoteMsg) + } + + // Validate Init. + err = localInit.CheckRemoteInit(remoteInit, wtwire.FeatureNames) + if err != nil { + return err + } + + // Send DeleteSession to tower. + err = c.sendMessage(conn, &wtwire.DeleteSession{}) + if err != nil { + return err + } + + // Receive DeleteSessionReply from tower. + remoteMsg, err = c.readMessage(conn) + if err != nil { + return err + } + + deleteSessionReply, ok := remoteMsg.(*wtwire.DeleteSessionReply) + if !ok { + return fmt.Errorf("watchtower %s responded with %T to "+ + "DeleteSession", towerAddr, remoteMsg) + } + + switch deleteSessionReply.Code { + case wtwire.CodeOK, wtwire.DeleteSessionCodeNotFound: + return nil + default: + return fmt.Errorf("received error code %v in "+ + "DeleteSessionReply when attempting to delete "+ + "session from tower", deleteSessionReply.Code) + } +} + // backupDispatcher processes events coming from the taskPipeline and is // responsible for detecting when the client needs to renegotiate a session to // fulfill continuing demand. The event loop exits after all tasks have been From 26e628c0feaf814a1d6288c45ee33613232ebab7 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 20 Mar 2023 11:07:31 +0200 Subject: [PATCH 14/19] watchtowers: handle closable sessions Add a routine to the tower client that informs towers of sessions they can delete and also deletes any info about the session from the client DB. --- lncfg/wtclient.go | 5 + sample-lnd.conf | 6 + server.go | 9 ++ watchtower/wtclient/client.go | 186 ++++++++++++++++++++++++++++- watchtower/wtclient/client_test.go | 132 +++++++++++++++++++- 5 files changed, 331 insertions(+), 7 deletions(-) diff --git a/lncfg/wtclient.go b/lncfg/wtclient.go index 8b9f03939..7d4331112 100644 --- a/lncfg/wtclient.go +++ b/lncfg/wtclient.go @@ -17,6 +17,11 @@ type WtClient struct { // SweepFeeRate specifies the fee rate in sat/byte to be used when // constructing justice transactions sent to the tower. SweepFeeRate uint64 `long:"sweep-fee-rate" description:"Specifies the fee rate in sat/byte to be used when constructing justice transactions sent to the watchtower."` + + // SessionCloseRange is the range over which to choose a random number + // of blocks to wait after the last channel of a session is closed + // before sending the DeleteSession message to the tower server. + SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."` } // Validate ensures the user has provided a valid configuration. diff --git a/sample-lnd.conf b/sample-lnd.conf index 3dfc76da8..f0edda984 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -997,6 +997,12 @@ litecoin.node=ltcd ; supported at this time, if none are provided the tower will not be enabled. ; wtclient.private-tower-uris= +; The range over which to choose a random number of blocks to wait after the +; last channel of a session is closed before sending the DeleteSession message +; to the tower server. The default is currently 288. Note that setting this to +; a lower value will result in faster session cleanup _but_ that this comes +; along with reduced privacy from the tower server. +; wtclient.session-close-range=10 [healthcheck] diff --git a/server.go b/server.go index 61b31ffa6..eef43550b 100644 --- a/server.go +++ b/server.go @@ -1497,6 +1497,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr, policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight() } + sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange) + if cfg.WtClient.SessionCloseRange != 0 { + sessionCloseRange = cfg.WtClient.SessionCloseRange + } + if err := policy.Validate(); err != nil { return nil, err } @@ -1516,6 +1521,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.towerClient, err = wtclient.New(&wtclient.Config{ FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { @@ -1546,6 +1553,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.anchorTowerClient, err = wtclient.New(&wtclient.Config{ FetchClosedChannel: fetchClosedChannel, + SessionCloseRange: sessionCloseRange, + ChainNotifier: s.cc.ChainNotifier, SubscribeChannelEvents: func() (subscribe.Subscription, error) { diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 464e93f16..e92b8b4cf 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -2,8 +2,10 @@ package wtclient import ( "bytes" + "crypto/rand" "errors" "fmt" + "math/big" "net" "sync" "time" @@ -12,6 +14,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" @@ -43,6 +46,11 @@ const ( // client should abandon any pending updates or session negotiations // before terminating. DefaultForceQuitDelay = 10 * time.Second + + // DefaultSessionCloseRange is the range over which we will generate a + // random number of blocks to delay closing a session after its last + // channel has been closed. + DefaultSessionCloseRange = 288 ) // genSessionFilter constructs a filter that can be used to select sessions only @@ -159,6 +167,9 @@ type Config struct { FetchClosedChannel func(cid lnwire.ChannelID) ( *channeldb.ChannelCloseSummary, error) + // ChainNotifier can be used to subscribe to block notifications. + ChainNotifier chainntnfs.ChainNotifier + // NewAddress generates a new on-chain sweep pkscript. NewAddress func() ([]byte, error) @@ -214,6 +225,11 @@ type Config struct { // watchtowers. If the exponential backoff produces a timeout greater // than this value, the backoff will be clamped to MaxBackoff. MaxBackoff time.Duration + + // SessionCloseRange is the range over which we will generate a random + // number of blocks to delay closing a session after its last channel + // has been closed. + SessionCloseRange uint32 } // newTowerMsg is an internal message we'll use within the TowerClient to signal @@ -590,9 +606,34 @@ func (c *TowerClient) Start() error { delete(c.summaries, id) } + // Load all closable sessions. + closableSessions, err := c.cfg.DB.ListClosableSessions() + if err != nil { + returnErr = err + return + } + + err = c.trackClosableSessions(closableSessions) + if err != nil { + returnErr = err + return + } + c.wg.Add(1) go c.handleChannelCloses(chanSub) + // Subscribe to new block events. + blockEvents, err := c.cfg.ChainNotifier.RegisterBlockEpochNtfn( + nil, + ) + if err != nil { + returnErr = err + return + } + + c.wg.Add(1) + go c.handleClosableSessions(blockEvents) + // Now start the session negotiator, which will allow us to // request new session as soon as the backupDispatcher starts // up. @@ -876,7 +917,8 @@ func (c *TowerClient) handleChannelCloses(chanSub subscribe.Subscription) { } // handleClosedChannel handles the closure of a single channel. It will mark the -// channel as closed in the DB. +// channel as closed in the DB, then it will handle all the sessions that are +// now closable due to the channel closure. func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, closeHeight uint32) error { @@ -890,18 +932,146 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, c.log.Debugf("Marking channel(%s) as closed", chanID) - _, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) + sessions, err := c.cfg.DB.MarkChannelClosed(chanID, closeHeight) if err != nil { return fmt.Errorf("could not mark channel(%s) as closed: %w", chanID, err) } + closableSessions := make(map[wtdb.SessionID]uint32, len(sessions)) + for _, sess := range sessions { + closableSessions[sess] = closeHeight + } + + c.log.Debugf("Tracking %d new closable sessions as a result of "+ + "closing channel %s", len(closableSessions), chanID) + + err = c.trackClosableSessions(closableSessions) + if err != nil { + return fmt.Errorf("could not track closable sessions: %w", err) + } + delete(c.summaries, chanID) delete(c.chanCommitHeights, chanID) return nil } +// handleClosableSessions listens for new block notifications. For each block, +// it checks the closableSessionQueue to see if there is a closable session with +// a delete-height smaller than or equal to the new block, if there is then the +// tower is informed that it can delete the session, and then we also delete it +// from our DB. +func (c *TowerClient) handleClosableSessions( + blocksChan *chainntnfs.BlockEpochEvent) { + + defer c.wg.Done() + + c.log.Debug("Starting closable sessions handler") + defer c.log.Debug("Stopping closable sessions handler") + + for { + select { + case newBlock := <-blocksChan.Epochs: + if newBlock == nil { + return + } + + height := uint32(newBlock.Height) + for { + select { + case <-c.quit: + return + default: + } + + // If there are no closable sessions that we + // need to handle, then we are done and can + // reevaluate when the next block comes. + item := c.closableSessionQueue.Top() + if item == nil { + break + } + + // If there is closable session but the delete + // height we have set for it is after the + // current block height, then our work is done. + if item.deleteHeight > height { + break + } + + // Otherwise, we pop this item from the heap + // and handle it. + c.closableSessionQueue.Pop() + + // Fetch the session from the DB so that we can + // extract the Tower info. + sess, err := c.cfg.DB.GetClientSession( + item.sessionID, + ) + if err != nil { + c.log.Errorf("error calling "+ + "GetClientSession for "+ + "session %s: %v", + item.sessionID, err) + + continue + } + + err = c.deleteSessionFromTower(sess) + if err != nil { + c.log.Errorf("error deleting "+ + "session %s from tower: %v", + sess.ID, err) + + continue + } + + err = c.cfg.DB.DeleteSession(item.sessionID) + if err != nil { + c.log.Errorf("could not delete "+ + "session(%s) from DB: %w", + sess.ID, err) + + continue + } + } + + case <-c.forceQuit: + return + + case <-c.quit: + return + } + } +} + +// trackClosableSessions takes in a map of session IDs to the earliest block +// height at which the session should be deleted. For each of the sessions, +// a random delay is added to the block height and the session is added to the +// closableSessionQueue. +func (c *TowerClient) trackClosableSessions( + sessions map[wtdb.SessionID]uint32) error { + + // For each closable session, add a random delay to its close + // height and add it to the closableSessionQueue. + for sID, blockHeight := range sessions { + delay, err := newRandomDelay(c.cfg.SessionCloseRange) + if err != nil { + return err + } + + deleteHeight := blockHeight + delay + + c.closableSessionQueue.Push(&sessionCloseItem{ + sessionID: sID, + deleteHeight: deleteHeight, + }) + } + + return nil +} + // deleteSessionFromTower dials the tower that we created the session with and // attempts to send the tower the DeleteSession message. func (c *TowerClient) deleteSessionFromTower(sess *wtdb.ClientSession) error { @@ -1671,3 +1841,15 @@ func (c *TowerClient) logMessage( preposition, peer.RemotePub().SerializeCompressed(), peer.RemoteAddr()) } + +func newRandomDelay(max uint32) (uint32, error) { + var maxDelay big.Int + maxDelay.SetUint64(uint64(max)) + + randDelay, err := rand.Int(rand.Reader, &maxDelay) + if err != nil { + return 0, err + } + + return uint32(randDelay.Uint64()), nil +} diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 29c4e7a53..2657e691b 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -16,6 +16,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" @@ -396,6 +397,9 @@ type testHarness struct { server *wtserver.Server net *mockNet + blockEvents *mockBlockSub + height int32 + channelEvents *mockSubscription sendUpdatesOn bool @@ -458,6 +462,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { serverDB: serverDB, serverCfg: serverCfg, net: mockNet, + blockEvents: newMockBlockSub(t), channelEvents: newMockSubscription(t), channels: make(map[lnwire.ChannelID]*mockChannel), closedChannels: make(map[lnwire.ChannelID]uint32), @@ -487,6 +492,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return h.channelEvents, nil }, FetchClosedChannel: fetchChannel, + ChainNotifier: h.blockEvents, Dial: mockNet.Dial, DB: clientDB, AuthDial: mockNet.AuthDial, @@ -495,11 +501,12 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { NewAddress: func() ([]byte, error) { return addrScript, nil }, - ReadTimeout: timeout, - WriteTimeout: timeout, - MinBackoff: time.Millisecond, - MaxBackoff: time.Second, - ForceQuitDelay: 10 * time.Second, + ReadTimeout: timeout, + WriteTimeout: timeout, + MinBackoff: time.Millisecond, + MaxBackoff: time.Second, + ForceQuitDelay: 10 * time.Second, + SessionCloseRange: 1, } if !cfg.noServerStart { @@ -518,6 +525,16 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { return h } +// mine mimics the mining of new blocks by sending new block notifications. +func (h *testHarness) mine(numBlocks int) { + h.t.Helper() + + for i := 0; i < numBlocks; i++ { + h.height++ + h.blockEvents.sendNewBlock(h.height) + } +} + // startServer creates a new server using the harness's current serverCfg and // starts it after pointing the mockNet's callback to the new server. func (h *testHarness) startServer() { @@ -909,6 +926,44 @@ func (m *mockSubscription) Updates() <-chan interface{} { return m.updates } +// mockBlockSub mocks out the ChainNotifier. +type mockBlockSub struct { + t *testing.T + events chan *chainntnfs.BlockEpoch + + chainntnfs.ChainNotifier +} + +// newMockBlockSub creates a new mockBlockSub. +func newMockBlockSub(t *testing.T) *mockBlockSub { + t.Helper() + + return &mockBlockSub{ + t: t, + events: make(chan *chainntnfs.BlockEpoch), + } +} + +// RegisterBlockEpochNtfn returns a channel that can be used to listen for new +// blocks. +func (m *mockBlockSub) RegisterBlockEpochNtfn(_ *chainntnfs.BlockEpoch) ( + *chainntnfs.BlockEpochEvent, error) { + + return &chainntnfs.BlockEpochEvent{ + Epochs: m.events, + }, nil +} + +// sendNewBlock will send a new block on the notification channel. +func (m *mockBlockSub) sendNewBlock(height int32) { + select { + case m.events <- &chainntnfs.BlockEpoch{Height: height}: + + case <-time.After(waitTime): + m.t.Fatalf("timed out sending block: %d", height) + } +} + const ( localBalance = lnwire.MilliSatoshi(100000000) remoteBalance = lnwire.MilliSatoshi(200000000) @@ -1891,6 +1946,73 @@ var clientTests = []clientTest{ return h.isSessionClosable(sessionIDs[0]) }, waitTime) require.NoError(h.t, err) + + // Now we will mine a few blocks. This will cause the + // necessary session-close-range to be exceeded meaning + // that the client should send the DeleteSession message + // to the server. We will assert that both the client + // and server have deleted the appropriate sessions and + // channel info. + + // Before we mine blocks, assert that the client + // currently has 3 closable sessions. + closableSess, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + require.Len(h.t, closableSess, 3) + + // Assert that the server is also aware of all of these + // sessions. + for sid := range closableSess { + _, err := h.serverDB.GetSessionInfo(&sid) + require.NoError(h.t, err) + } + + // Also make a note of the total number of sessions the + // client has. + sessions, err := h.clientDB.ListClientSessions(nil, nil) + require.NoError(h.t, err) + require.Len(h.t, sessions, 4) + + h.mine(3) + + // The client should no longer have any closable + // sessions and the total list of client sessions should + // no longer include the three that it previously had + // marked as closable. The server should also no longer + // have these sessions in its DB. + err = wait.Predicate(func() bool { + sess, err := h.clientDB.ListClientSessions( + nil, nil, + ) + require.NoError(h.t, err) + + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + if len(sess) != 1 || len(cs) != 0 { + return false + } + + for sid := range closableSess { + _, ok := sess[sid] + if ok { + return false + } + + _, err := h.serverDB.GetSessionInfo( + &sid, + ) + if !errors.Is( + err, wtdb.ErrSessionNotFound, + ) { + return false + } + } + + return true + + }, waitTime) + require.NoError(h.t, err) }, }, } From d840761cc4a1c75678d29007f6e1ae8f5ad7b93c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 26 Oct 2022 10:31:39 +0200 Subject: [PATCH 15/19] watchtower: dont load closed channel details In this commit, the FetchChanSummaries method is adapted to skip loading any channel summaries if the channel has been marked as closed. --- watchtower/wtclient/interface.go | 3 ++- watchtower/wtdb/client_db.go | 10 +++++++++- watchtower/wtmock/client_db.go | 8 +++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 4dd9176ae..4eebef4e5 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -83,7 +83,8 @@ type DB interface { NumAckedUpdates(id *wtdb.SessionID) (uint64, error) // FetchChanSummaries loads a mapping from all registered channels to - // their channel summaries. + // their channel summaries. Only the channels that have not yet been + // marked as closed will be loaded. FetchChanSummaries() (wtdb.ChannelSummaries, error) // MarkChannelClosed will mark a registered channel as closed by setting diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index c3e5447d6..53c643b6f 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1284,7 +1284,8 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { } // FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. +// channel summaries. Only the channels that have not yet been marked as closed +// will be loaded. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { var summaries map[lnwire.ChannelID]ClientChanSummary @@ -1300,6 +1301,13 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { return ErrCorruptChanDetails } + // If this channel has already been marked as closed, + // then its summary does not need to be loaded. + closedHeight := chanDetails.Get(cChanClosedHeight) + if len(closedHeight) > 0 { + return nil + } + var chanID lnwire.ChannelID copy(chanID[:], k) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 7213f17b6..e004fcdaf 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -566,13 +566,19 @@ func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { } // FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. +// channel summaries. Only the channels that have not yet been marked as closed +// will be loaded. func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { m.mu.Lock() defer m.mu.Unlock() summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) for chanID, channel := range m.channels { + // Don't load the channel if it has been marked as closed. + if channel.closedHeight > 0 { + continue + } + summaries[chanID] = wtdb.ClientChanSummary{ SweepPkScript: cloneBytes( channel.summary.SweepPkScript, From bad80ff583fbbeecd76fd5645a98ffa7f781d4e4 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 6 Dec 2022 13:50:15 +0200 Subject: [PATCH 16/19] multi: make tower MaxUpdates configurable This is helpful in an itest environment where we want to quickly saturate a session. --- lncfg/wtclient.go | 4 ++++ sample-lnd.conf | 3 +++ server.go | 4 ++++ 3 files changed, 11 insertions(+) diff --git a/lncfg/wtclient.go b/lncfg/wtclient.go index 7d4331112..feaae464c 100644 --- a/lncfg/wtclient.go +++ b/lncfg/wtclient.go @@ -22,6 +22,10 @@ type WtClient struct { // of blocks to wait after the last channel of a session is closed // before sending the DeleteSession message to the tower server. SessionCloseRange uint32 `long:"session-close-range" description:"The range over which to choose a random number of blocks to wait after the last channel of a session is closed before sending the DeleteSession message to the tower server. Set to 1 for no delay."` + + // MaxUpdates is the maximum number of updates to be backed up in a + // single tower sessions. + MaxUpdates uint16 `long:"max-updates" description:"The maximum number of updates to be backed up in a single session."` } // Validate ensures the user has provided a valid configuration. diff --git a/sample-lnd.conf b/sample-lnd.conf index f0edda984..b0499294b 100644 --- a/sample-lnd.conf +++ b/sample-lnd.conf @@ -1004,6 +1004,9 @@ litecoin.node=ltcd ; along with reduced privacy from the tower server. ; wtclient.session-close-range=10 +; The maximum number of updates to include in a tower session. +; wtclient.max-updates=1024 + [healthcheck] ; The number of times we should attempt to query our chain backend before diff --git a/server.go b/server.go index eef43550b..2224976e5 100644 --- a/server.go +++ b/server.go @@ -1497,6 +1497,10 @@ func newServer(cfg *Config, listenAddrs []net.Addr, policy.SweepFeeRate = sweepRateSatPerVByte.FeePerKWeight() } + if cfg.WtClient.MaxUpdates != 0 { + policy.MaxUpdates = cfg.WtClient.MaxUpdates + } + sessionCloseRange := uint32(wtclient.DefaultSessionCloseRange) if cfg.WtClient.SessionCloseRange != 0 { sessionCloseRange = cfg.WtClient.SessionCloseRange From 31beacc2c45492c53b9caaef31d5b9c4e720f430 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 17 Mar 2023 12:51:06 +0200 Subject: [PATCH 17/19] lnrpc/wtclientrpc: populate sessions from legacy channels --- lnrpc/wtclientrpc/wtclient.go | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 415ae1702..335c370c8 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -309,6 +309,7 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context, } t.SessionInfo = append(t.SessionInfo, rpcTower.SessionInfo...) + t.Sessions = append(t.Sessions, rpcTower.Sessions...) } towers := make([]*Tower, 0, len(rpcTowers)) @@ -365,6 +366,9 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context, rpcTower.SessionInfo = append( rpcTower.SessionInfo, rpcLegacyTower.SessionInfo..., ) + rpcTower.Sessions = append( + rpcTower.Sessions, rpcLegacyTower.Sessions..., + ) return rpcTower, nil } From 6f4034f7d1100615e13229964fdb7716dea25395 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 6 Dec 2022 13:52:00 +0200 Subject: [PATCH 18/19] lntest/itest: add session deletion itest --- itest/list_on_test.go | 4 + itest/lnd_watchtower_test.go | 172 +++++++++++++++++++++++++++++++++++ lntest/rpc/watchtower.go | 14 +++ 3 files changed, 190 insertions(+) create mode 100644 itest/lnd_watchtower_test.go diff --git a/itest/list_on_test.go b/itest/list_on_test.go index d2253969f..19f083cbb 100644 --- a/itest/list_on_test.go +++ b/itest/list_on_test.go @@ -515,4 +515,8 @@ var allTestCases = []*lntest.TestCase{ Name: "lookup htlc resolution", TestFunc: testLookupHtlcResolution, }, + { + Name: "watchtower session management", + TestFunc: testWatchtowerSessionManagement, + }, } diff --git a/itest/lnd_watchtower_test.go b/itest/lnd_watchtower_test.go new file mode 100644 index 000000000..3432848d8 --- /dev/null +++ b/itest/lnd_watchtower_test.go @@ -0,0 +1,172 @@ +package itest + +import ( + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/lightningnetwork/lnd/funding" + "github.com/lightningnetwork/lnd/lnrpc" + "github.com/lightningnetwork/lnd/lnrpc/routerrpc" + "github.com/lightningnetwork/lnd/lnrpc/wtclientrpc" + "github.com/lightningnetwork/lnd/lntest" + "github.com/lightningnetwork/lnd/lntest/node" + "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/stretchr/testify/require" +) + +// testWatchtowerSessionManagement tests that session deletion is done +// correctly. +func testWatchtowerSessionManagement(ht *lntest.HarnessTest) { + const ( + chanAmt = funding.MaxBtcFundingAmount + paymentAmt = 10_000 + numInvoices = 5 + maxUpdates = numInvoices * 2 + externalIP = "1.2.3.4" + sessionCloseRange = 1 + ) + + // Set up Wallis the watchtower who will be used by Dave to watch over + // his channel commitment transactions. + wallis := ht.NewNode("Wallis", []string{ + "--watchtower.active", + "--watchtower.externalip=" + externalIP, + }) + + wallisInfo := wallis.RPC.GetInfoWatchtower() + + // Assert that Wallis has one listener and it is 0.0.0.0:9911 or + // [::]:9911. Since no listener is explicitly specified, one of these + // should be the default depending on whether the host supports IPv6 or + // not. + require.Len(ht, wallisInfo.Listeners, 1) + listener := wallisInfo.Listeners[0] + require.True(ht, listener == "0.0.0.0:9911" || listener == "[::]:9911") + + // Assert the Wallis's URIs properly display the chosen external IP. + require.Len(ht, wallisInfo.Uris, 1) + require.Contains(ht, wallisInfo.Uris[0], externalIP) + + // Dave will be the tower client. + daveArgs := []string{ + "--wtclient.active", + fmt.Sprintf("--wtclient.max-updates=%d", maxUpdates), + fmt.Sprintf( + "--wtclient.session-close-range=%d", sessionCloseRange, + ), + } + dave := ht.NewNode("Dave", daveArgs) + + addTowerReq := &wtclientrpc.AddTowerRequest{ + Pubkey: wallisInfo.Pubkey, + Address: listener, + } + dave.RPC.AddTower(addTowerReq) + + // Assert that there exists a session between Dave and Wallis. + err := wait.NoError(func() error { + info := dave.RPC.GetTowerInfo(&wtclientrpc.GetTowerInfoRequest{ + Pubkey: wallisInfo.Pubkey, + IncludeSessions: true, + }) + + var numSessions uint32 + for _, sessionType := range info.SessionInfo { + numSessions += sessionType.NumSessions + } + if numSessions > 0 { + return nil + } + + return fmt.Errorf("expected a non-zero number of sessions") + }, defaultTimeout) + require.NoError(ht, err) + + // Before we make a channel, we'll load up Dave with some coins sent + // directly from the miner. + ht.FundCoins(btcutil.SatoshiPerBitcoin, dave) + + // Connect Dave and Alice. + ht.ConnectNodes(dave, ht.Alice) + + // Open a channel between Dave and Alice. + params := lntest.OpenChannelParams{ + Amt: chanAmt, + } + chanPoint := ht.OpenChannel(dave, ht.Alice, params) + + // Since there are 2 updates made for every payment and the maximum + // number of updates per session has been set to 10, make 5 payments + // between the pair so that the session is exhausted. + alicePayReqs, _, _ := ht.CreatePayReqs( + ht.Alice, paymentAmt, numInvoices, + ) + + send := func(node *node.HarnessNode, payReq string) { + stream := node.RPC.SendPayment(&routerrpc.SendPaymentRequest{ + PaymentRequest: payReq, + TimeoutSeconds: 60, + FeeLimitMsat: noFeeLimitMsat, + }) + + ht.AssertPaymentStatusFromStream( + stream, lnrpc.Payment_SUCCEEDED, + ) + } + + for i := 0; i < numInvoices; i++ { + send(dave, alicePayReqs[i]) + } + + // assertNumBackups is a closure that asserts that Dave has a certain + // number of backups backed up to the tower. If mineOnFail is true, + // then a block will be mined each time the assertion fails. + assertNumBackups := func(expected int, mineOnFail bool) { + err = wait.NoError(func() error { + info := dave.RPC.GetTowerInfo( + &wtclientrpc.GetTowerInfoRequest{ + Pubkey: wallisInfo.Pubkey, + IncludeSessions: true, + }, + ) + + var numBackups uint32 + for _, sessionType := range info.SessionInfo { + for _, session := range sessionType.Sessions { + numBackups += session.NumBackups + } + } + + if numBackups == uint32(expected) { + return nil + } + + if mineOnFail { + ht.Miner.MineBlocksSlow(1) + } + + return fmt.Errorf("expected %d backups, got %d", + expected, numBackups) + }, defaultTimeout) + require.NoError(ht, err) + } + + // Assert that one of the sessions now has 10 backups. + assertNumBackups(10, false) + + // Now close the channel and wait for the close transaction to appear + // in the mempool so that it is included in a block when we mine. + ht.CloseChannelAssertPending(dave, chanPoint, false) + + // Mine enough blocks to surpass the session-close-range. This should + // trigger the session to be deleted. + ht.MineBlocksAndAssertNumTxes(sessionCloseRange+6, 1) + + // Wait for the session to be deleted. We know it has been deleted once + // the number of backups is back to zero. We check for number of backups + // instead of number of sessions because it is expected that the client + // would immediately negotiate another session after deleting the + // exhausted one. This time we set the "mineOnFail" parameter to true to + // ensure that the session deleting logic is run. + assertNumBackups(0, true) +} diff --git a/lntest/rpc/watchtower.go b/lntest/rpc/watchtower.go index 1b05d15ea..d5512e7a4 100644 --- a/lntest/rpc/watchtower.go +++ b/lntest/rpc/watchtower.go @@ -24,6 +24,20 @@ func (h *HarnessRPC) GetInfoWatchtower() *watchtowerrpc.GetInfoResponse { return info } +// GetTowerInfo makes an RPC call to the watchtower client of the given node and +// asserts. +func (h *HarnessRPC) GetTowerInfo( + req *wtclientrpc.GetTowerInfoRequest) *wtclientrpc.Tower { + + ctxt, cancel := context.WithTimeout(h.runCtx, DefaultTimeout) + defer cancel() + + info, err := h.WatchtowerClient.GetTowerInfo(ctxt, req) + h.NoError(err, "GetTowerInfo from WatchtowerClient") + + return info +} + // AddTower makes a RPC call to the WatchtowerClient of the given node and // asserts. func (h *HarnessRPC) AddTower( From ab98fc43fe1b459e444884dc26e2eeb3f48d9274 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 21 Oct 2022 22:08:27 +0200 Subject: [PATCH 19/19] docs: add release note for 7069 --- docs/release-notes/release-notes-0.16.1.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/release-notes/release-notes-0.16.1.md b/docs/release-notes/release-notes-0.16.1.md index 983f4f358..2b5cdc801 100644 --- a/docs/release-notes/release-notes-0.16.1.md +++ b/docs/release-notes/release-notes-0.16.1.md @@ -9,6 +9,9 @@ * [Allow caller to filter sessions at the time of reading them from disk](https://github.com/lightningnetwork/lnd/pull/7059) +* [Clean up sessions once all channels for which they have updates for are + closed. Also start sending the `DeleteSession` message to the + tower.](https://github.com/lightningnetwork/lnd/pull/7069) ## Misc