diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index b5d2418ab..a69e9980f 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -99,6 +99,12 @@ type DB interface { // marked as closable. ListClosableSessions() (map[wtdb.SessionID]uint32, error) + // DeleteSession can be called when a session should be deleted from the + // DB. All references to the session will also be deleted from the DB. + // A session will only be deleted if it was previously marked as + // closable. + DeleteSession(id wtdb.SessionID) error + // RegisterChannel registers a channel for use within the client // database. For now, all that is stored in the channel summary is the // sweep pkscript that we'd like any tower sweeps to pay into. In the diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 4a74c7bb4..c3e5447d6 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -168,6 +168,10 @@ var ( // not pass the filter func provided by the caller. ErrSessionFailedFilterFn = errors.New("session failed filter func") + // ErrSessionNotClosable is returned when a session is not found in the + // closable list. + ErrSessionNotClosable = errors.New("session is not closable") + // errSessionHasOpenChannels is an error used to indicate that a // session has updates for channels that are still open. errSessionHasOpenChannels = errors.New("session has open channels") @@ -175,6 +179,11 @@ var ( // errSessionHasUnackedUpdates is an error used to indicate that a // session has un-acked updates. errSessionHasUnackedUpdates = errors.New("session has un-acked updates") + + // errChannelHasMoreSessions is an error used to indicate that a channel + // has updates in other non-closed sessions. + errChannelHasMoreSessions = errors.New("channel has updates in " + + "other sessions") ) // NewBoltBackendCreator returns a function that creates a new bbolt backend for @@ -1053,6 +1062,7 @@ func (c *ClientDB) GetClientSession(id SessionID, } sess = session + return nil }, func() {}) @@ -1425,6 +1435,177 @@ func (c *ClientDB) ListClosableSessions() (map[SessionID]uint32, error) { return sessions, nil } +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if was previously marked as closable. +func (c *ClientDB) DeleteSession(id SessionID) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessionsBkt := tx.ReadWriteBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + closableBkt := tx.ReadWriteBucket(cClosableSessionsBkt) + if closableBkt == nil { + return ErrUninitializedDB + } + + chanDetailsBkt := tx.ReadWriteBucket(cChanDetailsBkt) + if chanDetailsBkt == nil { + return ErrUninitializedDB + } + + sessIDIndexBkt := tx.ReadWriteBucket(cSessionIDIndexBkt) + if sessIDIndexBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadWriteBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + towerToSessBkt := tx.ReadWriteBucket(cTowerToSessionIndexBkt) + if towerToSessBkt == nil { + return ErrUninitializedDB + } + + // Get the sub-bucket for this session ID. If it does not exist + // then the session has already been deleted and so our work is + // done. + sessionBkt := sessionsBkt.NestedReadBucket(id[:]) + if sessionBkt == nil { + return nil + } + + _, dbIDBytes, err := getDBSessionID(sessionsBkt, id) + if err != nil { + return err + } + + // First we check if the session has actually been marked as + // closable. + if closableBkt.Get(dbIDBytes) == nil { + return ErrSessionNotClosable + } + + sess, err := getClientSessionBody(sessionsBkt, id[:]) + if err != nil { + return err + } + + // Delete from the tower-to-sessionID index. + towerIndexBkt := towerToSessBkt.NestedReadWriteBucket( + sess.TowerID.Bytes(), + ) + if towerIndexBkt == nil { + return fmt.Errorf("no entry in the tower-to-session "+ + "index found for tower ID %v", sess.TowerID) + } + + err = towerIndexBkt.Delete(id[:]) + if err != nil { + return err + } + + // Delete entry from session ID index. + err = sessIDIndexBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Delete the entry from the closable sessions index. + err = closableBkt.Delete(dbIDBytes) + if err != nil { + return err + } + + // Get the acked updates range index for the session. This is + // used to get the list of channels that the session has updates + // for. + ackRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex) + if ackRanges == nil { + // A session would only be considered closable if it + // was exhausted. Meaning that it should not be the + // case that it has no acked-updates. + return fmt.Errorf("cannot delete session %s since it "+ + "is not yet exhausted", id) + } + + // For each of the channels, delete the session ID entry. + err = ackRanges.ForEach(func(chanDBID, _ []byte) error { + chanDBIDInt, err := readBigSize(chanDBID) + if err != nil { + return err + } + + chanID, err := getRealChannelID( + chanIDIndexBkt, chanDBIDInt, + ) + if err != nil { + return err + } + + chanDetails := chanDetailsBkt.NestedReadWriteBucket( + chanID[:], + ) + if chanDetails == nil { + return ErrChannelNotRegistered + } + + chanSessions := chanDetails.NestedReadWriteBucket( + cChanSessions, + ) + if chanSessions == nil { + return fmt.Errorf("no session list found for "+ + "channel %s", chanID) + } + + // Check that this session was actually listed in the + // session list for this channel. + if len(chanSessions.Get(dbIDBytes)) == 0 { + return fmt.Errorf("session %s not found in "+ + "the session list for channel %s", id, + chanID) + } + + // If it was, then delete it. + err = chanSessions.Delete(dbIDBytes) + if err != nil { + return err + } + + // If this was the last session for this channel, we can + // now delete the channel details for this channel + // completely. + err = chanSessions.ForEach(func(_, _ []byte) error { + return errChannelHasMoreSessions + }) + if errors.Is(err, errChannelHasMoreSessions) { + return nil + } else if err != nil { + return err + } + + // Delete the channel's entry from the channel-id-index. + dbID := chanDetails.Get(cChanDBID) + err = chanIDIndexBkt.Delete(dbID) + if err != nil { + return err + } + + // Delete the channel details. + return chanDetailsBkt.DeleteNestedBucket(chanID[:]) + }) + if err != nil { + return err + } + + // Delete the actual session. + return sessionsBkt.DeleteNestedBucket(id[:]) + }, func() {}) +} + // MarkChannelClosed will mark a registered channel as closed by setting its // closed-height as the given block height. It returns a list of session IDs for // sessions that are now considered closable due to the close of this channel. diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 73f4e5550..b3d241175 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -218,6 +218,13 @@ func (h *clientDBHarness) listClosableSessions( return closableSessions } +func (h *clientDBHarness) deleteSession(id wtdb.SessionID, expErr error) { + h.t.Helper() + + err := h.db.DeleteSession(id) + require.ErrorIs(h.t, err, expErr) +} + // newTower is a helper function that creates a new tower with a randomly // generated public key and inserts it into the client DB. func (h *clientDBHarness) newTower() *wtdb.Tower { @@ -724,6 +731,10 @@ func testMarkChannelClosed(h *clientDBHarness) { require.Empty(h.t, sl) require.Empty(h.t, h.listClosableSessions(nil)) + // Also check that attempting to delete the session will fail since it + // is not yet considered closable. + h.deleteSession(session1.ID, wtdb.ErrSessionNotClosable) + // Finally, if we close channel 6, session 1 _should_ be in the closable // list. sl = h.markChannelClosed(chanID6, 100, nil) @@ -732,6 +743,10 @@ func testMarkChannelClosed(h *clientDBHarness) { require.InDeltaMapValues(h.t, slMap, map[wtdb.SessionID]uint32{ session1.ID: 100, }, 0) + + // Assert that we now can delete the session. + h.deleteSession(session1.ID, nil) + require.Empty(h.t, h.listClosableSessions(nil)) } // testAckUpdate asserts the behavior of AckUpdate. diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index f439e9182..7213f17b6 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -703,6 +703,34 @@ func (m *ClientDB) GetClientSession(id wtdb.SessionID, return &session, nil } +// DeleteSession can be called when a session should be deleted from the DB. +// All references to the session will also be deleted from the DB. Note that a +// session will only be deleted if it is considered closable. +func (m *ClientDB) DeleteSession(id wtdb.SessionID) error { + m.mu.Lock() + defer m.mu.Unlock() + + _, ok := m.closableSessions[id] + if !ok { + return wtdb.ErrSessionNotClosable + } + + // For each of the channels, delete the session ID entry. + for chanID := range m.ackedUpdates[id] { + c, ok := m.channels[chanID] + if !ok { + return wtdb.ErrChannelNotRegistered + } + + delete(c.sessions, id) + } + + delete(m.closableSessions, id) + delete(m.activeSessions, id) + + return nil +} + // RegisterChannel registers a channel for use within the client database. For // now, all that is stored in the channel summary is the sweep pkscript that // we'd like any tower sweeps to pay into. In the future, this will be extended