diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index e5fc5d22b..b5d2418ab 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -95,6 +95,10 @@ type DB interface { MarkChannelClosed(chanID lnwire.ChannelID, blockHeight uint32) ( []wtdb.SessionID, error) + // ListClosableSessions fetches and returns the IDs for all sessions + // marked as closable. + ListClosableSessions() (map[wtdb.SessionID]uint32, error) + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index d88fd631e..4a74c7bb4 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -1385,6 +1385,46 @@ func (c *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, return nil } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) { + sessions := make(map[SessionID]uint32) + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + csBkt := tx.ReadBucket(cClosableSessionsBkt) + if csBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + return csBkt.ForEach(func(dbIDBytes, heightBytes []byte) error { + dbID, err := readBigSize(dbIDBytes) + if err != nil { + return err + } + + sessID, err := getRealSessionID(sessIDIndexBkt, dbID) + if err != nil { + return err + } + + sessions[*sessID] = byteOrder.Uint32(heightBytes) + + return nil + }) + }, func() { + sessions = make(map[SessionID]uint32) + }) + if err != nil { + return nil, err + } + + return sessions, nil +} + // MarkChannelClosed will mark a registered channel as closed by setting its // closed-height as the given block height. It returns a list of session IDs for // sessions that are now considered closable due to the close of this channel. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 4f5f80749..73f4e5550 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -207,6 +207,17 @@ func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, return closableSessions } +func (h *clientDBHarness) listClosableSessions( + expErr error) map[wtdb.SessionID]uint32 { + + h.t.Helper() + + closableSessions, err := h.db.ListClosableSessions() + require.ErrorIs(h.t, err, expErr) + + return closableSessions +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -711,11 +722,16 @@ func testMarkChannelClosed(h *clientDBHarness) { // since it has an update for channel 6 which is still open. sl = h.markChannelClosed(chanID5, 1, nil) require.Empty(h.t, sl) + require.Empty(h.t, h.listClosableSessions(nil)) // Finally, if we close channel 6, session 1 _should_ be in the closable // list. - sl = h.markChannelClosed(chanID6, 1, nil) + sl = h.markChannelClosed(chanID6, 100, nil) require.ElementsMatch(h.t, sl, []wtdb.SessionID{session1.ID}) + slMap := h.listClosableSessions(nil) + require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{ + session1.ID: 100, + }, 0) } // testAckUpdate asserts the behavior of AckUpdate. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 2820d74cd..f439e9182 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -551,6 +551,20 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, return wtdb.ErrCommittedUpdateNotFound } +// ListClosableSessions fetches and returns the IDs for all sessions marked as +// closable. +func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) { + m.mu.Lock() + defer m.mu.Unlock() + + cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions)) + for id, height := range m.closableSessions { + cs[id] = height + } + + return cs, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {