diff --git a/watchtower/wtclient/queue_test.go b/watchtower/wtclient/queue_test.go index 81f96bb7f..529acb49a 100644 --- a/watchtower/wtclient/queue_test.go +++ b/watchtower/wtclient/queue_test.go @@ -18,51 +18,13 @@ const ( waitTime = time.Second * 2 ) -type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID] - // TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected. func TestDiskOverflowQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init initQueue - }{ - { - name: "kvdb", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - db.Close() - }) - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - { - name: "mock", - init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] { - db := wtmock.NewClientDB() - - return db.GetDBQueue([]byte("test-namespace")) - }, - }, - } - tests := []struct { name string - run func(*testing.T, initQueue) + run func(*testing.T, wtdb.Queue[*wtdb.BackupID]) }{ { name: "overflow to disk", @@ -78,29 +40,43 @@ func TestDiskOverflowQueue(t *testing.T) { }, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + initDB := func() wtdb.Queue[*wtdb.BackupID] { + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, + } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - test.run(t, db.init) - }) - } + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) + + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, db.Close()) + }) + + return db.GetDBQueue([]byte("test-namespace")) + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(tt *testing.T) { + tt.Parallel() + + test.run(tt, initDB()) }) } } // testOverflowToDisk is a basic test that ensures that the queue correctly // overflows items to disk and then correctly reloads them. -func testOverflowToDisk(t *testing.T, initQueue initQueue) { +func testOverflowToDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Init the DB. - db := initQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -146,7 +122,9 @@ func testOverflowToDisk(t *testing.T, initQueue initQueue) { // testRestartWithSmallerBufferSize tests that if the queue is restarted with // a smaller in-memory buffer size that it was initially started with, then // tasks are still loaded in the correct order. -func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { +func testRestartWithSmallerBufferSize(t *testing.T, + db wtdb.Queue[*wtdb.BackupID]) { + const ( firstMaxInMemItems = 5 secondMaxInMemItems = 2 @@ -155,9 +133,6 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // Generate some backup IDs that we want to add to the queue. tasks := genBackupIDs(10) - // Create a db. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) @@ -223,14 +198,11 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) { // testStartStopQueue is a stress test that pushes a large number of tasks // through the queue while also restarting the queue a couple of times // throughout. -func testStartStopQueue(t *testing.T, newQueue initQueue) { +func testStartStopQueue(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) { // Generate a lot of backup IDs that we want to add to the // queue one after the other. tasks := genBackupIDs(200_000) - // Construct the ClientDB. - db := newQueue(t) - // New mock logger. log := newMockLogger(t.Logf) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 5bfb4dab5..8be729f21 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/stretchr/testify/require" ) @@ -964,12 +963,6 @@ func TestClientDB(t *testing.T) { return db }, }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, } tests := []struct { diff --git a/watchtower/wtdb/queue_test.go b/watchtower/wtdb/queue_test.go index a864125cf..02c7b272c 100644 --- a/watchtower/wtdb/queue_test.go +++ b/watchtower/wtdb/queue_test.go @@ -4,9 +4,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/kvdb" - "github.com/lightningnetwork/lnd/watchtower/wtclient" "github.com/lightningnetwork/lnd/watchtower/wtdb" - "github.com/lightningnetwork/lnd/watchtower/wtmock" "github.com/stretchr/testify/require" ) @@ -15,53 +13,24 @@ import ( func TestDiskQueue(t *testing.T) { t.Parallel() - dbs := []struct { - name string - init clientDBInit - }{ - { - name: "bbolt", - init: func(t *testing.T) wtclient.DB { - dbCfg := &kvdb.BoltConfig{ - DBTimeout: kvdb.DefaultDBTimeout, - } - - // Construct the ClientDB. - bdb, err := wtdb.NewBoltBackendCreator( - true, t.TempDir(), "wtclient.db", - )(dbCfg) - require.NoError(t, err) - - db, err := wtdb.OpenClientDB(bdb) - require.NoError(t, err) - - t.Cleanup(func() { - err = db.Close() - require.NoError(t, err) - }) - - return db - }, - }, - { - name: "mock", - init: func(t *testing.T) wtclient.DB { - return wtmock.NewClientDB() - }, - }, + dbCfg := &kvdb.BoltConfig{ + DBTimeout: kvdb.DefaultDBTimeout, } - for _, database := range dbs { - db := database - t.Run(db.name, func(t *testing.T) { - t.Parallel() + // Construct the ClientDB. + bdb, err := wtdb.NewBoltBackendCreator( + true, t.TempDir(), "wtclient.db", + )(dbCfg) + require.NoError(t, err) - testQueue(t, db.init(t)) - }) - } -} + db, err := wtdb.OpenClientDB(bdb) + require.NoError(t, err) + + t.Cleanup(func() { + err = db.Close() + require.NoError(t, err) + }) -func testQueue(t *testing.T, db wtclient.DB) { namespace := []byte("test-namespace") queue := db.GetDBQueue(namespace) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go deleted file mode 100644 index f5625d35b..000000000 --- a/watchtower/wtmock/client_db.go +++ /dev/null @@ -1,887 +0,0 @@ -package wtmock - -import ( - "encoding/binary" - "net" - "sync" - "sync/atomic" - - "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" - "github.com/lightningnetwork/lnd/watchtower/blob" - "github.com/lightningnetwork/lnd/watchtower/wtdb" -) - -var byteOrder = binary.BigEndian - -type towerPK [33]byte - -type keyIndexKey struct { - towerID wtdb.TowerID - blobType blob.Type -} - -type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex - -type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore - -type channel struct { - summary *wtdb.ClientChanSummary - closedHeight uint32 - sessions map[wtdb.SessionID]bool -} - -// ClientDB is a mock, in-memory database or testing the watchtower client -// behavior. -type ClientDB struct { - nextTowerID uint64 // to be used atomically - - mu sync.Mutex - channels map[lnwire.ChannelID]*channel - activeSessions map[wtdb.SessionID]wtdb.ClientSession - ackedUpdates rangeIndexArrayMap - persistedAckedUpdates rangeIndexKVStore - committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower - closableSessions map[wtdb.SessionID]uint32 - - nextIndex uint32 - indexes map[keyIndexKey]uint32 - legacyIndexes map[wtdb.TowerID]uint32 - - queues map[string]wtdb.Queue[*wtdb.BackupID] -} - -// NewClientDB initializes a new mock ClientDB. -func NewClientDB() *ClientDB { - return &ClientDB{ - channels: make(map[lnwire.ChannelID]*channel), - activeSessions: make( - map[wtdb.SessionID]wtdb.ClientSession, - ), - ackedUpdates: make(rangeIndexArrayMap), - persistedAckedUpdates: make(rangeIndexKVStore), - committedUpdates: make( - map[wtdb.SessionID][]wtdb.CommittedUpdate, - ), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), - closableSessions: make(map[wtdb.SessionID]uint32), - queues: make(map[string]wtdb.Queue[*wtdb.BackupID]), - } -} - -// CreateTower initialize an address record used to communicate with a -// watchtower. Each Tower is assigned a unique ID, that is used to amortize -// storage costs of the public key when used by multiple sessions. If the tower -// already exists, the address is appended to the list of all addresses used to -// that tower previously and its corresponding sessions are marked as active. -func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var towerPubKey towerPK - copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed()) - - var tower *wtdb.Tower - towerID, ok := m.towerIndex[towerPubKey] - if ok { - tower = m.towers[towerID] - tower.AddAddress(lnAddr.Address) - - towerSessions, err := m.listClientSessions(&towerID) - if err != nil { - return nil, err - } - for id, session := range towerSessions { - session.Status = wtdb.CSessionActive - m.activeSessions[id] = *session - } - } else { - towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) - tower = &wtdb.Tower{ - ID: towerID, - IdentityKey: lnAddr.IdentityKey, - Addresses: []net.Addr{lnAddr.Address}, - } - } - - m.towerIndex[towerPubKey] = towerID - m.towers[towerID] = tower - - return copyTower(tower), nil -} - -// RemoveTower modifies a tower's record within the database. If an address is -// provided, then _only_ the address record should be removed from the tower's -// persisted state. Otherwise, we'll attempt to mark the tower as inactive by -// marking all of its sessions inactive. If any of its sessions has unacked -// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have -// any sessions at all, it'll be completely removed from the database. -// -// NOTE: An error is not returned if the tower doesn't exist. -func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { - m.mu.Lock() - defer m.mu.Unlock() - - tower, err := m.loadTower(pubKey) - if err == wtdb.ErrTowerNotFound { - return nil - } - if err != nil { - return err - } - - if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr - } - m.towers[tower.ID] = tower - return nil - } - - towerSessions, err := m.listClientSessions(&tower.ID) - if err != nil { - return err - } - if len(towerSessions) == 0 { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - delete(m.towerIndex, towerPK) - delete(m.towers, tower.ID) - return nil - } - - for id, session := range towerSessions { - if len(m.committedUpdates[session.ID]) > 0 { - return wtdb.ErrTowerUnackedUpdates - } - session.Status = wtdb.CSessionInactive - m.activeSessions[id] = *session - } - - return nil -} - -// LoadTower retrieves a tower by its public key. -func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - return m.loadTower(pubKey) -} - -// loadTower retrieves a tower by its public key. -// -// NOTE: This method requires the database's lock to be acquired. -func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) { - var towerPK towerPK - copy(towerPK[:], pubKey.SerializeCompressed()) - - towerID, ok := m.towerIndex[towerPK] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - tower, ok := m.towers[towerID] - if !ok { - return nil, wtdb.ErrTowerNotFound - } - - return copyTower(tower), nil -} - -// LoadTowerByID retrieves a tower by its tower ID. -func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - if tower, ok := m.towers[towerID]; ok { - return copyTower(tower), nil - } - - return nil, wtdb.ErrTowerNotFound -} - -// ListTowers retrieves the list of towers available within the database. -func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) { - m.mu.Lock() - defer m.mu.Unlock() - - towers := make([]*wtdb.Tower, 0, len(m.towers)) - for _, tower := range m.towers { - towers = append(towers, copyTower(tower)) - } - - return towers, nil -} - -// MarkBackupIneligible records that particular commit height is ineligible for -// backup. This allows the client to track which updates it should not attempt -// to retry after startup. -func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error { - return nil -} - -// ListClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - return m.listClientSessions(tower, opts...) -} - -// listClientSessions returns the set of all client sessions known to the db. An -// optional tower ID can be used to filter out any client sessions in the -// response that do not correspond to this tower. -func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, - opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, session := range m.activeSessions { - session := session - if tower != nil && *tower != session.TowerID { - continue - } - - if cfg.PreEvaluateFilterFn != nil && - !cfg.PreEvaluateFilterFn(&session) { - - continue - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight( - &session, chanID, index.MaxHeight(), - ) - } - } - - if cfg.PerNumAckedUpdates != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerNumAckedUpdates( - &session, chanID, - uint16(index.NumInSet()), - ) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - if cfg.PostEvaluateFilterFn != nil && - !cfg.PostEvaluateFilterFn(&session) { - - continue - } - - sessions[session.ID] = &session - } - - return sessions, nil -} - -// FetchSessionCommittedUpdates retrieves the current set of un-acked updates -// of the given session. -func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( - []wtdb.CommittedUpdate, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - updates, ok := m.committedUpdates[*id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - return updates, nil -} - -// IsAcked returns true if the given backup has been backed up using the given -// session. -func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, - error) { - - m.mu.Lock() - defer m.mu.Unlock() - - index, ok := m.ackedUpdates[*id][backupID.ChanID] - if !ok { - return false, nil - } - - return index.IsInIndex(backupID.CommitHeight), nil -} - -// NumAckedUpdates returns the number of backups that have been successfully -// backed up using the given session. -func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) { - m.mu.Lock() - defer m.mu.Unlock() - - var numAcked uint64 - - for _, index := range m.ackedUpdates[*id] { - numAcked += index.NumInSet() - } - - return numAcked, nil -} - -// CreateClientSession records a newly negotiated client session in the set of -// active sessions. The session can be identified by its SessionID. -func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { - m.mu.Lock() - defer m.mu.Unlock() - - // Ensure that we aren't overwriting an existing session. - if _, ok := m.activeSessions[session.ID]; ok { - return wtdb.ErrClientSessionAlreadyExists - } - - key := keyIndexKey{ - towerID: session.TowerID, - blobType: session.Policy.BlobType, - } - - // Ensure that a session key index has been reserved for this tower. - keyIndex, err := m.getSessionKeyIndex(key) - if err != nil { - return err - } - - // Ensure that the session's index matches the reserved index. - if keyIndex != session.KeyIndex { - return wtdb.ErrIncorrectKeyIndex - } - - // Remove the key index reservation for this tower. Once committed, this - // permits us to create another session with this tower. - delete(m.indexes, key) - if key.blobType == blob.TypeAltruistCommit { - delete(m.legacyIndexes, key.towerID) - } - - m.activeSessions[session.ID] = wtdb.ClientSession{ - ID: session.ID, - ClientSessionBody: wtdb.ClientSessionBody{ - SeqNum: session.SeqNum, - TowerLastApplied: session.TowerLastApplied, - TowerID: session.TowerID, - KeyIndex: session.KeyIndex, - Policy: session.Policy, - RewardPkScript: cloneBytes(session.RewardPkScript), - }, - } - m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex) - m.persistedAckedUpdates[session.ID] = make( - map[lnwire.ChannelID]*mockKVStore, - ) - m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) - - return nil -} - -// NextSessionKeyIndex reserves a new session key derivation index for a -// particular tower id. The index is reserved for that tower until -// CreateClientSession is invoked for that tower and index, at which point a new -// index for that tower can be reserved. Multiple calls to this method before -// CreateClientSession is invoked should return the same index unless forceNext -// is set to true. -func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, blobType blob.Type, - forceNext bool) (uint32, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - key := keyIndexKey{ - towerID: towerID, - blobType: blobType, - } - - if !forceNext { - if index, err := m.getSessionKeyIndex(key); err == nil { - return index, nil - } - } - - // By default, we use the next available bucket sequence as the key - // index. But if forceNext is true, then it is assumed that some data - // loss occurred and so the sequence is incremented a by a jump of 1000 - // so that we can arrive at a brand new key index quicker. - nextIndex := m.nextIndex + 1 - if forceNext { - nextIndex = m.nextIndex + 1000 - } - m.nextIndex = nextIndex - m.indexes[key] = nextIndex - - return nextIndex, nil -} - -func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) { - if index, ok := m.indexes[key]; ok { - return index, nil - } - - if key.blobType == blob.TypeAltruistCommit { - if index, ok := m.legacyIndexes[key.towerID]; ok { - return index, nil - } - } - - return 0, wtdb.ErrNoReservedKeyIndex -} - -// CommitUpdate persists the CommittedUpdate provided in the slot for (session, -// seqNum). This allows the client to retransmit this update on startup. -func (m *ClientDB) CommitUpdate(id *wtdb.SessionID, - update *wtdb.CommittedUpdate) (uint16, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return 0, wtdb.ErrClientSessionNotFound - } - - // Check if an update has already been committed for this state. - for _, dbUpdate := range m.committedUpdates[session.ID] { - if dbUpdate.SeqNum == update.SeqNum { - // If the breach hint matches, we'll just return the - // last applied value so the client can retransmit. - if dbUpdate.Hint == update.Hint { - return session.TowerLastApplied, nil - } - - // Otherwise, fail since the breach hint doesn't match. - return 0, wtdb.ErrUpdateAlreadyCommitted - } - } - - // Sequence number must increment. - if update.SeqNum != session.SeqNum+1 { - return 0, wtdb.ErrCommitUnorderedUpdate - } - - // Save the update and increment the sequence number. - m.committedUpdates[session.ID] = append( - m.committedUpdates[session.ID], *update, - ) - session.SeqNum++ - m.activeSessions[*id] = session - - return session.TowerLastApplied, nil -} - -// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This -// removes the update from the set of committed updates, and validates the -// lastApplied value returned from the tower. -func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, - lastApplied uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Ensure the returned last applied value does not exceed the highest - // allocated sequence number. - if lastApplied > session.SeqNum { - return wtdb.ErrUnallocatedLastApplied - } - - // Ensure the last applied value isn't lower than a previous one sent by - // the tower. - if lastApplied < session.TowerLastApplied { - return wtdb.ErrLastAppliedReversion - } - - // Retrieve the committed update, failing if none is found. We should - // only receive acks for state updates that we send. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Add sessionID to channel. - channel, ok := m.channels[update.BackupID.ChanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - channel.sessions[*id] = true - - // Remove the committed update from disk and mark the update as - // acked. The tower last applied value is also recorded to send - // along with the next update. - copy(updates[:i], updates[i+1:]) - updates[len(updates)-1] = wtdb.CommittedUpdate{} - m.committedUpdates[session.ID] = updates[:len(updates)-1] - - chanID := update.BackupID.ChanID - if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok { - index, err := wtdb.NewRangeIndex(nil) - if err != nil { - return err - } - - m.ackedUpdates[*id][chanID] = index - m.persistedAckedUpdates[*id][chanID] = newMockKVStore() - } - - err := m.ackedUpdates[*id][chanID].Add( - update.BackupID.CommitHeight, - m.persistedAckedUpdates[*id][chanID], - ) - if err != nil { - return err - } - - session.TowerLastApplied = lastApplied - - m.activeSessions[*id] = session - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// GetDBQueue returns a BackupID Queue instance under the given name space. -func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] { - m.mu.Lock() - defer m.mu.Unlock() - - if q, ok := m.queues[string(namespace)]; ok { - return q - } - - q := NewQueueDB[*wtdb.BackupID]() - m.queues[string(namespace)] = q - - return q -} - -// DeleteCommittedUpdate deletes the committed update with the given sequence -// number from the given session. -func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID, - seqNum uint16) error { - - m.mu.Lock() - defer m.mu.Unlock() - - // Fail if session doesn't exist. - session, ok := m.activeSessions[*id] - if !ok { - return wtdb.ErrClientSessionNotFound - } - - // Retrieve the committed update, failing if none is found. - updates := m.committedUpdates[session.ID] - for i, update := range updates { - if update.SeqNum != seqNum { - continue - } - - // Remove the committed update from "disk". - updates = append(updates[:i], updates[i+1:]...) - m.committedUpdates[session.ID] = updates - - return nil - } - - return wtdb.ErrCommittedUpdateNotFound -} - -// ListClosableSessions fetches and returns the IDs for all sessions marked as -// closable. -func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { - m.mu.Lock() - defer m.mu.Unlock() - - cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) - for id, height := range m.closableSessions { - cs[id] = height - } - - return cs, nil -} - -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) { - m.mu.Lock() - defer m.mu.Unlock() - - summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary) - for chanID, channel := range m.channels { - // Don't load the channel if it has been marked as closed. - if channel.closedHeight > 0 { - continue - } - - summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes( - channel.summary.SweepPkScript, - ), - } - } - - return summaries, nil -} - -// MarkChannelClosed will mark a registered channel as closed by setting -// its closed-height as the given block height. It returns a list of -// session IDs for sessions that are now considered closable due to the -// close of this channel. -func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID, - blockHeight uint32) ([]wtdb.SessionID, error) { - - m.mu.Lock() - defer m.mu.Unlock() - - channel, ok := m.channels[chanID] - if !ok { - return nil, wtdb.ErrChannelNotRegistered - } - - // If there are no sessions for this channel, the channel details can be - // deleted. - if len(channel.sessions) == 0 { - delete(m.channels, chanID) - return nil, nil - } - - // Mark the channel as closed. - channel.closedHeight = blockHeight - - // Now iterate through all the sessions of the channel to check if any - // of them are closeable. - var closableSessions []wtdb.SessionID - for sessID := range channel.sessions { - isClosable, err := m.isSessionClosable(sessID) - if err != nil { - return nil, err - } - - if !isClosable { - continue - } - - closableSessions = append(closableSessions, sessID) - - // Add session to "closableSessions" list and add the block - // height that this last channel was closed in. This will be - // used in future to determine when we should delete the - // session. - m.closableSessions[sessID] = blockHeight - } - - return closableSessions, nil -} - -// isSessionClosable returns true if a session is considered closable. A session -// is considered closable only if: -// 1) It has no un-acked updates -// 2) It is exhausted (ie it cant accept any more updates) -// 3) All the channels that it has acked-updates for are closed. -func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) { - // The session is not closable if it has un-acked updates. - if len(m.committedUpdates[id]) > 0 { - return false, nil - } - - sess, ok := m.activeSessions[id] - if !ok { - return false, wtdb.ErrClientSessionNotFound - } - - // The session is not closable if it is not yet exhausted. - if sess.SeqNum != sess.Policy.MaxUpdates { - return false, nil - } - - // Iterate over each of the channels that the session has acked-updates - // for. If any of those channels are not closed, then the session is - // not yet closable. - for chanID := range m.ackedUpdates[id] { - channel, ok := m.channels[chanID] - if !ok { - continue - } - - // Channel is not yet closed, and so we can not yet delete the - // session. - if channel.closedHeight == 0 { - return false, nil - } - } - - return true, nil -} - -// GetClientSession loads the ClientSession with the given ID from the DB. -func (m *ClientDB) GetClientSession(id wtdb.SessionID, - opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) { - - cfg := wtdb.NewClientSessionCfg() - for _, o := range opts { - o(cfg) - } - - session, ok := m.activeSessions[id] - if !ok { - return nil, wtdb.ErrClientSessionNotFound - } - - if cfg.PerMaxHeight != nil { - for chanID, index := range m.ackedUpdates[session.ID] { - cfg.PerMaxHeight(&session, chanID, index.MaxHeight()) - } - } - - if cfg.PerCommittedUpdate != nil { - for _, update := range m.committedUpdates[session.ID] { - update := update - cfg.PerCommittedUpdate(&session, &update) - } - } - - return &session, nil -} - -// DeleteSession can be called when a session should be deleted from the DB. -// All references to the session will also be deleted from the DB. Note that a -// session will only be deleted if it is considered closable. -func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { - m.mu.Lock() - defer m.mu.Unlock() - - _, ok := m.closableSessions[id] - if !ok { - return wtdb.ErrSessionNotClosable - } - - // For each of the channels, delete the session ID entry. - for chanID := range m.ackedUpdates[id] { - c, ok := m.channels[chanID] - if !ok { - return wtdb.ErrChannelNotRegistered - } - - delete(c.sessions, id) - } - - delete(m.closableSessions, id) - delete(m.activeSessions, id) - - return nil -} - -// RegisterChannel registers a channel for use within the client database. For -// now, all that is stored in the channel summary is the sweep pkscript that -// we'd like any tower sweeps to pay into. In the future, this will be extended -// to contain more info to allow the client efficiently request historical -// states to be backed up under the client's active policy. -func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID, - sweepPkScript []byte) error { - - m.mu.Lock() - defer m.mu.Unlock() - - if _, ok := m.channels[chanID]; ok { - return wtdb.ErrChannelAlreadyRegistered - } - - m.channels[chanID] = &channel{ - summary: &wtdb.ClientChanSummary{ - SweepPkScript: cloneBytes(sweepPkScript), - }, - sessions: make(map[wtdb.SessionID]bool), - } - - return nil -} - -func cloneBytes(b []byte) []byte { - if b == nil { - return nil - } - - bb := make([]byte, len(b)) - copy(bb, b) - - return bb -} - -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ - ID: tower.ID, - IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), - } - copy(t.Addresses, tower.Addresses) - - return t -} - -type mockKVStore struct { - kv map[uint64]uint64 - - err error -} - -func newMockKVStore() *mockKVStore { - return &mockKVStore{ - kv: make(map[uint64]uint64), - } -} - -func (m *mockKVStore) Put(key, value []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - v := byteOrder.Uint64(value) - - m.kv[k] = v - - return nil -} - -func (m *mockKVStore) Delete(key []byte) error { - if m.err != nil { - return m.err - } - - k := byteOrder.Uint64(key) - delete(m.kv, k) - - return nil -}