diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index ffc6ea694..1051da0d5 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -33,6 +33,10 @@ type DB interface { // NOTE: An error is not returned if the tower doesn't exist. RemoveTower(*btcec.PublicKey, net.Addr) error + // TerminateSession sets the given session's status to CSessionTerminal + // meaning that it will not be usable again. + TerminateSession(id wtdb.SessionID) error + // LoadTower retrieves a tower by its public key. LoadTower(*btcec.PublicKey) (*wtdb.Tower, error) diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index a88cd49c4..9731dfef8 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -2259,6 +2259,53 @@ func (c *ClientDB) GetDBQueue(namespace []byte) Queue[*BackupID] { ) } +// TerminateSession sets the given session's status to CSessionTerminal meaning +// that it will not be usable again. An error will be returned if the given +// session still has un-acked updates that should be attended to. +func (c *ClientDB) TerminateSession(id SessionID) error { + return kvdb.Update(c.db, func(tx kvdb.RwTx) error { + sessions := tx.ReadWriteBucket(cSessionBkt) + if sessions == nil { + return ErrUninitializedDB + } + + sessionsBkt := tx.ReadBucket(cSessionBkt) + if sessionsBkt == nil { + return ErrUninitializedDB + } + + chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt) + if chanIDIndexBkt == nil { + return ErrUninitializedDB + } + + // Collect any un-acked updates for this session. + committedUpdateCount := make(map[SessionID]uint16) + perCommittedUpdate := func(s *ClientSession, + _ *CommittedUpdate) { + + committedUpdateCount[s.ID]++ + } + + session, err := c.getClientSession( + sessionsBkt, chanIDIndexBkt, id[:], + WithPerCommittedUpdate(perCommittedUpdate), + ) + if err != nil { + return err + } + + // If there are any un-acked updates for this session then + // we don't allow the change of status as these updates must + // first be dealt with somehow. + if committedUpdateCount[id] > 0 { + return ErrSessionHasUnackedUpdates + } + + return markSessionStatus(sessions, session, CSessionTerminal) + }, func() {}) +} + // DeleteCommittedUpdates deletes all the committed updates for the given // session. func (c *ClientDB) DeleteCommittedUpdates(id *SessionID) error { diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index 01ea0d8d8..dd6479b74 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -173,6 +173,24 @@ func (h *clientDBHarness) loadTowerByID(id wtdb.TowerID, return tower } +func (h *clientDBHarness) terminateSession(id wtdb.SessionID, expErr error) { + h.t.Helper() + + err := h.db.TerminateSession(id) + require.ErrorIs(h.t, err, expErr) +} + +func (h *clientDBHarness) getClientSession(id wtdb.SessionID, + expErr error) *wtdb.ClientSession { + + h.t.Helper() + + session, err := h.db.GetClientSession(id) + require.ErrorIs(h.t, err, expErr) + + return session +} + func (h *clientDBHarness) fetchChanInfos() wtdb.ChannelInfos { h.t.Helper() @@ -557,6 +575,55 @@ func testRemoveTower(h *clientDBHarness) { h.removeTower(pk, nil, true, nil) } +func testTerminateSession(h *clientDBHarness) { + const blobType = blob.TypeAltruistCommit + + tower := h.newTower() + + // Create a new session that the updates in this will be tied to. + session := &wtdb.ClientSession{ + ClientSessionBody: wtdb.ClientSessionBody{ + TowerID: tower.ID, + Policy: wtpolicy.Policy{ + TxPolicy: wtpolicy.TxPolicy{ + BlobType: blobType, + }, + MaxUpdates: 100, + }, + RewardPkScript: []byte{0x01, 0x02, 0x03}, + }, + ID: wtdb.SessionID([33]byte{0x03}), + } + + // Reserve a session key and insert the client session. + session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType, false) + h.insertSession(session, nil) + + // Commit to a random update at seqnum 1. + update1 := randCommittedUpdate(h.t, 1) + h.registerChan(update1.BackupID.ChanID, nil, nil) + lastApplied := h.commitUpdate(&session.ID, update1, nil) + require.Zero(h.t, lastApplied) + + // Terminating the session now should fail since the session has an + // un-acked update. + h.terminateSession(session.ID, wtdb.ErrSessionHasUnackedUpdates) + + // Fetch the session and assert that the status is still active. + sess := h.getClientSession(session.ID, nil) + require.Equal(h.t, wtdb.CSessionActive, sess.Status) + + // Delete the update. + h.deleteCommittedUpdates(&session.ID, nil) + + // Terminating the session now should succeed. + h.terminateSession(session.ID, nil) + + // Fetch the session again and assert that its status is now Terminal. + sess = h.getClientSession(session.ID, nil) + require.Equal(h.t, wtdb.CSessionTerminal, sess.Status) +} + // testTowerStatusChange tests that the Tower status is updated accordingly // given a variety of commands. func testTowerStatusChange(h *clientDBHarness) { @@ -1258,6 +1325,10 @@ func TestClientDB(t *testing.T) { name: "test tower status change", run: testTowerStatusChange, }, + { + name: "terminate session", + run: testTerminateSession, + }, } for _, database := range dbs {