diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 412412c1e..f9036e87f 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -17,6 +17,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channelnotifier" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" @@ -295,9 +296,8 @@ type TowerClient struct { closableSessionQueue *sessionCloseMinHeap - backupMu sync.Mutex - summaries wtdb.ChannelSummaries - chanCommitHeights map[lnwire.ChannelID]uint64 + backupMu sync.Mutex + chanInfos wtdb.ChannelInfos statTicker *time.Ticker stats *ClientStats @@ -339,9 +339,7 @@ func New(config *Config) (*TowerClient, error) { plog := build.NewPrefixLog(prefix, log) - // Load the sweep pkscripts that have been generated for all previously - // registered channels. - chanSummaries, err := cfg.DB.FetchChanSummaries() + chanInfos, err := cfg.DB.FetchChanInfos() if err != nil { return nil, err } @@ -358,9 +356,8 @@ func New(config *Config) (*TowerClient, error) { cfg: cfg, log: plog, pipeline: queue, - chanCommitHeights: make(map[lnwire.ChannelID]uint64), activeSessions: newSessionQueueSet(), - summaries: chanSummaries, + chanInfos: chanInfos, closableSessionQueue: newSessionCloseMinHeap(), statTicker: time.NewTicker(DefaultStatInterval), stats: new(ClientStats), @@ -369,44 +366,6 @@ func New(config *Config) (*TowerClient, error) { quit: make(chan struct{}), } - // perUpdate is a callback function that will be used to inspect the - // full set of candidate client sessions loaded from disk, and to - // determine the highest known commit height for each channel. This - // allows the client to reject backups that it has already processed for - // its active policy. - perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID, - commitHeight uint64) { - - // We only want to consider accepted updates that have been - // accepted under an identical policy to the client's current - // policy. - if policy != c.cfg.Policy { - return - } - - c.backupMu.Lock() - defer c.backupMu.Unlock() - - // Take the highest commit height found in the session's acked - // updates. - height, ok := c.chanCommitHeights[chanID] - if !ok || commitHeight > height { - c.chanCommitHeights[chanID] = commitHeight - } - } - - perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID, - height uint64) { - - perUpdate(s.Policy, chanID, height) - } - - perCommittedUpdate := func(s *wtdb.ClientSession, - u *wtdb.CommittedUpdate) { - - perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight) - } - candidateTowers := newTowerListIterator() perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is @@ -429,8 +388,6 @@ func New(config *Config) (*TowerClient, error) { candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, perActiveTower, wtdb.WithPreEvalFilterFn(c.genSessionFilter(true)), - wtdb.WithPerMaxHeight(perMaxHeight), - wtdb.WithPerCommittedUpdate(perCommittedUpdate), wtdb.WithPostEvalFilterFn(ExhaustedSessionFilter()), ) if err != nil { @@ -594,7 +551,7 @@ func (c *TowerClient) Start() error { // Iterate over the list of registered channels and check if // any of them can be marked as closed. - for id := range c.summaries { + for id := range c.chanInfos { isClosed, closedHeight, err := c.isChannelClosed(id) if err != nil { returnErr = err @@ -615,7 +572,7 @@ func (c *TowerClient) Start() error { // Since the channel has been marked as closed, we can // also remove it from the channel summaries map. - delete(c.summaries, id) + delete(c.chanInfos, id) } // Load all closable sessions. @@ -732,7 +689,7 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // If a pkscript for this channel already exists, the channel has been // previously registered. - if _, ok := c.summaries[chanID]; ok { + if _, ok := c.chanInfos[chanID]; ok { return nil } @@ -752,8 +709,10 @@ func (c *TowerClient) RegisterChannel(chanID lnwire.ChannelID) error { // Finally, cache the pkscript in our in-memory cache to avoid db // lookups for the remainder of the daemon's execution. - c.summaries[chanID] = wtdb.ClientChanSummary{ - SweepPkScript: pkScript, + c.chanInfos[chanID] = &wtdb.ChannelInfo{ + ClientChanSummary: wtdb.ClientChanSummary{ + SweepPkScript: pkScript, + }, } return nil @@ -770,16 +729,23 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // Make sure that this channel is registered with the tower client. c.backupMu.Lock() - if _, ok := c.summaries[*chanID]; !ok { + info, ok := c.chanInfos[*chanID] + if !ok { c.backupMu.Unlock() return ErrUnregisteredChannel } // Ignore backups that have already been presented to the client. - height, ok := c.chanCommitHeights[*chanID] - if ok && stateNum <= height { + var duplicate bool + info.MaxHeight.WhenSome(func(maxHeight uint64) { + if stateNum <= maxHeight { + duplicate = true + } + }) + if duplicate { c.backupMu.Unlock() + c.log.Debugf("Ignoring duplicate backup for chanid=%v at "+ "height=%d", chanID, stateNum) @@ -789,7 +755,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, // This backup has a higher commit height than any known backup for this // channel. We'll update our tip so that we won't accept it again if the // link flaps. - c.chanCommitHeights[*chanID] = stateNum + c.chanInfos[*chanID].MaxHeight = fn.Some(stateNum) c.backupMu.Unlock() id := &wtdb.BackupID{ @@ -899,7 +865,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, defer c.backupMu.Unlock() // We only care about channels registered with the tower client. - if _, ok := c.summaries[chanID]; !ok { + if _, ok := c.chanInfos[chanID]; !ok { return nil } @@ -924,8 +890,7 @@ func (c *TowerClient) handleClosedChannel(chanID lnwire.ChannelID, return fmt.Errorf("could not track closable sessions: %w", err) } - delete(c.summaries, chanID) - delete(c.chanCommitHeights, chanID) + delete(c.chanInfos, chanID) return nil } @@ -1332,7 +1297,7 @@ func (c *TowerClient) backupDispatcher() { // the prevTask, and should be reprocessed after obtaining a new sessionQueue. func (c *TowerClient) processTask(task *wtdb.BackupID) { c.backupMu.Lock() - summary, ok := c.summaries[task.ChanID] + summary, ok := c.chanInfos[task.ChanID] if !ok { c.backupMu.Unlock() diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 0f7f1b539..e69155288 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -81,10 +81,10 @@ type DB interface { // successfully backed up using the given session. NumAckedUpdates(id *wtdb.SessionID) (uint64, error) - // FetchChanSummaries loads a mapping from all registered channels to - // their channel summaries. Only the channels that have not yet been + // FetchChanInfos loads a mapping from all registered channels to + // their wtdb.ChannelInfo. Only the channels that have not yet been // marked as closed will be loaded. - FetchChanSummaries() (wtdb.ChannelSummaries, error) + FetchChanInfos() (wtdb.ChannelInfos, error) // MarkChannelClosed will mark a registered channel as closed by setting // its closed-height as the given block height. It returns a list of diff --git a/watchtower/wtdb/client_chan_summary.go b/watchtower/wtdb/client_chan_summary.go index d4b3c3c38..6fec34c84 100644 --- a/watchtower/wtdb/client_chan_summary.go +++ b/watchtower/wtdb/client_chan_summary.go @@ -3,11 +3,29 @@ package wtdb import ( "io" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lnwire" ) -// ChannelSummaries is a map for a given channel id to it's ClientChanSummary. -type ChannelSummaries map[lnwire.ChannelID]ClientChanSummary +// ChannelInfos is a map for a given channel id to it's ChannelInfo. +type ChannelInfos map[lnwire.ChannelID]*ChannelInfo + +// ChannelInfo contains various useful things about a registered channel. +// +// NOTE: the reason for adding this struct which wraps ClientChanSummary +// instead of extending ClientChanSummary is for faster look-up of added fields. +// If we were to extend ClientChanSummary instead then we would need to decode +// the entire struct each time we want to read the new fields and then re-encode +// the struct each time we want to write to a new field. +type ChannelInfo struct { + ClientChanSummary + + // MaxHeight is the highest commitment height that the tower has been + // handed for this channel. An Option type is used to store this since + // a commitment height of zero is valid, and we need a way of knowing if + // we have seen a new height yet or not. + MaxHeight fn.Option[uint64] +} // ClientChanSummary tracks channel-specific information. A new // ClientChanSummary is inserted in the database the first time the client diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 5aec017d5..635c6cfa8 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -9,6 +9,7 @@ import ( "sync" "github.com/btcsuite/btcd/btcec/v2" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/tlv" @@ -1308,11 +1309,11 @@ func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { return numAcked, nil } -// FetchChanSummaries loads a mapping from all registered channels to their -// channel summaries. Only the channels that have not yet been marked as closed -// will be loaded. -func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { - var summaries map[lnwire.ChannelID]ClientChanSummary +// FetchChanInfos loads a mapping from all registered channels to their +// ChannelInfo. Only the channels that have not yet been marked as closed will +// be loaded. +func (c *ClientDB) FetchChanInfos() (ChannelInfos, error) { + var infos ChannelInfos err := kvdb.View(c.db, func(tx kvdb.RTx) error { chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) @@ -1325,34 +1326,47 @@ func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { if chanDetails == nil { return ErrCorruptChanDetails } - // If this channel has already been marked as closed, // then its summary does not need to be loaded. closedHeight := chanDetails.Get(cChanClosedHeight) if len(closedHeight) > 0 { return nil } - var chanID lnwire.ChannelID copy(chanID[:], k) - summary, err := getChanSummary(chanDetails) if err != nil { return err } - summaries[chanID] = *summary + info := &ChannelInfo{ + ClientChanSummary: *summary, + } + + maxHeightBytes := chanDetails.Get( + cChanMaxCommitmentHeight, + ) + if len(maxHeightBytes) != 0 { + height, err := readBigSize(maxHeightBytes) + if err != nil { + return err + } + + info.MaxHeight = fn.Some(height) + } + + infos[chanID] = info return nil }) }, func() { - summaries = make(map[lnwire.ChannelID]ClientChanSummary) + infos = make(ChannelInfos) }) if err != nil { return nil, err } - return summaries, nil + return infos, nil } // RegisterChannel registers a channel for use within the client database. For diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 6d11a6972..36cc049a9 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -156,13 +156,13 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, return tower } -func (h *clientDBHarness) fetchChanSummaries() map[lnwire.ChannelID]wtdb.ClientChanSummary { +func (h *clientDBHarness) fetchChanInfos() wtdb.ChannelInfos { h.t.Helper() - summaries, err := h.db.FetchChanSummaries() + infos, err := h.db.FetchChanInfos() require.NoError(h.t, err) - return summaries + return infos } func (h *clientDBHarness) registerChan(chanID lnwire.ChannelID, @@ -552,7 +552,7 @@ func testRemoveTower(h *clientDBHarness) { func testChanSummaries(h *clientDBHarness) { // First, assert that this channel is not already registered. var chanID lnwire.ChannelID - _, ok := h.fetchChanSummaries()[chanID] + _, ok := h.fetchChanInfos()[chanID] require.Falsef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) @@ -565,7 +565,7 @@ func testChanSummaries(h *clientDBHarness) { // Assert that the channel exists and that its sweep pkscript matches // the one we registered. - summary, ok := h.fetchChanSummaries()[chanID] + summary, ok := h.fetchChanInfos()[chanID] require.Truef(h.t, ok, "pkscript for channel %x should not exist yet", chanID) require.Equal(h.t, expPkScript, summary.SweepPkScript) @@ -767,6 +767,58 @@ func testRogueUpdates(h *clientDBHarness) { require.Len(h.t, closableSessionsMap, 1) } +// testMaxCommitmentHeights tests that the max known commitment height of a +// channel is properly persisted. +func testMaxCommitmentHeights(h *clientDBHarness) { + const maxUpdates = 5 + t := h.t + + // Initially, we expect no channels. + infos := h.fetchChanInfos() + require.Empty(t, infos) + + // Create a new tower. + tower := h.newTower() + + // Create and insert a new session. + session1 := h.randSession(t, tower.ID, maxUpdates) + h.insertSession(session1, nil) + + // Create a new channel and register it. + chanID1 := randChannelID(t) + h.registerChan(chanID1, nil, nil) + + // At this point, we expect one channel to be returned from + // fetchChanInfos but with an unset max height. + infos = h.fetchChanInfos() + require.Len(t, infos, 1) + + info, ok := infos[chanID1] + require.True(t, ok) + require.True(t, info.MaxHeight.IsNone()) + + // Commit and ACK some updates for this channel. + for i := 1; i <= maxUpdates; i++ { + update := randCommittedUpdateForChanWithHeight( + t, chanID1, uint16(i), uint64(i-1), + ) + lastApplied := h.commitUpdate(&session1.ID, update, nil) + h.ackUpdate(&session1.ID, uint16(i), lastApplied, nil) + } + + // Assert that the max height has now been set accordingly for this + // channel. + infos = h.fetchChanInfos() + require.Len(t, infos, 1) + + info, ok = infos[chanID1] + require.True(t, ok) + require.True(t, info.MaxHeight.IsSome()) + info.MaxHeight.WhenSome(func(u uint64) { + require.EqualValues(t, maxUpdates-1, u) + }) +} + // testMarkChannelClosed asserts the behaviour of MarkChannelClosed. func testMarkChannelClosed(h *clientDBHarness) { tower := h.newTower() @@ -1097,6 +1149,10 @@ func TestClientDB(t *testing.T) { name: "rogue updates", run: testRogueUpdates, }, + { + name: "max commitment heights", + run: testMaxCommitmentHeights, + }, } for _, database := range dbs { diff --git a/watchtower/wtdb/queue_test.go b/watchtower/wtdb/queue_test.go index 02c7b272c..ff2c5a0da 100644 --- a/watchtower/wtdb/queue_test.go +++ b/watchtower/wtdb/queue_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/lightningnetwork/lnd/kvdb" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/stretchr/testify/require" ) @@ -31,6 +32,14 @@ func TestDiskQueue(t *testing.T) { require.NoError(t, err) }) + // In order to test that the queue's `onItemWrite` call back (which in + // this case will be set to maybeUpdateMaxCommitHeight) is executed as + // expected, we need to register a channel so that we can later assert + // that it's max height field was updated properly. + var chanID lnwire.ChannelID + err = db.RegisterChannel(chanID, []byte{}) + require.NoError(t, err) + namespace := []byte("test-namespace") queue := db.GetDBQueue(namespace) @@ -110,4 +119,19 @@ func TestDiskQueue(t *testing.T) { // This should not have changed the order of the tasks, they should // still appear in the correct order. popAndAssert(task1, task2, task3, task4, task5, task6) + + // Finally, we check that the `onItemWrite` call back was executed by + // the queue. We do this by checking that the channel's recorded max + // commitment height was set correctly. It should be equal to the height + // recorded in task6. + infos, err := db.FetchChanInfos() + require.NoError(t, err) + require.Len(t, infos, 1) + + info, ok := infos[chanID] + require.True(t, ok) + require.True(t, info.MaxHeight.IsSome()) + info.MaxHeight.WhenSome(func(height uint64) { + require.EqualValues(t, task6.CommitHeight, height) + }) }