diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index a2a20c987..0f7f1b539 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -135,6 +135,10 @@ type DB interface { // GetDBQueue returns a BackupID Queue instance under the given name // space. GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] + + // DeleteCommittedUpdate deletes the committed update belonging to the + // given session and with the given sequence number from the db. + DeleteCommittedUpdate(id *wtdb.SessionID, seqNum uint16) error } // AuthDialer connects to a remote node using an authenticated transport, such diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 9491015d6..41d80587d 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2073,6 +2073,42 @@ func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] { ) } +// DeleteCommittedUpdate deletes the committed update with the given sequence +// number from the given session. +func (c *ClientDB) DeleteCommittedUpdate(id *SessionID, seqNum uint16) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessions := tx.ReadWriteBucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + sessionBkt := sessions.NestedReadWriteBucket(id[:]) + if sessionBkt == nil { + return fmt.Errorf("session bucket %s not found", + id.String()) + } + + // If the commits sub-bucket doesn't exist, there can't possibly + // be a corresponding update to remove. + sessionCommits := sessionBkt.NestedReadWriteBucket( + cSessionCommits, + ) + if sessionCommits == nil { + return ErrCommittedUpdateNotFound + } + + var seqNumBuf [2]byte + byteOrder.PutUint16(seqNumBuf[:], seqNum) + + if sessionCommits.Get(seqNumBuf[:]) == nil { + return ErrCommittedUpdateNotFound + } + + // Remove the corresponding committed update. + return sessionCommits.Delete(seqNumBuf[:]) + }, func() {}) +} + // putChannelToSessionMapping adds the given session ID to a channel's // cChanSessions bucket. func putChannelToSessionMapping(chanDetails kvdb.RwBucket, diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 80557f1f8..5bfb4dab5 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -195,6 +195,15 @@ func (h *clientDBHarness) ackUpdate(id *wtdb.SessionID, seqNum uint16, require.ErrorIs(h.t, err, expErr) } +func (h *clientDBHarness) deleteCommittedUpdate(id *wtdb.SessionID, + seqNum uint16, expErr error) { + + h.t.Helper() + + err := h.db.DeleteCommittedUpdate(id, seqNum) + require.ErrorIs(h.t, err, expErr) +} + func (h *clientDBHarness) markChannelClosed(id lnwire.ChannelID, blockHeight uint32, expErr error) []wtdb.SessionID { @@ -567,7 +576,8 @@ func testChanSummaries(h *clientDBHarness) { h.registerChan(chanID, expPkScript, wtdb.ErrChannelAlreadyRegistered) } -// testCommitUpdate tests the behavior of CommitUpdate, ensuring that they can +// testCommitUpdate tests the behavior of CommitUpdate and +// DeleteCommittedUpdate. func testCommitUpdate(h *clientDBHarness) { const blobType = blob.TypeAltruistCommit @@ -648,6 +658,22 @@ func testCommitUpdate(h *clientDBHarness) { *update1, *update2, }, nil) + + // We will now also test that the DeleteCommittedUpdates method also + // works. + // First, try to delete a committed update that does not exist. + h.deleteCommittedUpdate( + &session.ID, update4.SeqNum, wtdb.ErrCommittedUpdateNotFound, + ) + + // Now delete an existing committed update and ensure that it succeeds. + h.deleteCommittedUpdate(&session.ID, update1.SeqNum, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{ + *update2, + }, nil) + + h.deleteCommittedUpdate(&session.ID, update2.SeqNum, nil) + h.assertUpdates(session.ID, []wtdb.CommittedUpdate{}, nil) } // testMarkChannelClosed asserts the behaviour of MarkChannelClosed. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 60838ab42..f5625d35b 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -586,6 +586,37 @@ func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] { return q } +// DeleteCommittedUpdate deletes the committed update with the given sequence +// number from the given session. +func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID, + seqNum uint16) error { + + m.mu.Lock() + defer m.mu.Unlock() + + // Fail if session doesn't exist. + session, ok := m.activeSessions[*id] + if !ok { + return wtdb.ErrClientSessionNotFound + } + + // Retrieve the committed update, failing if none is found. + updates := m.committedUpdates[session.ID] + for i, update := range updates { + if update.SeqNum != seqNum { + continue + } + + // Remove the committed update from "disk". + updates = append(updates[:i], updates[i+1:]...) + m.committedUpdates[session.ID] = updates + + return nil + } + + return wtdb.ErrCommittedUpdateNotFound +} + // ListClosableSessions fetches and returns the IDs for all sessions marked as // closable. func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {