From f889c9b1cca3d006b8514dd1f3c21c2f6d75b27f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 12 Sep 2023 19:53:16 +0200 Subject: [PATCH 1/7] watchtower: use bbolt db instead of mock DB for client tests The watchtower client test framework currently uses a mock version of the tower client DB. This can lead to bugs if the mock DB works slightly differently to the actual bbolt DB. So this commit ensures that we only use the bbolt db for the tower client tests. We also increment the `waitTime` used in the tests a bit to account for the slightly longer DB read and write times. Doing this switch resulted in one bug being caught: we were not removing sessions from the in-memory set on deletion of the session and so that is fixed here too. --- watchtower/wtclient/client.go | 13 +++++++++++++ watchtower/wtclient/client_test.go | 28 +++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f3b8d3307..412412c1e 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -977,6 +977,19 @@ func (c *TowerClient) handleClosableSessions( // and handle it. c.closableSessionQueue.Pop() + // Stop the session and remove it from the + // in-memory set. + err := c.activeSessions.StopAndRemove( + item.sessionID, + ) + if err != nil { + c.log.Errorf("could not remove "+ + "session(%s) from in-memory "+ + "set: %v", item.sessionID, err) + + return + } + // Fetch the session from the DB so that we can // extract the Tower info. sess, err := c.cfg.DB.GetClientSession( diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 5cb998c9a..a5b774c10 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -21,6 +21,7 @@ import ( "github.com/lightningnetwork/lnd/channelnotifier" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntest/wait" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" @@ -72,7 +73,7 @@ var ( addrScript, _ = txscript.PayToAddrScript(addr) - waitTime = 5 * time.Second + waitTime = 15 * time.Second defaultTxPolicy = wtpolicy.TxPolicy{ BlobType: blob.TypeAltruistCommit, @@ -398,7 +399,7 @@ type testHarness struct { cfg harnessCfg signer *wtmock.MockSigner capacity lnwire.MilliSatoshi - clientDB *wtmock.ClientDB + clientDB *wtdb.ClientDB clientCfg *wtclient.Config client wtclient.Client server *serverHarness @@ -426,10 +427,26 @@ type harnessCfg struct { noServerStart bool } +func newClientDB(t *testing.T) *wtdb.ClientDB { + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, + } + + // Construct the ClientDB. + dir := t.TempDir() + bdb, err := wtdb.NewBoltBackendCreator(true, dir, "wtclient.db")(dbCfg) + require.NoError(t, err) + + clientDB, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + return clientDB +} + func newHarness(t *testing.T, cfg harnessCfg) *testHarness { signer := wtmock.NewMockSigner() mockNet := newMockNet() - clientDB := wtmock.NewClientDB() + clientDB := newClientDB(t) server := newServerHarness( t, mockNet, towerAddrStr, func(serverCfg *wtserver.Config) { @@ -509,6 +526,7 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness { h.startClient() t.Cleanup(func() { require.NoError(t, h.client.Stop()) + require.NoError(t, h.clientDB.Close()) }) h.makeChannel(0, h.cfg.localBalance, h.cfg.remoteBalance) @@ -1342,7 +1360,7 @@ var clientTests = []clientTest{ // Wait for all the updates to be populated in the // server's database. - h.server.waitForUpdates(hints, 3*time.Second) + h.server.waitForUpdates(hints, waitTime) }, }, { @@ -2053,7 +2071,7 @@ var clientTests = []clientTest{ // Now stop the client and reset its database. require.NoError(h.t, h.client.Stop()) - db := wtmock.NewClientDB() + db := newClientDB(h.t) h.clientDB = db h.clientCfg.DB = db From ff0d8fc6190481126dc654095a3302f79a71141c Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 12:21:04 +0200 Subject: [PATCH 2/7] watchtower: completely remove the mock tower client DB Remove the use of the mock tower client DB and use the actual bbolt DB everywhere instead. --- watchtower/wtclient/queue_test.go | 92 ++-- watchtower/wtdb/client_db_test.go | 7 - watchtower/wtdb/queue_test.go | 59 +- watchtower/wtmock/client_db.go | 887 ------------------------------ 4 files changed, 46 insertions(+), 999 deletions(-) delete mode 100644 watchtower/wtmock/client_db.go diff --git a/watchtower/wtclient/queue_test.go b/watchtower/wtclient/queue_test.go index 81f96bb7f..529acb49a 100644 --- a/watchtower/wtclient/queue_test.go +++ b/watchtower/wtclient/queue_test.go @@ -18,51 +18,13 @@ const ( waitTime = time.Second * 2 ) -type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID] - // TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected. func TestDiskOverflowQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init initQueue - }{ - { - name: "kvdb", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - db.Close() - }) - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - { - name: "mock", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - db := wtmock.NewClientDB() - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - } - tests := []struct { name string - run func(*testing.T, initQueue) + run func(*testing.T, wtdb.Queue[*wtdb.BackupID]) }{ { name: "overflow to disk", @@ -78,29 +40,43 @@ func TestDiskOverflowQueue(t *testing.T) { }, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + initDB := func() wtdb.Queue[*wtdb.BackupID] { + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - test.run(t, db.init) - }) - } + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) + + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + return db.GetDBQueue([]byte("test-namespace")) + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + + test.run(tt, initDB()) }) } } // testOverflowToDisk is a basic test that ensures that the queue correctly // overflows items to disk and then correctly reloads them. -func testOverflowToDisk(t *testing.T, initQueue initQueue) { +func testOverflowToDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Init the DB. - db := initQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -146,7 +122,9 @@ func testOverflowToDisk(t *testing.T, initQueue initQueue) { // testRestartWithSmallerBufferSize tests that if the queue is restarted with // a smaller in-memory buffer size that it was initially started with, then // tasks are still loaded in the correct order. -func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { +func testRestartWithSmallerBufferSize(t *testing.T, + db wtdb.Queue[*wtdb.BackupID]) { + const ( firstMaxInMemItems = 5 secondMaxInMemItems = 2 @@ -155,9 +133,6 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Create a db. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -223,14 +198,11 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // testStartStopQueue is a stress test that pushes a large number of tasks // through the queue while also restarting the queue a couple of times // throughout. -func testStartStopQueue(t *testing.T, newQueue initQueue) { +func testStartStopQueue(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate a lot of backup IDs that we want to add to the // queue one after the other. tasks := genBackupIDs(200_000) - // Construct the ClientDB. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 5bfb4dab5..8be729f21 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/stretchr/testify/require" ) @@ -964,12 +963,6 @@ func TestClientDB(t *testing.T) { return db }, }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, } tests := []struct { diff --git a/watchtower/wtdb/queue_test.go b/watchtower/wtdb/queue_test.go index a864125cf..02c7b272c 100644 --- a/watchtower/wtdb/queue_test.go +++ b/watchtower/wtdb/queue_test.go @@ -4,9 +4,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/stretchr/testify/require" ) @@ -15,53 +13,24 @@ import ( func TestDiskQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init clientDBInit - }{ - { - name: "bbolt", - init: func(t *testing.T) wtclient.DB { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - // Construct the ClientDB. - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) - - return db - }, - }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + // Construct the ClientDB. + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) - testQueue(t, db.init(t)) - }) - } -} + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) -func testQueue(t *testing.T, db wtclient.DB) { namespace := []byte("test-namespace") queue := db.GetDBQueue(namespace) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go deleted file mode 100644 index f5625d35b..000000000 --- a/watchtower/wtmock/client_db.go +++ /dev/null @@ -1,887 +0,0 @@ -package wtmock - -import ( - "encoding/binary" - "net" - "sync" - "sync/atomic" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/watchtower/blob" - "github.com/lightningnetwork/lnd/watchtower/wtdb" -) - -var byteOrder = binary.BigEndian - -type towerPK [33]byte - -type keyIndexKey struct { - towerID wtdb.TowerID - blobType blob.Type -} - -type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex - -type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore - -type channel struct { - summary *wtdb.ClientChanSummary - closedHeight uint32 - sessions map[wtdb.SessionID]bool -} - -// ClientDB is a mock, in-memory database or testing the watchtower client -// behavior. -type ClientDB struct { - nextTowerID uint64 // to be used atomically - - mu sync.Mutex - channels map[lnwire.ChannelID]*channel - activeSessions map[wtdb.SessionID]wtdb.ClientSession - ackedUpdates rangeIndexArrayMap - persistedAckedUpdates rangeIndexKVStore - committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower - closableSessions map[wtdb.SessionID]uint32 - - nextIndex uint32 - indexes map[keyIndexKey]uint32 - legacyIndexes map[wtdb.TowerID]uint32 - - queues map[string]wtdb.Queue[*wtdb.BackupID] -} - -// NewClientDB initializes a new mock ClientDB. -func NewClientDB() *ClientDB { - return &ClientDB{ - channels: make(map[lnwire.ChannelID]*channel), - activeSessions: make( - map[wtdb.SessionID]wtdb.ClientSession, - ), - ackedUpdates: make(rangeIndexArrayMap), - persistedAckedUpdates: make(rangeIndexKVStore), - committedUpdates: make( - map[wtdb.SessionID][]wtdb.CommittedUpdate, - ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), - closableSessions: make(map[wtdb.SessionID]uint32), - queues: make(map[string]wtdb.Queue[*wtdb.BackupID]), - } -} - -// CreateTower initialize an address record used to communicate with a -// watchtower. Each Tower is assigned a unique ID, that is used to amortize -// storage costs of the public key when used by multiple sessions. If the tower -// already exists, the address is appended to the list of all addresses used to -// that tower previously and its corresponding sessions are marked as active. -func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var towerPubKey towerPK - copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) - - var tower *wtdb.Tower - towerID, ok := m.towerIndex[towerPubKey] - if ok { - tower = m.towers[towerID] - tower.AddAddress(lnAddr.Address) - - towerSessions, err := m.listClientSessions(&towerID) - if err != nil { - return nil, err - } - for id, session := range towerSessions { - session.Status = wtdb.CSessionActive - m.activeSessions[id] = *session - } - } else { - towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) - tower = &wtdb.Tower{ - ID: towerID, - IdentityKey: lnAddr.IdentityKey, - Addresses: []net.Addr{lnAddr.Address}, - } - } - - m.towerIndex[towerPubKey] = towerID - m.towers[towerID] = tower - - return copyTower(tower), nil -} - -// RemoveTower modifies a tower's record within the database. If an address is -// provided, then _only_ the address record should be removed from the tower's -// persisted state. Otherwise, we'll attempt to mark the tower as inactive by -// marking all of its sessions inactive. If any of its sessions has unacked -// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have -// any sessions at all, it'll be completely removed from the database. -// -// NOTE: An error is not returned if the tower doesn't exist. -func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { - m.mu.Lock() - defer m.mu.Unlock() - - tower, err := m.loadTower(pubKey) - if err == wtdb.ErrTowerNotFound { - return nil - } - if err != nil { - return err - } - - if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr - } - m.towers[tower.ID] = tower - return nil - } - - towerSessions, err := m.listClientSessions(&tower.ID) - if err != nil { - return err - } - if len(towerSessions) == 0 { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - delete(m.towerIndex, towerPK) - delete(m.towers, tower.ID) - return nil - } - - for id, session := range towerSessions { - if len(m.committedUpdates[session.ID]) > 0 { - return wtdb.ErrTowerUnackedUpdates - } - session.Status = wtdb.CSessionInactive - m.activeSessions[id] = *session - } - - return nil -} - -// LoadTower retrieves a tower by its public key. -func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - return m.loadTower(pubKey) -} - -// loadTower retrieves a tower by its public key. -// -// NOTE: This method requires the database's lock to be acquired. -func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - - towerID, ok := m.towerIndex[towerPK] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - tower, ok := m.towers[towerID] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - - return copyTower(tower), nil -} - -// LoadTowerByID retrieves a tower by its tower ID. -func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if tower, ok := m.towers[towerID]; ok { - return copyTower(tower), nil - } - - return nil, wtdb.ErrTowerNotFound -} - -// ListTowers retrieves the list of towers available within the database. -func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - towers := make([]*wtdb.Tower, 0, len(m.towers)) - for _, tower := range m.towers { - towers = append(towers, copyTower(tower)) - } - - return towers, nil -} - -// MarkBackupIneligible records that particular commit height is ineligible for -// backup. This allows the client to track which updates it should not attempt -// to retry after startup. -func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { - return nil -} - -// ListClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - return m.listClientSessions(tower, opts...) -} - -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, session := range m.activeSessions { - session := session - if tower != nil && *tower != session.TowerID { - continue - } - - if cfg.PreEvaluateFilterFn != nil && - !cfg.PreEvaluateFilterFn(&session) { - - continue - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight( - &session, chanID, index.MaxHeight(), - ) - } - } - - if cfg.PerNumAckedUpdates != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerNumAckedUpdates( - &session, chanID, - uint16(index.NumInSet()), - ) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - if cfg.PostEvaluateFilterFn != nil && - !cfg.PostEvaluateFilterFn(&session) { - - continue - } - - sessions[session.ID] = &session - } - - return sessions, nil -} - -// FetchSessionCommittedUpdates retrieves the current set of un-acked updates -// of the given session. -func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( - []wtdb.CommittedUpdate, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - updates, ok := m.committedUpdates[*id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - return updates, nil -} - -// IsAcked returns true if the given backup has been backed up using the given -// session. -func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, - error) { - - m.mu.Lock() - defer m.mu.Unlock() - - index, ok := m.ackedUpdates[*id][backupID.ChanID] - if !ok { - return false, nil - } - - return index.IsInIndex(backupID.CommitHeight), nil -} - -// NumAckedUpdates returns the number of backups that have been successfully -// backed up using the given session. -func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var numAcked uint64 - - for _, index := range m.ackedUpdates[*id] { - numAcked += index.NumInSet() - } - - return numAcked, nil -} - -// CreateClientSession records a newly negotiated client session in the set of -// active sessions. The session can be identified by its SessionID. -func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { - m.mu.Lock() - defer m.mu.Unlock() - - // Ensure that we aren't overwriting an existing session. - if _, ok := m.activeSessions[session.ID]; ok { - return wtdb.ErrClientSessionAlreadyExists - } - - key := keyIndexKey{ - towerID: session.TowerID, - blobType: session.Policy.BlobType, - } - - // Ensure that a session key index has been reserved for this tower. - keyIndex, err := m.getSessionKeyIndex(key) - if err != nil { - return err - } - - // Ensure that the session's index matches the reserved index. - if keyIndex != session.KeyIndex { - return wtdb.ErrIncorrectKeyIndex - } - - // Remove the key index reservation for this tower. Once committed, this - // permits us to create another session with this tower. - delete(m.indexes, key) - if key.blobType == blob.TypeAltruistCommit { - delete(m.legacyIndexes, key.towerID) - } - - m.activeSessions[session.ID] = wtdb.ClientSession{ - ID: session.ID, - ClientSessionBody: wtdb.ClientSessionBody{ - SeqNum: session.SeqNum, - TowerLastApplied: session.TowerLastApplied, - TowerID: session.TowerID, - KeyIndex: session.KeyIndex, - Policy: session.Policy, - RewardPkScript: cloneBytes(session.RewardPkScript), - }, - } - m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex) - m.persistedAckedUpdates[session.ID] = make( - map[lnwire.ChannelID]*mockKVStore, - ) - m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) - - return nil -} - -// NextSessionKeyIndex reserves a new session key derivation index for a -// particular tower id. The index is reserved for that tower until -// CreateClientSession is invoked for that tower and index, at which point a new -// index for that tower can be reserved. Multiple calls to this method before -// CreateClientSession is invoked should return the same index unless forceNext -// is set to true. -func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, blobType blob.Type, - forceNext bool) (uint32, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - key := keyIndexKey{ - towerID: towerID, - blobType: blobType, - } - - if !forceNext { - if index, err := m.getSessionKeyIndex(key); err == nil { - return index, nil - } - } - - // By default, we use the next available bucket sequence as the key - // index. But if forceNext is true, then it is assumed that some data - // loss occurred and so the sequence is incremented a by a jump of 1000 - // so that we can arrive at a brand new key index quicker. - nextIndex := m.nextIndex + 1 - if forceNext { - nextIndex = m.nextIndex + 1000 - } - m.nextIndex = nextIndex - m.indexes[key] = nextIndex - - return nextIndex, nil -} - -func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) { - if index, ok := m.indexes[key]; ok { - return index, nil - } - - if key.blobType == blob.TypeAltruistCommit { - if index, ok := m.legacyIndexes[key.towerID]; ok { - return index, nil - } - } - - return 0, wtdb.ErrNoReservedKeyIndex -} - -// CommitUpdate persists the CommittedUpdate provided in the slot for (session, -// seqNum). This allows the client to retransmit this update on startup. -func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, - update *wtdb.CommittedUpdate) (uint16, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return 0, wtdb.ErrClientSessionNotFound - } - - // Check if an update has already been committed for this state. - for _, dbUpdate := range m.committedUpdates[session.ID] { - if dbUpdate.SeqNum == update.SeqNum { - // If the breach hint matches, we'll just return the - // last applied value so the client can retransmit. - if dbUpdate.Hint == update.Hint { - return session.TowerLastApplied, nil - } - - // Otherwise, fail since the breach hint doesn't match. - return 0, wtdb.ErrUpdateAlreadyCommitted - } - } - - // Sequence number must increment. - if update.SeqNum != session.SeqNum+1 { - return 0, wtdb.ErrCommitUnorderedUpdate - } - - // Save the update and increment the sequence number. - m.committedUpdates[session.ID] = append( - m.committedUpdates[session.ID], *update, - ) - session.SeqNum++ - m.activeSessions[*id] = session - - return session.TowerLastApplied, nil -} - -// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This -// removes the update from the set of committed updates, and validates the -// lastApplied value returned from the tower. -func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, - lastApplied uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Ensure the returned last applied value does not exceed the highest - // allocated sequence number. - if lastApplied > session.SeqNum { - return wtdb.ErrUnallocatedLastApplied - } - - // Ensure the last applied value isn't lower than a previous one sent by - // the tower. - if lastApplied < session.TowerLastApplied { - return wtdb.ErrLastAppliedReversion - } - - // Retrieve the committed update, failing if none is found. We should - // only receive acks for state updates that we send. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Add sessionID to channel. - channel, ok := m.channels[update.BackupID.ChanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - channel.sessions[*id] = true - - // Remove the committed update from disk and mark the update as - // acked. The tower last applied value is also recorded to send - // along with the next update. - copy(updates[:i], updates[i+1:]) - updates[len(updates)-1] = wtdb.CommittedUpdate{} - m.committedUpdates[session.ID] = updates[:len(updates)-1] - - chanID := update.BackupID.ChanID - if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok { - index, err := wtdb.NewRangeIndex(nil) - if err != nil { - return err - } - - m.ackedUpdates[*id][chanID] = index - m.persistedAckedUpdates[*id][chanID] = newMockKVStore() - } - - err := m.ackedUpdates[*id][chanID].Add( - update.BackupID.CommitHeight, - m.persistedAckedUpdates[*id][chanID], - ) - if err != nil { - return err - } - - session.TowerLastApplied = lastApplied - - m.activeSessions[*id] = session - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// GetDBQueue returns a BackupID Queue instance under the given name space. -func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] { - m.mu.Lock() - defer m.mu.Unlock() - - if q, ok := m.queues[string(namespace)]; ok { - return q - } - - q := NewQueueDB[*wtdb.BackupID]() - m.queues[string(namespace)] = q - - return q -} - -// DeleteCommittedUpdate deletes the committed update with the given sequence -// number from the given session. -func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID, - seqNum uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Retrieve the committed update, failing if none is found. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Remove the committed update from "disk". - updates = append(updates[:i], updates[i+1:]...) - m.committedUpdates[session.ID] = updates - - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// ListClosableSessions fetches and returns the IDs for all sessions marked as -// closable. -func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { - m.mu.Lock() - defer m.mu.Unlock() - - cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) - for id, height := range m.closableSessions { - cs[id] = height - } - - return cs, nil -} - -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { - m.mu.Lock() - defer m.mu.Unlock() - - summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, channel := range m.channels { - // Don't load the channel if it has been marked as closed. - if channel.closedHeight > 0 { - continue - } - - summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes( - channel.summary.SweepPkScript, - ), - } - } - - return summaries, nil -} - -// MarkChannelClosed will mark a registered channel as closed by setting -// its closed-height as the given block height. It returns a list of -// session IDs for sessions that are now considered closable due to the -// close of this channel. -func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, - blockHeight uint32) ([]wtdb.SessionID, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - channel, ok := m.channels[chanID] - if !ok { - return nil, wtdb.ErrChannelNotRegistered - } - - // If there are no sessions for this channel, the channel details can be - // deleted. - if len(channel.sessions) == 0 { - delete(m.channels, chanID) - return nil, nil - } - - // Mark the channel as closed. - channel.closedHeight = blockHeight - - // Now iterate through all the sessions of the channel to check if any - // of them are closeable. - var closableSessions []wtdb.SessionID - for sessID := range channel.sessions { - isClosable, err := m.isSessionClosable(sessID) - if err != nil { - return nil, err - } - - if !isClosable { - continue - } - - closableSessions = append(closableSessions, sessID) - - // Add session to "closableSessions" list and add the block - // height that this last channel was closed in. This will be - // used in future to determine when we should delete the - // session. - m.closableSessions[sessID] = blockHeight - } - - return closableSessions, nil -} - -// isSessionClosable returns true if a session is considered closable. A session -// is considered closable only if: -// 1) It has no un-acked updates -// 2) It is exhausted (ie it cant accept any more updates) -// 3) All the channels that it has acked-updates for are closed. -func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { - // The session is not closable if it has un-acked updates. - if len(m.committedUpdates[id]) > 0 { - return false, nil - } - - sess, ok := m.activeSessions[id] - if !ok { - return false, wtdb.ErrClientSessionNotFound - } - - // The session is not closable if it is not yet exhausted. - if sess.SeqNum != sess.Policy.MaxUpdates { - return false, nil - } - - // Iterate over each of the channels that the session has acked-updates - // for. If any of those channels are not closed, then the session is - // not yet closable. - for chanID := range m.ackedUpdates[id] { - channel, ok := m.channels[chanID] - if !ok { - continue - } - - // Channel is not yet closed, and so we can not yet delete the - // session. - if channel.closedHeight == 0 { - return false, nil - } - } - - return true, nil -} - -// GetClientSession loads the ClientSession with the given ID from the DB. -func (m *ClientDB) GetClientSession(id wtdb.SessionID, - opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - session, ok := m.activeSessions[id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - return &session, nil -} - -// DeleteSession can be called when a session should be deleted from the DB. -// All references to the session will also be deleted from the DB. Note that a -// session will only be deleted if it is considered closable. -func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { - m.mu.Lock() - defer m.mu.Unlock() - - _, ok := m.closableSessions[id] - if !ok { - return wtdb.ErrSessionNotClosable - } - - // For each of the channels, delete the session ID entry. - for chanID := range m.ackedUpdates[id] { - c, ok := m.channels[chanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - - delete(c.sessions, id) - } - - delete(m.closableSessions, id) - delete(m.activeSessions, id) - - return nil -} - -// RegisterChannel registers a channel for use within the client database. For -// now, all that is stored in the channel summary is the sweep pkscript that -// we'd like any tower sweeps to pay into. In the future, this will be extended -// to contain more info to allow the client efficiently request historical -// states to be backed up under the client's active policy. -func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, - sweepPkScript []byte) error { - - m.mu.Lock() - defer m.mu.Unlock() - - if _, ok := m.channels[chanID]; ok { - return wtdb.ErrChannelAlreadyRegistered - } - - m.channels[chanID] = &channel{ - summary: &wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), - }, - sessions: make(map[wtdb.SessionID]bool), - } - - return nil -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - - bb := make([]byte, len(b)) - copy(bb, b) - - return bb -} - -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ - ID: tower.ID, - IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), - } - copy(t.Addresses, tower.Addresses) - - return t -} - -type mockKVStore struct { - kv map[uint64]uint64 - - err error -} - -func newMockKVStore() *mockKVStore { - return &mockKVStore{ - kv: make(map[uint64]uint64), - } -} - -func (m *mockKVStore) Put(key, value []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - v := byteOrder.Uint64(value) - - m.kv[k] = v - - return nil -} - -func (m *mockKVStore) Delete(key []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - delete(m.kv, k) - - return nil -} From adb87dcfb8fe985b192d71b7c899f046af89eb86 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 12:33:51 +0200 Subject: [PATCH 3/7] wtclient: demo un-acked update of closed channel bug This commit adds a new test to the tower client to demonstrate a bug that can happen if a channel is closed while an update for it has yet to be acked by the tower server. This will be fixed in an upcomming commit. --- watchtower/wtclient/client_test.go | 138 +++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index a5b774c10..4f375f487 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2416,6 +2416,144 @@ var clientTests = []clientTest{ server2.waitForUpdates(hints[numUpdates/2:], waitTime) }, }, + { + // This test demonstrates a bug that will be addressed in a + // follow-up commit. It shows that if a channel is closed while + // an update for that channel still exists in an in-memory queue + // somewhere then it is possible that all the data for that + // channel gets deleted from the tower client DB. This results + // in an error being thrown in the DB AckUpdate method since it + // will try to find the associated channel data but will not + // find it. + name: "channel closed while update is un-acked", + cfg: harnessCfg{ + localBalance: localBalance, + remoteBalance: remoteBalance, + policy: wtpolicy.Policy{ + TxPolicy: defaultTxPolicy, + MaxUpdates: 5, + }, + }, + fn: func(h *testHarness) { + const ( + numUpdates = 10 + chanIDInt = 0 + ) + + h.sendUpdatesOn = true + + // Advance the channel with a few updates. + hints := h.advanceChannelN(chanIDInt, numUpdates) + + // Backup a few these updates and wait for them to + // arrive at the server. Note that we back up enough + // updates to saturate the session so that the session + // is considered closable when the channel is deleted. + h.backupStates(chanIDInt, 0, numUpdates/2, nil) + h.server.waitForUpdates(hints[:numUpdates/2], waitTime) + + // Now, restart the server in a state where it will not + // ack updates. This will allow us to wait for an + // update to be un-acked and persisted. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = true + }) + + // Backup a few more of the update. These should remain + // in the client as un-acked. + h.backupStates( + chanIDInt, numUpdates/2, numUpdates-1, nil, + ) + + // Wait for the tasks to be bound to sessions. + fetchSessions := h.clientDB.FetchSessionCommittedUpdates + err := wait.Predicate(func() bool { + sessions, err := h.clientDB.ListClientSessions( + nil, + ) + require.NoError(h.t, err) + + var updates []wtdb.CommittedUpdate + for id := range sessions { + updates, err = fetchSessions(&id) + require.NoError(h.t, err) + + if len(updates) != numUpdates-1 { + return true + } + } + + return false + }, waitTime) + require.NoError(h.t, err) + + // Now we close this channel while the update for it has + // not yet been acked. + h.closeChannel(chanIDInt, 1) + + // Closable sessions should now be one. + err = wait.Predicate(func() bool { + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + return len(cs) == 1 + }, waitTime) + require.NoError(h.t, err) + + // Now, restart the server and allow it to ack updates + // again. + h.server.restart(func(cfg *wtserver.Config) { + cfg.NoAckUpdates = false + }) + + // Mine a few blocks so that the session close range is + // surpassed. + h.mine(3) + + // Wait for there to be no more closable sessions on the + // client side. + err = wait.Predicate(func() bool { + cs, err := h.clientDB.ListClosableSessions() + require.NoError(h.t, err) + + return len(cs) == 0 + }, waitTime) + require.NoError(h.t, err) + + // Wait for channel to be "unregistered". + chanID := chanIDFromInt(chanIDInt) + err = wait.Predicate(func() bool { + err := h.client.BackupState(&chanID, 0) + + return errors.Is( + err, wtclient.ErrUnregisteredChannel, + ) + }, waitTime) + require.NoError(h.t, err) + + // Show that the committed update for the closed channel + // remains in the client's DB. + err = wait.Predicate(func() bool { + sessions, err := h.clientDB.ListClientSessions( + nil, + ) + require.NoError(h.t, err) + + var updates []wtdb.CommittedUpdate + for id := range sessions { + updates, err = fetchSessions(&id) + require.NoError(h.t, err) + + if len(updates) != 0 { + return true + } + } + + return false + }, waitTime) + require.NoError(h.t, err) + }, + }, } // TestClient executes the client test suite, asserting the ability to backup From c33cd0ea3897bc41cf468989bee4a682bc5a707e Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 12:51:27 +0200 Subject: [PATCH 4/7] wtdb: refactor getRangesWriteBucket A pure refactor commit that passes the required buckets to the `getRangesWriteBucket` instead of re-fetching them. --- watchtower/wtdb/client_db.go | 35 +++++++++++------------------------ 1 file changed, 11 insertions(+), 24 deletions(-) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 41d80587d..8eef8f03c 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -980,29 +980,8 @@ func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) ( // getRangesWriteBucket gets the range index bucket where the range index for // the given session-channel pair is stored. If any sub-buckets along the way do // not exist, then they are created. -func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID, - chanID lnwire.ChannelID) (kvdb.RwBucket, error) { - - sessions := tx.ReadWriteBucket(cSessionBkt) - if sessions == nil { - return nil, ErrUninitializedDB - } - - chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) - if chanDetailsBkt == nil { - return nil, ErrUninitializedDB - } - - sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:]) - if err != nil { - return nil, err - } - - // Get the DB representation of the channel-ID. - _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) - if err != nil { - return nil, err - } +func getRangesWriteBucket(sessionBkt kvdb.RwBucket, dbChanIDBytes []byte) ( + kvdb.RwBucket, error) { sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists( cSessionAckRangeIndex, @@ -2029,10 +2008,18 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, chanID := committedUpdate.BackupID.ChanID height := committedUpdate.BackupID.CommitHeight + // Get the DB representation of the channel-ID. + _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) + if err != nil { + return err + } + // Get the ranges write bucket before getting the range index to // ensure that the session acks sub-bucket is initialized, so // that we can insert an entry. - rangesBkt, err := getRangesWriteBucket(tx, *id, chanID) + rangesBkt, err := getRangesWriteBucket( + sessionBkt, dbChanIDBytes, + ) if err != nil { return err } From 2a9339805e9111595f9d23c517e66c5d195e8064 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 13:11:00 +0200 Subject: [PATCH 5/7] watchtower: account for rogue updates In this commit, we introduce the concept of a rogue update. An update is rogue if we need to ACK it but we have already deleted all the data for the associated channel due to the channel being closed. In this case, we now no longer error out and instead keep count of how many rogue updates a session has backed-up. --- watchtower/wtclient/client_test.go | 18 ++-- watchtower/wtdb/client_db.go | 153 ++++++++++++++++++++++++++--- watchtower/wtdb/client_db_test.go | 126 +++++++++++++++++++++++- 3 files changed, 271 insertions(+), 26 deletions(-) diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 4f375f487..51dd16e09 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -2417,14 +2417,10 @@ var clientTests = []clientTest{ }, }, { - // This test demonstrates a bug that will be addressed in a - // follow-up commit. It shows that if a channel is closed while - // an update for that channel still exists in an in-memory queue - // somewhere then it is possible that all the data for that - // channel gets deleted from the tower client DB. This results - // in an error being thrown in the DB AckUpdate method since it - // will try to find the associated channel data but will not - // find it. + // This test shows that if a channel is closed while an update + // for that channel still exists in an in-memory queue + // somewhere then it is handled correctly by treating it as a + // rogue update. name: "channel closed while update is un-acked", cfg: harnessCfg{ localBalance: localBalance, @@ -2532,7 +2528,7 @@ var clientTests = []clientTest{ require.NoError(h.t, err) // Show that the committed update for the closed channel - // remains in the client's DB. + // is cleared from the DB. err = wait.Predicate(func() bool { sessions, err := h.clientDB.ListClientSessions( nil, @@ -2545,11 +2541,11 @@ var clientTests = []clientTest{ require.NoError(h.t, err) if len(updates) != 0 { - return true + return false } } - return false + return true }, waitTime) require.NoError(h.t, err) }, diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 8eef8f03c..e7b9c1137 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -50,6 +50,7 @@ var ( // => cSessionDBID -> db-assigned-id // => cSessionCommits => seqnum -> encoded CommittedUpdate // => cSessionAckRangeIndex => db-chan-id => start -> end + // => cSessionRogueUpdateCount -> count cSessionBkt = []byte("client-session-bucket") // cSessionDBID is a key used in the cSessionBkt to store the @@ -68,6 +69,12 @@ var ( // chan-id => start -> end cSessionAckRangeIndex = []byte("client-session-ack-range-index") + // cSessionRogueUpdateCount is a key in the cSessionBkt bucket storing + // the number of rogue updates that were backed up using the session. + // Rogue updates are updates for channels that have been closed already + // at the time of the back-up. + cSessionRogueUpdateCount = []byte("client-session-rogue-update-count") + // cChanIDIndexBkt is a top-level bucket storing: // db-assigned-id -> channel-ID cChanIDIndexBkt = []byte("client-channel-id-index") @@ -1242,10 +1249,23 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { } sessionBkt := sessions.NestedReadBucket(id[:]) - if sessionsBkt == nil { + if sessionBkt == nil { return nil } + // First, account for any rogue updates. + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return err + } + + numAcked += rogueCount + } + + // Then, check if the session-ack-ranges contains any entries + // to account for. sessionAckRanges := sessionBkt.NestedReadBucket( cSessionAckRangeIndex, ) @@ -1525,14 +1545,37 @@ func (c *ClientDB) DeleteSession(id SessionID) error { return err } - // Get the acked updates range index for the session. This is - // used to get the list of channels that the session has updates - // for. ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + + // There is a small chance that the session only contains rogue + // updates. In that case, there will be no ack-ranges index but + // the rogue update count will be equal the MaxUpdates. + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return err + } + + maxUpdates := sess.ClientSessionBody.Policy.MaxUpdates + if rogueCount == uint64(maxUpdates) { + // Do a sanity check to ensure that the acked + // ranges bucket does not exist in this case. + if ackRanges != nil { + return fmt.Errorf("acked updates "+ + "exist for session with a "+ + "max-updates(%d) rogue count", + rogueCount) + } + + return sessionsBkt.DeleteNestedBucket(id[:]) + } + } + + // A session would only be considered closable if it was + // exhausted. Meaning that it should not be the case that it has + // no acked-updates. if ackRanges == nil { - // A session would only be considered closable if it - // was exhausted. Meaning that it should not be the - // case that it has no acked-updates. return fmt.Errorf("cannot delete session %s since it "+ "is not yet exhausted", id) } @@ -1763,6 +1806,22 @@ func isSessionClosable(sessionsBkt, chanDetailsBkt, chanIDIndexBkt kvdb.RBucket, return false, nil } + // Either the acked-update bucket should exist _or_ the rogue update + // count must be equal to the session's MaxUpdates value, otherwise + // something is wrong because the above check ensures that the session + // has been exhausted. + rogueCountBytes := sessBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + rogueCount, err := readBigSize(rogueCountBytes) + if err != nil { + return false, err + } + + if rogueCount == uint64(session.Policy.MaxUpdates) { + return true, nil + } + } + // If the session has no acked-updates, then something is wrong since // the above check ensures that this session has been exhausted meaning // that it should have MaxUpdates acked updates. @@ -2005,12 +2064,83 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } + dbSessionID, dbSessIDBytes, err := getDBSessionID(sessions, *id) + if err != nil { + return err + } + chanID := committedUpdate.BackupID.ChanID height := committedUpdate.BackupID.CommitHeight - // Get the DB representation of the channel-ID. + // Get the DB representation of the channel-ID. There is a + // chance that the channel corresponding to this update has been + // closed and that the details for this channel no longer exist + // in the tower client DB. In that case, we consider this a + // rogue update and all we do is make sure to keep track of the + // number of rogue updates for this session. _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) - if err != nil { + if errors.Is(err, ErrChannelNotRegistered) { + var ( + count uint64 + err error + ) + + rogueCountBytes := sessionBkt.Get( + cSessionRogueUpdateCount, + ) + if len(rogueCountBytes) != 0 { + count, err = readBigSize(rogueCountBytes) + if err != nil { + return err + } + } + + rogueCount := count + 1 + countBytes, err := writeBigSize(rogueCount) + if err != nil { + return err + } + + err = sessionBkt.Put( + cSessionRogueUpdateCount, countBytes, + ) + if err != nil { + return err + } + + // In the rare chance that this session only has rogue + // updates, we check here if the count is equal to the + // MaxUpdate of the session. If it is, then we mark the + // session as closable. + if rogueCount != uint64(session.Policy.MaxUpdates) { + return nil + } + + // Before we mark the session as closable, we do a + // sanity check to ensure that this session has no + // acked-update index. + sessionAckRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if sessionAckRanges != nil { + return fmt.Errorf("session(%s) has an "+ + "acked ranges index but has a rogue "+ + "count indicating saturation", + session.ID) + } + + closableSessBkt := tx.ReadWriteBucket( + cClosableSessionsBkt, + ) + if closableSessBkt == nil { + return ErrUninitializedDB + } + + var height [4]byte + byteOrder.PutUint32(height[:], 0) + + return closableSessBkt.Put(dbSessIDBytes, height[:]) + } else if err != nil { return err } @@ -2024,11 +2154,6 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } - dbSessionID, _, err := getDBSessionID(sessions, *id) - if err != nil { - return err - } - chanDetails := chanDetailsBkt.NestedReadWriteBucket( committedUpdate.BackupID.ChanID[:], ) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 8be729f21..6d11a6972 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -675,6 +675,98 @@ func testCommitUpdate(h *clientDBHarness) { h.assertUpdates(session.ID, []wtdb.CommittedUpdate{}, nil) } +// testRogueUpdates asserts that rogue updates (updates for channels that are +// backed up after the channel has been closed and the channel details deleted +// from the DB) are handled correctly. +func testRogueUpdates(h *clientDBHarness) { + const maxUpdates = 5 + + tower := h.newTower() + + // Create and insert a new session. + session1 := h.randSession(h.t, tower.ID, maxUpdates) + h.insertSession(session1, nil) + + // Create a new channel and register it. + chanID1 := randChannelID(h.t) + h.registerChan(chanID1, nil, nil) + + // Num acked updates should be 0. + require.Zero(h.t, h.numAcked(&session1.ID, nil)) + + // Commit and ACK enough updates for this channel to fill the session. + for i := 1; i <= maxUpdates; i++ { + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(i), uint64(i), + ) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil) + } + + // Num acked updates should now be 5. + require.EqualValues(h.t, maxUpdates, h.numAcked(&session1.ID, nil)) + + // Commit one more update for the channel but this time do not ACK it. + // This update will be put in a new session since the previous one has + // been exhausted. + session2 := h.randSession(h.t, tower.ID, maxUpdates) + sess2Seq := 1 + h.insertSession(session2, nil) + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(sess2Seq), uint64(maxUpdates+1), + ) + lastApplied := h.commitUpdate(&session2.ID, update, nil) + + // Session 2 should not have any acked updates yet. + require.Zero(h.t, h.numAcked(&session2.ID, nil)) + + // There should currently be no closable sessions. + require.Empty(h.t, h.listClosableSessions(nil)) + + // Now mark the channel as closed. + h.markChannelClosed(chanID1, 1, nil) + + // Assert that session 1 is now seen as closable. + closableSessionsMap := h.listClosableSessions(nil) + require.Len(h.t, closableSessionsMap, 1) + _, ok := closableSessionsMap[session1.ID] + require.True(h.t, ok) + + // Delete session 1. + h.deleteSession(session1.ID, nil) + + // Now try to ACK the update for the channel. This should succeed and + // the update should be considered a rogue update. + h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil) + + // Show that the number of acked updates is now 1. + require.EqualValues(h.t, 1, h.numAcked(&session2.ID, nil)) + + // We also want to test the extreme case where all the updates for a + // particular session are rogue updates. In this case, the session + // should be seen as closable if it is saturated. + + // First show that the session is not yet considered closable. + require.Empty(h.t, h.listClosableSessions(nil)) + + // Then, let's continue adding rogue updates for the closed channel to + // session 2. + for i := maxUpdates + 2; i <= maxUpdates*2; i++ { + sess2Seq++ + + update := randCommittedUpdateForChanWithHeight( + h.t, chanID1, uint16(sess2Seq), uint64(i), + ) + lastApplied := h.commitUpdate(&session2.ID, update, nil) + h.ackUpdate(&session2.ID, uint16(sess2Seq), lastApplied, nil) + } + + // At this point, session 2 is saturated with rogue updates. Assert that + // it is now closable. + closableSessionsMap = h.listClosableSessions(nil) + require.Len(h.t, closableSessionsMap, 1) +} + // testMarkChannelClosed asserts the behaviour of MarkChannelClosed. func testMarkChannelClosed(h *clientDBHarness) { tower := h.newTower() @@ -762,7 +854,7 @@ func testMarkChannelClosed(h *clientDBHarness) { require.EqualValues(h.t, 4, lastApplied) h.ackUpdate(&session1.ID, 5, 5, nil) - // The session is no exhausted. + // The session is now exhausted. // If we now close channel 5, session 1 should still not be closable // since it has an update for channel 6 which is still open. sl = h.markChannelClosed(chanID5, 1, nil) @@ -1001,6 +1093,10 @@ func TestClientDB(t *testing.T) { name: "mark channel closed", run: testMarkChannelClosed, }, + { + name: "rogue updates", + run: testRogueUpdates, + }, } for _, database := range dbs { @@ -1066,6 +1162,34 @@ func randCommittedUpdateForChannel(t *testing.T, chanID lnwire.ChannelID, } } +// randCommittedUpdateForChanWithHeight generates a random committed update for +// the given channel ID using the given commit height. +func randCommittedUpdateForChanWithHeight(t *testing.T, chanID lnwire.ChannelID, + seqNum uint16, height uint64) *wtdb.CommittedUpdate { + + t.Helper() + + var hint blob.BreachHint + _, err := io.ReadFull(crand.Reader, hint[:]) + require.NoError(t, err) + + encBlob := make([]byte, blob.Size(blob.FlagCommitOutputs.Type())) + _, err = io.ReadFull(crand.Reader, encBlob) + require.NoError(t, err) + + return &wtdb.CommittedUpdate{ + SeqNum: seqNum, + CommittedUpdateBody: wtdb.CommittedUpdateBody{ + BackupID: wtdb.BackupID{ + ChanID: chanID, + CommitHeight: height, + }, + Hint: hint, + EncryptedBlob: encBlob, + }, + } +} + func (h *clientDBHarness) randSession(t *testing.T, towerID wtdb.TowerID, maxUpdates uint16) *wtdb.ClientSession { From 273b934f62fd1ea1e8daeb56aeb226e71a23843b Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 14:10:45 +0200 Subject: [PATCH 6/7] wtdb+lnrpc: return correct NumAckedUpdates from wtclientrpc --- lnrpc/wtclientrpc/wtclient.go | 5 +++++ watchtower/wtdb/client_db.go | 39 +++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 1f45335fb..228877743 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -390,6 +390,10 @@ func constructFunctionalOptions(includeSessions, return opts, ackCounts, committedUpdateCounts } + perNumRogueUpdates := func(s *wtdb.ClientSession, numUpdates uint16) { + ackCounts[s.ID] += numUpdates + } + perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID, numUpdates uint16) { @@ -405,6 +409,7 @@ func constructFunctionalOptions(includeSessions, opts = []wtdb.ClientSessionListOption{ wtdb.WithPerNumAckedUpdates(perNumAckedUpdates), wtdb.WithPerCommittedUpdate(perCommittedUpdate), + wtdb.WithPerRogueUpdateCount(perNumRogueUpdates), } if excludeExhaustedSessions { diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index e7b9c1137..084f2dcfe 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2285,6 +2285,11 @@ type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64) // number of updates that the session has for the channel. type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16) +// PerRogueUpdateCountCB describes the signature of a callback function that can +// be called for each session with the number of rogue updates that the session +// has. +type PerRogueUpdateCountCB func(*ClientSession, uint16) + // PerAckedUpdateCB describes the signature of a callback function that can be // called for each of a session's acked updates. type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) @@ -2307,6 +2312,10 @@ type ClientSessionListCfg struct { // channel. PerNumAckedUpdates PerNumAckedUpdatesCB + // PerRogueUpdateCount will, if set, be called with the number of rogue + // updates that the session has backed up. + PerRogueUpdateCount PerRogueUpdateCountCB + // PerMaxHeight will, if set, be called for each of the session's // channels to communicate the highest commit height of updates stored // for that channel. @@ -2354,6 +2363,15 @@ func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption { } } +// WithPerRogueUpdateCount constructs a functional option that will set a +// call-back function to be called with the number of rogue updates that the +// session has backed up. +func WithPerRogueUpdateCount(cb PerRogueUpdateCountCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerRogueUpdateCount = cb + } +} + // WithPerCommittedUpdate constructs a functional option that will set a // call-back function to be called for each of a client's un-acked updates. func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { @@ -2422,7 +2440,7 @@ func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, // provided. err = c.filterClientSessionAcks( sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight, - cfg.PerNumAckedUpdates, + cfg.PerNumAckedUpdates, cfg.PerRogueUpdateCount, ) if err != nil { return nil, err @@ -2480,7 +2498,24 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, // call back if one is provided. func (c *ClientDB) filterClientSessionAcks(sessionBkt, chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB, - perNumAckedUpdates PerNumAckedUpdatesCB) error { + perNumAckedUpdates PerNumAckedUpdatesCB, + perRogueUpdateCount PerRogueUpdateCountCB) error { + + if perRogueUpdateCount != nil { + var ( + count uint64 + err error + ) + rogueCountBytes := sessionBkt.Get(cSessionRogueUpdateCount) + if len(rogueCountBytes) != 0 { + count, err = readBigSize(rogueCountBytes) + if err != nil { + return err + } + } + + perRogueUpdateCount(s, uint16(count)) + } if perMaxCb == nil && perNumAckedUpdates == nil { return nil From 95c2bfe181be9619e255987ab75fb313b77c6ca2 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 13 Sep 2023 14:41:24 +0200 Subject: [PATCH 7/7] docs: add entry for 7981 --- docs/release-notes/release-notes-0.17.0.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/release-notes/release-notes-0.17.0.md b/docs/release-notes/release-notes-0.17.0.md index a74a80421..9b0329da6 100644 --- a/docs/release-notes/release-notes-0.17.0.md +++ b/docs/release-notes/release-notes-0.17.0.md @@ -71,6 +71,10 @@ fails](https://github.com/lightningnetwork/lnd/pull/7876). retried](https://github.com/lightningnetwork/lnd/pull/7927) with an exponential back off. +* In the watchtower client, we [now explicitly + handle](https://github.com/lightningnetwork/lnd/pull/7981) the scenario where + a channel is closed while we still have an in-memory update for it. + # New Features ## Functional Enhancements ### Protocol Features