diff --git a/lnrpc/wtclientrpc/wtclient.go b/lnrpc/wtclientrpc/wtclient.go index 1d9d41849..6672c0c75 100644 --- a/lnrpc/wtclientrpc/wtclient.go +++ b/lnrpc/wtclientrpc/wtclient.go @@ -341,27 +341,27 @@ func constructFunctionalOptions(includeSessions bool) ( var ( opts []wtdb.ClientSessionListOption - ackCounts = make(map[wtdb.SessionID]uint16) committedUpdateCounts = make(map[wtdb.SessionID]uint16) + ackCounts = make(map[wtdb.SessionID]uint16) ) if !includeSessions { return opts, ackCounts, committedUpdateCounts } - perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, - _ wtdb.BackupID) { + perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID, + numUpdates uint16) { - ackCounts[s.ID]++ + ackCounts[s.ID] += numUpdates } perCommittedUpdate := func(s *wtdb.ClientSession, - _ *wtdb.CommittedUpdate) { + u *wtdb.CommittedUpdate) { committedUpdateCounts[s.ID]++ } opts = []wtdb.ClientSessionListOption{ - wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerNumAckedUpdates(perNumAckedUpdates), wtdb.WithPerCommittedUpdate(perCommittedUpdate), } @@ -438,7 +438,8 @@ func (c *WatchtowerClient) Policy(ctx context.Context, // marshallTower converts a client registered watchtower into its corresponding // RPC type. func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool, - ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower { + ackCounts map[wtdb.SessionID]uint16, + pendingCounts map[wtdb.SessionID]uint16) *Tower { rpcAddrs := make([]string, 0, len(tower.Addresses)) for _, addr := range tower.Addresses { diff --git a/watchtower/log.go b/watchtower/log.go index 8ce96131e..8e9062d95 100644 --- a/watchtower/log.go +++ b/watchtower/log.go @@ -4,6 +4,8 @@ import ( "github.com/btcsuite/btclog" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/watchtower/lookout" + "github.com/lightningnetwork/lnd/watchtower/wtclient" + "github.com/lightningnetwork/lnd/watchtower/wtdb" "github.com/lightningnetwork/lnd/watchtower/wtserver" ) @@ -30,4 +32,6 @@ func UseLogger(logger btclog.Logger) { log = logger lookout.UseLogger(logger) wtserver.UseLogger(logger) + wtclient.UseLogger(logger) + wtdb.UseLogger(logger) } diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index 3aa84f28c..685170eab 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -314,7 +314,9 @@ func New(config *Config) (*TowerClient, error) { // determine the highest known commit height for each channel. This // allows the client to reject backups that it has already processed for // its active policy. - perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) { + perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID, + commitHeight uint64) { + // We only want to consider accepted updates that have been // accepted under an identical policy to the client's current // policy. @@ -324,22 +326,22 @@ func New(config *Config) (*TowerClient, error) { // Take the highest commit height found in the session's acked // updates. - height, ok := c.chanCommitHeights[id.ChanID] - if !ok || id.CommitHeight > height { - c.chanCommitHeights[id.ChanID] = id.CommitHeight + height, ok := c.chanCommitHeights[chanID] + if !ok || commitHeight > height { + c.chanCommitHeights[chanID] = commitHeight } } - perAckedUpdate := func(s *wtdb.ClientSession, _ uint16, - id wtdb.BackupID) { + perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID, + height uint64) { - perUpdate(s.Policy, id) + perUpdate(s.Policy, chanID, height) } perCommittedUpdate := func(s *wtdb.ClientSession, u *wtdb.CommittedUpdate) { - perUpdate(s.Policy, u.BackupID) + perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight) } // Load all candidate sessions and towers from the database into the @@ -366,7 +368,7 @@ func New(config *Config) (*TowerClient, error) { candidateSessions, err := getTowerAndSessionCandidates( cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, - wtdb.WithPerAckedUpdate(perAckedUpdate), + wtdb.WithPerMaxHeight(perMaxHeight), wtdb.WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index ba6546328..63d718b60 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -68,6 +68,14 @@ type DB interface { FetchSessionCommittedUpdates(id *wtdb.SessionID) ( []wtdb.CommittedUpdate, error) + // IsAcked returns true if the given backup has been backed up using + // the given session. + IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, error) + + // NumAckedUpdates returns the number of backups that have been + // successfully backed up using the given session. + NumAckedUpdates(id *wtdb.SessionID) (uint64, error) + // FetchChanSummaries loads a mapping from all registered channels to // their channel summaries. FetchChanSummaries() (wtdb.ChannelSummaries, error) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 3b256ac0c..f8eda37bf 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -36,7 +36,7 @@ var ( // cSessionBkt is a top-level bucket storing: // session-id => cSessionBody -> encoded ClientSessionBody // => cSessionCommits => seqnum -> encoded CommittedUpdate - // => cSessionAcks => seqnum -> encoded BackupID + // => cSessionAckRangeIndex => db-chan-id => start -> end cSessionBkt = []byte("client-session-bucket") // cSessionBody is a sub-bucket of cSessionBkt storing only the body of @@ -47,9 +47,9 @@ var ( // seqnum -> encoded CommittedUpdate. cSessionCommits = []byte("client-session-commits") - // cSessionAcks is a sub-bucket of cSessionBkt storing: - // seqnum -> encoded BackupID. - cSessionAcks = []byte("client-session-acks") + // cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing + // chan-id => start -> end + cSessionAckRangeIndex = []byte("client-session-ack-range-index") // cChanIDIndexBkt is a top-level bucket storing: // db-assigned-id -> channel-ID @@ -422,6 +422,11 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { return ErrUninitializedDB } + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + // Don't return an error if the watchtower doesn't exist to act // as a NOP. pubKeyBytes := pubKey.SerializeCompressed() @@ -463,7 +468,8 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } towerSessions, err := c.listTowerSessions( - towerID, sessions, towersToSessionsIndex, + towerID, sessions, chanIDIndexBkt, + towersToSessionsIndex, WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -763,6 +769,149 @@ func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) { return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize)) } +// getRangeIndex checks the ClientDB's in-memory range index map to see if it +// has an entry for the given session and channel ID. If it does, this is +// returned, otherwise the range index is loaded from the DB. An optional db +// transaction parameter may be provided. If one is provided then it will be +// used to query the DB for the range index, otherwise, a new transaction will +// be created and used. +func (c *ClientDB) getRangeIndex(tx kvdb.RTx, sID SessionID, + chanID lnwire.ChannelID) (*RangeIndex, error) { + + c.ackedRangeIndexMu.Lock() + defer c.ackedRangeIndexMu.Unlock() + + if _, ok := c.ackedRangeIndex[sID]; !ok { + c.ackedRangeIndex[sID] = make(map[lnwire.ChannelID]*RangeIndex) + } + + // If the in-memory range-index map already includes an entry for this + // session ID and channel ID pair, then return it. + if index, ok := c.ackedRangeIndex[sID][chanID]; ok { + return index, nil + } + + // readRangeIndexFromBkt is a helper that is used to read in a + // RangeIndex structure from the passed in bucket and store it in the + // ackedRangeIndex map. + readRangeIndexFromBkt := func(rangesBkt kvdb.RBucket) (*RangeIndex, + error) { + + // Create a new in-memory RangeIndex by reading in ranges from + // the DB. + rangeIndex, err := readRangeIndex(rangesBkt) + if err != nil { + return nil, err + } + + c.ackedRangeIndex[sID][chanID] = rangeIndex + + return rangeIndex, nil + } + + // If a DB transaction is provided then use it to fetch the ranges + // bucket from the DB. + if tx != nil { + rangesBkt, err := getRangesReadBucket(tx, sID, chanID) + if err != nil { + return nil, err + } + + return readRangeIndexFromBkt(rangesBkt) + } + + // No DB transaction was provided. So create and use a new one. + var index *RangeIndex + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + rangesBkt, err := getRangesReadBucket(tx, sID, chanID) + if err != nil { + return err + } + + index, err = readRangeIndexFromBkt(rangesBkt) + + return err + }, func() {}) + if err != nil { + return nil, err + } + + return index, nil +} + +// getRangesReadBucket gets the range index bucket where the range index for the +// given session-channel pair is stored. If any sub-buckets along the way do not +// exist, then an error is returned. If the sub-buckets should be created +// instead, then use getRangesWriteBucket. +func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) ( + kvdb.RBucket, error) { + + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return nil, ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return nil, ErrUninitializedDB + } + + sessionBkt := sessions.NestedReadBucket(sID[:]) + if sessionsBkt == nil { + return nil, ErrNoRangeIndexFound + } + + // Get the DB representation of the channel-ID. + _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) + if err != nil { + return nil, err + } + + sessionAckRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + if sessionAckRanges == nil { + return nil, ErrNoRangeIndexFound + } + + return sessionAckRanges.NestedReadBucket(dbChanIDBytes), nil +} + +// getRangesWriteBucket gets the range index bucket where the range index for +// the given session-channel pair is stored. If any sub-buckets along the way do +// not exist, then they are created. +func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID, + chanID lnwire.ChannelID) (kvdb.RwBucket, error) { + + sessions := tx.ReadWriteBucket(cSessionBkt) + if sessions == nil { + return nil, ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return nil, ErrUninitializedDB + } + + sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:]) + if err != nil { + return nil, err + } + + // Get the DB representation of the channel-ID. + _, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID) + if err != nil { + return nil, err + } + + sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists( + cSessionAckRangeIndex, + ) + if err != nil { + return nil, err + } + + return sessionAckRanges.CreateBucketIfNotExists(dbChanIDBytes) +} + // createSessionKeyIndexKey returns the identifier used in the // session-key-index index, created as tower-id||blob-type. // @@ -825,13 +974,18 @@ func (c *ClientDB) ListClientSessions(id *TowerID, return ErrUninitializedDB } + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + var err error // If no tower ID is specified, then fetch all the sessions // known to the db. if id == nil { clientSessions, err = c.listClientAllSessions( - sessions, opts..., + sessions, chanIDIndexBkt, opts..., ) return err } @@ -843,7 +997,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } clientSessions, err = c.listTowerSessions( - *id, sessions, towerToSessionIndex, opts..., + *id, sessions, chanIDIndexBkt, towerToSessionIndex, + opts..., ) return err }, func() { @@ -857,7 +1012,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } // listClientAllSessions returns the set of all client sessions known to the db. -func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket, +func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket, opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -866,7 +1021,9 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := c.getClientSession(sessions, k, opts...) + session, err := c.getClientSession( + sessions, chanIDIndexBkt, k, opts..., + ) if err != nil { return err } @@ -884,7 +1041,7 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. -func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, +func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( map[SessionID]*ClientSession, error) { @@ -899,7 +1056,9 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := c.getClientSession(sessionsBkt, k, opts...) + session, err := c.getClientSession( + sessionsBkt, chanIDIndexBkt, k, opts..., + ) if err != nil { return err } @@ -944,6 +1103,73 @@ func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) ( return committedUpdates, nil } +// IsAcked returns true if the given backup has been backed up using the given +// session. +func (c *ClientDB) IsAcked(id *SessionID, backupID *BackupID) (bool, error) { + index, err := c.getRangeIndex(nil, *id, backupID.ChanID) + if errors.Is(err, ErrNoRangeIndexFound) { + return false, nil + } else if err != nil { + return false, err + } + + return index.IsInIndex(backupID.CommitHeight), nil +} + +// NumAckedUpdates returns the number of backups that have been successfully +// backed up using the given session. +func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) { + var numAcked uint64 + err := kvdb.View(c.db, func(tx kvdb.RTx) error { + sessions := tx.ReadBucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + sessionBkt := sessions.NestedReadBucket(id[:]) + if sessionsBkt == nil { + return nil + } + + sessionAckRanges := sessionBkt.NestedReadBucket( + cSessionAckRangeIndex, + ) + if sessionAckRanges == nil { + return nil + } + + // Iterate over the channel ID's in the sessionAckRanges + // bucket. + return sessionAckRanges.ForEach(func(dbChanID, _ []byte) error { + // Get the range index for the session-channel pair. + chanIDBytes := chanIDIndexBkt.Get(dbChanID) + var chanID lnwire.ChannelID + copy(chanID[:], chanIDBytes) + + index, err := c.getRangeIndex(tx, *id, chanID) + if err != nil { + return err + } + + numAcked += index.NumInSet() + + return nil + }) + }, func() { + numAcked = 0 + }) + if err != nil { + return 0, err + } + + return numAcked, nil +} + // FetchChanSummaries loads a mapping from all registered channels to their // channel summaries. func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { @@ -1174,6 +1400,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return ErrUninitializedDB } + chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + // We'll only load the ClientSession body for performance, since // we primarily need to inspect its SeqNum and TowerLastApplied // fields. The CommittedUpdates and AckedUpdates will be @@ -1242,25 +1473,24 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16, return err } - // Ensure that the session acks sub-bucket is initialized, so we - // can insert an entry. - sessionAcks, err := sessionBkt.CreateBucketIfNotExists( - cSessionAcks, - ) + chanID := committedUpdate.BackupID.ChanID + height := committedUpdate.BackupID.CommitHeight + + // Get the ranges write bucket before getting the range index to + // ensure that the session acks sub-bucket is initialized, so + // that we can insert an entry. + rangesBkt, err := getRangesWriteBucket(tx, *id, chanID) if err != nil { return err } - // The session acks only need to track the backup id of the - // update, so we can discard the blob and hint. - var b bytes.Buffer - err = committedUpdate.BackupID.Encode(&b) + // Get the range index for the given session-channel pair. + index, err := c.getRangeIndex(tx, *id, chanID) if err != nil { return err } - // Finally, insert the ack into the sessionAcks sub-bucket. - return sessionAcks.Put(seqNumBuf[:], b.Bytes()) + return index.Add(height, rangesBkt) }, func() {}) } @@ -1293,9 +1523,15 @@ func getClientSessionBody(sessions kvdb.RBucket, return &session, nil } -// PerAckedUpdateCB describes the signature of a callback function that can be -// called for each of a session's acked updates. -type PerAckedUpdateCB func(*ClientSession, uint16, BackupID) +// PerMaxHeightCB describes the signature of a callback function that can be +// called for each channel that a session has updates for to communicate the +// maximum commitment height that the session has backed up for the channel. +type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64) + +// PerNumAckedUpdatesCB describes the signature of a callback function that can +// be called for each channel that a session has updates for to communicate the +// number of updates that the session has for the channel. +type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16) // PerCommittedUpdateCB describes the signature of a callback function that can // be called for each of a session's committed updates (updates that the client @@ -1310,9 +1546,15 @@ type ClientSessionListOption func(cfg *ClientSessionListCfg) // ClientSessionListCfg defines various query parameters that will be used when // querying the DB for client sessions. type ClientSessionListCfg struct { - // PerAckedUpdate will, if set, be called for each of the session's - // acked updates. - PerAckedUpdate PerAckedUpdateCB + // PerNumAckedUpdates will, if set, be called for each of the session's + // channels to communicate the number of updates stored for that + // channel. + PerNumAckedUpdates PerNumAckedUpdatesCB + + // PerMaxHeight will, if set, be called for each of the session's + // channels to communicate the highest commit height of updates stored + // for that channel. + PerMaxHeight PerMaxHeightCB // PerCommittedUpdate will, if set, be called for each of the session's // committed (un-acked) updates. @@ -1324,11 +1566,22 @@ func NewClientSessionCfg() *ClientSessionListCfg { return &ClientSessionListCfg{} } -// WithPerAckedUpdate constructs a functional option that will set a call-back -// function to be called for each of a client's acked updates. -func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption { +// WithPerMaxHeight constructs a functional option that will set a call-back +// function to be called for each of a session's channels to communicate the +// maximum commitment height that the session has stored for the channel. +func WithPerMaxHeight(cb PerMaxHeightCB) ClientSessionListOption { return func(cfg *ClientSessionListCfg) { - cfg.PerAckedUpdate = cb + cfg.PerMaxHeight = cb + } +} + +// WithPerNumAckedUpdates constructs a functional option that will set a +// call-back function to be called for each of a session's channels to +// communicate the number of updates that the session has stored for the +// channel. +func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption { + return func(cfg *ClientSessionListCfg) { + cfg.PerNumAckedUpdates = cb } } @@ -1343,21 +1596,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte, - opts ...ClientSessionListOption) (*ClientSession, error) { +func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket, + idBytes []byte, opts ...ClientSessionListOption) (*ClientSession, + error) { cfg := NewClientSessionCfg() for _, o := range opts { o(cfg) } - session, err := getClientSessionBody(sessions, idBytes) + session, err := getClientSessionBody(sessionsBkt, idBytes) if err != nil { return nil, err } // Can't fail because client session body has already been read. - sessionBkt := sessions.NestedReadBucket(idBytes) + sessionBkt := sessionsBkt.NestedReadBucket(idBytes) // Pass the session's committed (un-acked) updates through the call-back // if one is provided. @@ -1370,7 +1624,10 @@ func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte, // Pass the session's acked updates through the call-back if one is // provided. - err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate) + err = c.filterClientSessionAcks( + sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight, + cfg.PerNumAckedUpdates, + ) if err != nil { return nil, err } @@ -1419,35 +1676,43 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession, // filterClientSessionAcks retrieves all acked updates for the session // identified by the serialized session id and passes them to the provided // call back if one is provided. -func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, - cb PerAckedUpdateCB) error { +func (c *ClientDB) filterClientSessionAcks(sessionBkt, + chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB, + perNumAckedUpdates PerNumAckedUpdatesCB) error { - if cb == nil { + if perMaxCb == nil && perNumAckedUpdates == nil { return nil } - sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) - if sessionAcks == nil { + sessionAcksRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + if sessionAcksRanges == nil { return nil } - err := sessionAcks.ForEach(func(k, v []byte) error { - seqNum := byteOrder.Uint16(k) + return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error { + rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID) + if rangeBkt == nil { + return nil + } - var backupID BackupID - err := backupID.Decode(bytes.NewReader(v)) + index, err := readRangeIndex(rangeBkt) if err != nil { return err } - cb(s, seqNum, backupID) + chanIDBytes := chanIDIndexBkt.Get(dbChanID) + var chanID lnwire.ChannelID + copy(chanID[:], chanIDBytes) + + if perMaxCb != nil { + perMaxCb(s, chanID, index.MaxHeight()) + } + + if perNumAckedUpdates != nil { + perNumAckedUpdates(s, chanID, uint16(index.NumInSet())) + } return nil }) - if err != nil { - return err - } - - return nil } // filterClientSessionCommits retrieves all committed updates for the session diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index f75a0c2bc..2cafc160b 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -221,6 +221,26 @@ func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID, return updates } +func (h *clientDBHarness) isAcked(id *wtdb.SessionID, backupID *wtdb.BackupID, + expErr error) bool { + + h.t.Helper() + + isAcked, err := h.db.IsAcked(id, backupID) + require.ErrorIs(h.t, err, expErr) + + return isAcked +} + +func (h *clientDBHarness) numAcked(id *wtdb.SessionID, expErr error) uint64 { + h.t.Helper() + + numAcked, err := h.db.NumAckedUpdates(id) + require.ErrorIs(h.t, err, expErr) + + return numAcked +} + // testCreateClientSession asserts various conditions regarding the creation of // a new ClientSession. The test asserts: // - client sessions can only be created if a session key index is reserved. @@ -453,6 +473,7 @@ func testRemoveTower(h *clientDBHarness) { } h.insertSession(session, nil) update := randCommittedUpdate(h.t, 1) + h.registerChan(update.BackupID.ChanID, nil, nil) h.commitUpdate(&session.ID, update, nil) // We should not be able to fully remove it from the database since @@ -583,16 +604,6 @@ func testCommitUpdate(h *clientDBHarness) { }, nil) } -func perAckedUpdate(updates map[uint16]wtdb.BackupID) func( - _ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) { - - return func(_ *wtdb.ClientSession, seq uint16, - id wtdb.BackupID) { - - updates[seq] = id - } -} - // testAckUpdate asserts the behavior of AckUpdate. func testAckUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit @@ -628,6 +639,8 @@ func testAckUpdate(h *clientDBHarness) { // Commit to a random update at seqnum 1. update1 := randCommittedUpdate(h.t, 1) + + h.registerChan(update1.BackupID.ChanID, nil, nil) lastApplied := h.commitUpdate(&session.ID, update1, nil) require.Zero(h.t, lastApplied) @@ -654,6 +667,7 @@ func testAckUpdate(h *clientDBHarness) { // value is 1, since this was what was provided in the last successful // ack. update2 := randCommittedUpdate(h.t, 2) + h.registerChan(update2.BackupID.ChanID, nil, nil) lastApplied = h.commitUpdate(&session.ID, update2, nil) require.EqualValues(h.t, 1, lastApplied) @@ -681,13 +695,16 @@ func (h *clientDBHarness) assertUpdates(id wtdb.SessionID, expectedPending []wtdb.CommittedUpdate, expectedAcked map[uint16]wtdb.BackupID) { - ackedUpdates := make(map[uint16]wtdb.BackupID) - _ = h.listSessions( - nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)), - ) - committedUpates := h.fetchSessionCommittedUpdates(&id, nil) - checkCommittedUpdates(h.t, committedUpates, expectedPending) - checkAckedUpdates(h.t, ackedUpdates, expectedAcked) + committedUpdates := h.fetchSessionCommittedUpdates(&id, nil) + checkCommittedUpdates(h.t, committedUpdates, expectedPending) + + // Check acked updates. + numAcked := h.numAcked(&id, nil) + require.EqualValues(h.t, len(expectedAcked), numAcked) + for _, backupID := range expectedAcked { + isAcked := h.isAcked(&id, &backupID, nil) + require.True(h.t, isAcked) + } } // checkCommittedUpdates asserts that the CommittedUpdates on session match the @@ -707,21 +724,6 @@ func checkCommittedUpdates(t *testing.T, actualUpdates, require.Equal(t, expUpdates, actualUpdates) } -// checkAckedUpdates asserts that the AckedUpdates on a session match the -// expUpdates provided. -func checkAckedUpdates(t *testing.T, actualUpdates, - expUpdates map[uint16]wtdb.BackupID) { - - // We promote nil expUpdates to an initialized map since the database - // should never return a nil map. This promotion is done purely out of - // convenience for the testing framework. - if expUpdates == nil { - expUpdates = make(map[uint16]wtdb.BackupID) - } - - require.Equal(t, expUpdates, actualUpdates) -} - // TestClientDB asserts the behavior of a fresh client db, a reopened client db, // and the mock implementation. This ensures that all databases function // identically, especially in the negative paths. diff --git a/watchtower/wtdb/log.go b/watchtower/wtdb/log.go index 63be43d28..134cee4a7 100644 --- a/watchtower/wtdb/log.go +++ b/watchtower/wtdb/log.go @@ -6,6 +6,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration2" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" ) // log is a logger that is initialized with no output filters. This @@ -32,6 +33,7 @@ func UseLogger(logger btclog.Logger) { migration1.UseLogger(logger) migration2.UseLogger(logger) migration3.UseLogger(logger) + migration4.UseLogger(logger) } // logClosure is used to provide a closure over expensive logging operations so diff --git a/watchtower/wtdb/version.go b/watchtower/wtdb/version.go index fd3f5f762..ebc09fb8d 100644 --- a/watchtower/wtdb/version.go +++ b/watchtower/wtdb/version.go @@ -8,6 +8,7 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb/migration1" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration2" "github.com/lightningnetwork/lnd/watchtower/wtdb/migration3" + "github.com/lightningnetwork/lnd/watchtower/wtdb/migration4" ) // txMigration is a function which takes a prior outdated version of the @@ -49,6 +50,11 @@ var clientDBVersions = []version{ { txMigration: migration3.MigrateChannelIDIndex, }, + { + dbMigration: migration4.MigrateAckedUpdates( + migration4.DefaultSessionsPerTx, + ), + }, } // getLatestDBVersion returns the last known database version. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index b12fe2780..095e8cbac 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -1,6 +1,7 @@ package wtmock import ( + "encoding/binary" "net" "sync" "sync/atomic" @@ -11,6 +12,8 @@ import ( "github.com/lightningnetwork/lnd/watchtower/wtdb" ) +var byteOrder = binary.BigEndian + type towerPK [33]byte type keyIndexKey struct { @@ -18,18 +21,23 @@ type keyIndexKey struct { blobType blob.Type } +type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex + +type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore + // ClientDB is a mock, in-memory database or testing the watchtower client // behavior. type ClientDB struct { nextTowerID uint64 // to be used atomically - mu sync.Mutex - summaries map[lnwire.ChannelID]wtdb.ClientChanSummary - activeSessions map[wtdb.SessionID]wtdb.ClientSession - ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID - committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate - towerIndex map[towerPK]wtdb.TowerID - towers map[wtdb.TowerID]*wtdb.Tower + mu sync.Mutex + summaries map[lnwire.ChannelID]wtdb.ClientChanSummary + activeSessions map[wtdb.SessionID]wtdb.ClientSession + ackedUpdates rangeIndexArrayMap + persistedAckedUpdates rangeIndexKVStore + committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate + towerIndex map[towerPK]wtdb.TowerID + towers map[wtdb.TowerID]*wtdb.Tower nextIndex uint32 indexes map[keyIndexKey]uint32 @@ -39,14 +47,21 @@ type ClientDB struct { // NewClientDB initializes a new mock ClientDB. func NewClientDB() *ClientDB { return &ClientDB{ - summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), - activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), - ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID), - committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate), - towerIndex: make(map[towerPK]wtdb.TowerID), - towers: make(map[wtdb.TowerID]*wtdb.Tower), - indexes: make(map[keyIndexKey]uint32), - legacyIndexes: make(map[wtdb.TowerID]uint32), + summaries: make( + map[lnwire.ChannelID]wtdb.ClientChanSummary, + ), + activeSessions: make( + map[wtdb.SessionID]wtdb.ClientSession, + ), + ackedUpdates: make(rangeIndexArrayMap), + persistedAckedUpdates: make(rangeIndexKVStore), + committedUpdates: make( + map[wtdb.SessionID][]wtdb.CommittedUpdate, + ), + towerIndex: make(map[towerPK]wtdb.TowerID), + towers: make(map[wtdb.TowerID]*wtdb.Tower), + indexes: make(map[keyIndexKey]uint32), + legacyIndexes: make(map[wtdb.TowerID]uint32), } } @@ -233,9 +248,20 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, } sessions[session.ID] = &session - if cfg.PerAckedUpdate != nil { - for seq, id := range m.ackedUpdates[session.ID] { - cfg.PerAckedUpdate(&session, seq, id) + if cfg.PerMaxHeight != nil { + for chanID, index := range m.ackedUpdates[session.ID] { + cfg.PerMaxHeight( + &session, chanID, index.MaxHeight(), + ) + } + } + + if cfg.PerNumAckedUpdates != nil { + for chanID, index := range m.ackedUpdates[session.ID] { + cfg.PerNumAckedUpdates( + &session, chanID, + uint16(index.NumInSet()), + ) } } @@ -266,6 +292,37 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) ( return updates, nil } +// IsAcked returns true if the given backup has been backed up using the given +// session. +func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, + error) { + + m.mu.Lock() + defer m.mu.Unlock() + + index, ok := m.ackedUpdates[*id][backupID.ChanID] + if !ok { + return false, nil + } + + return index.IsInIndex(backupID.CommitHeight), nil +} + +// NumAckedUpdates returns the number of backups that have been successfully +// backed up using the given session. +func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) { + m.mu.Lock() + defer m.mu.Unlock() + + var numAcked uint64 + + for _, index := range m.ackedUpdates[*id] { + numAcked += index.NumInSet() + } + + return numAcked, nil +} + // CreateClientSession records a newly negotiated client session in the set of // active sessions. The session can be identified by its SessionID. func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { @@ -311,7 +368,10 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { RewardPkScript: cloneBytes(session.RewardPkScript), }, } - m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID) + m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex) + m.persistedAckedUpdates[session.ID] = make( + map[lnwire.ChannelID]*mockKVStore, + ) m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0) return nil @@ -443,7 +503,25 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, updates[len(updates)-1] = wtdb.CommittedUpdate{} m.committedUpdates[session.ID] = updates[:len(updates)-1] - m.ackedUpdates[*id][seqNum] = update.BackupID + chanID := update.BackupID.ChanID + if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok { + index, err := wtdb.NewRangeIndex(nil) + if err != nil { + return err + } + + m.ackedUpdates[*id][chanID] = index + m.persistedAckedUpdates[*id][chanID] = newMockKVStore() + } + + err := m.ackedUpdates[*id][chanID].Add( + update.BackupID.CommitHeight, + m.persistedAckedUpdates[*id][chanID], + ) + if err != nil { + return err + } + session.TowerLastApplied = lastApplied m.activeSessions[*id] = session @@ -512,3 +590,39 @@ func copyTower(tower *wtdb.Tower) *wtdb.Tower { return t } + +type mockKVStore struct { + kv map[uint64]uint64 + + err error +} + +func newMockKVStore() *mockKVStore { + return &mockKVStore{ + kv: make(map[uint64]uint64), + } +} + +func (m *mockKVStore) Put(key, value []byte) error { + if m.err != nil { + return m.err + } + + k := byteOrder.Uint64(key) + v := byteOrder.Uint64(value) + + m.kv[k] = v + + return nil +} + +func (m *mockKVStore) Delete(key []byte) error { + if m.err != nil { + return m.err + } + + k := byteOrder.Uint64(key) + delete(m.kv, k) + + return nil +}