watchtower+lnrpc: remove AckedUpdates from ClientSession struct

In this commit, we start making use of the new ListClientSession
functional options added in the previous commit. We use the functional
options in order to calculate the max commit heights per channel on the
construction of the tower client. We also use the options to count the
total number of acked and committed updates. With this commit, we are
also able to completely remove the AckedUpdates member of the
ClientSession since it is no longer used anywhere in the code.
This commit is contained in:
Elle Mouton
2022-10-13 14:24:15 +02:00
parent 40e0ebf417
commit 15858cae1c
6 changed files with 194 additions and 137 deletions

View File

@@ -265,12 +265,16 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
return nil, err return nil, err
} }
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers() opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
req.IncludeSessions,
)
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
legacyTowers, err := c.cfg.Client.RegisteredTowers() legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -286,7 +290,10 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
rpcTowers := make([]*Tower, 0, len(towers)) rpcTowers := make([]*Tower, 0, len(towers))
for _, tower := range towers { for _, tower := range towers {
rpcTower := marshallTower(tower, req.IncludeSessions) rpcTower := marshallTower(
tower, req.IncludeSessions, ackCounts,
committedUpdateCounts,
)
rpcTowers = append(rpcTowers, rpcTower) rpcTowers = append(rpcTowers, rpcTower)
} }
@@ -306,16 +313,59 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context,
return nil, err return nil, err
} }
opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
req.IncludeSessions,
)
var tower *wtclient.RegisteredTower var tower *wtclient.RegisteredTower
tower, err = c.cfg.Client.LookupTower(pubKey) tower, err = c.cfg.Client.LookupTower(pubKey, opts...)
if err == wtdb.ErrTowerNotFound { if err == wtdb.ErrTowerNotFound {
tower, err = c.cfg.AnchorClient.LookupTower(pubKey) tower, err = c.cfg.AnchorClient.LookupTower(pubKey, opts...)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
return marshallTower(tower, req.IncludeSessions), nil return marshallTower(
tower, req.IncludeSessions, ackCounts, committedUpdateCounts,
), nil
}
// constructFunctionalOptions is a helper function that constructs a list of
// functional options to be used when fetching a tower from the DB. It also
// returns a map of acked-update counts and one for un-acked-update counts that
// will be populated once the db call has been made.
func constructFunctionalOptions(includeSessions bool) (
[]wtdb.ClientSessionListOption, map[wtdb.SessionID]uint16,
map[wtdb.SessionID]uint16) {
var (
opts []wtdb.ClientSessionListOption
ackCounts = make(map[wtdb.SessionID]uint16)
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
)
if !includeSessions {
return opts, ackCounts, committedUpdateCounts
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
_ wtdb.BackupID) {
ackCounts[s.ID]++
}
perCommittedUpdate := func(s *wtdb.ClientSession,
_ *wtdb.CommittedUpdate) {
committedUpdateCounts[s.ID]++
}
opts = []wtdb.ClientSessionListOption{
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
}
return opts, ackCounts, committedUpdateCounts
} }
// Stats returns the in-memory statistics of the client since startup. // Stats returns the in-memory statistics of the client since startup.
@@ -387,7 +437,9 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
// marshallTower converts a client registered watchtower into its corresponding // marshallTower converts a client registered watchtower into its corresponding
// RPC type. // RPC type.
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower { func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower {
rpcAddrs := make([]string, 0, len(tower.Addresses)) rpcAddrs := make([]string, 0, len(tower.Addresses))
for _, addr := range tower.Addresses { for _, addr := range tower.Addresses {
rpcAddrs = append(rpcAddrs, addr.String()) rpcAddrs = append(rpcAddrs, addr.String())
@@ -399,8 +451,8 @@ func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower
for _, session := range tower.Sessions { for _, session := range tower.Sessions {
satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000 satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000
rpcSessions = append(rpcSessions, &TowerSession{ rpcSessions = append(rpcSessions, &TowerSession{
NumBackups: uint32(len(session.AckedUpdates)), NumBackups: uint32(ackCounts[session.ID]),
NumPendingBackups: uint32(len(session.CommittedUpdates)), NumPendingBackups: uint32(pendingCounts[session.ID]),
MaxBackups: uint32(session.Policy.MaxUpdates), MaxBackups: uint32(session.Policy.MaxUpdates),
SweepSatPerVbyte: uint32(satPerVByte), SweepSatPerVbyte: uint32(satPerVByte),

View File

@@ -289,12 +289,67 @@ func New(config *Config) (*TowerClient, error) {
} }
plog := build.NewPrefixLog(prefix, log) plog := build.NewPrefixLog(prefix, log)
// Next, load all candidate towers and sessions from the database into // Load the sweep pkscripts that have been generated for all previously
// the client. We will use any of these sessions if their policies match // registered channels.
// the current policy of the client, otherwise they will be ignored and chanSummaries, err := cfg.DB.FetchChanSummaries()
// new sessions will be requested. if err != nil {
return nil, err
}
c := &TowerClient{
cfg: cfg,
log: plog,
pipeline: newTaskPipeline(plog),
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
statTicker: time.NewTicker(DefaultStatInterval),
stats: new(ClientStats),
newTowers: make(chan *newTowerMsg),
staleTowers: make(chan *staleTowerMsg),
forceQuit: make(chan struct{}),
}
// perUpdate is a callback function that will be used to inspect the
// full set of candidate client sessions loaded from disk, and to
// determine the highest known commit height for each channel. This
// allows the client to reject backups that it has already processed for
// its active policy.
perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
if policy != c.cfg.Policy {
return
}
// Take the highest commit height found in the session's acked
// updates.
height, ok := c.chanCommitHeights[id.ChanID]
if !ok || id.CommitHeight > height {
c.chanCommitHeights[id.ChanID] = id.CommitHeight
}
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
id wtdb.BackupID) {
perUpdate(s.Policy, id)
}
perCommittedUpdate := func(s *wtdb.ClientSession,
u *wtdb.CommittedUpdate) {
perUpdate(s.Policy, u.BackupID)
}
// Load all candidate sessions and towers from the database into the
// client. We will use any of these sessions if their policies match the
// current policy of the client, otherwise they will be ignored and new
// sessions will be requested.
isAnchorClient := cfg.Policy.IsAnchorChannel() isAnchorClient := cfg.Policy.IsAnchorChannel()
activeSessionFilter := genActiveSessionFilter(isAnchorClient) activeSessionFilter := genActiveSessionFilter(isAnchorClient)
candidateTowers := newTowerListIterator() candidateTowers := newTowerListIterator()
perActiveTower := func(tower *wtdb.Tower) { perActiveTower := func(tower *wtdb.Tower) {
// If the tower has already been marked as active, then there is // If the tower has already been marked as active, then there is
@@ -309,34 +364,19 @@ func New(config *Config) (*TowerClient, error) {
// Add the tower to the set of candidate towers. // Add the tower to the set of candidate towers.
candidateTowers.AddCandidate(tower) candidateTowers.AddCandidate(tower)
} }
candidateSessions, err := getTowerAndSessionCandidates( candidateSessions, err := getTowerAndSessionCandidates(
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Load the sweep pkscripts that have been generated for all previously c.candidateTowers = candidateTowers
// registered channels. c.candidateSessions = candidateSessions
chanSummaries, err := cfg.DB.FetchChanSummaries()
if err != nil {
return nil, err
}
c := &TowerClient{
cfg: cfg,
log: plog,
pipeline: newTaskPipeline(plog),
candidateTowers: candidateTowers,
candidateSessions: candidateSessions,
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
statTicker: time.NewTicker(DefaultStatInterval),
stats: new(ClientStats),
newTowers: make(chan *newTowerMsg),
staleTowers: make(chan *staleTowerMsg),
forceQuit: make(chan struct{}),
}
c.negotiator = newSessionNegotiator(&NegotiatorConfig{ c.negotiator = newSessionNegotiator(&NegotiatorConfig{
DB: cfg.DB, DB: cfg.DB,
SecretKeyRing: cfg.SecretKeyRing, SecretKeyRing: cfg.SecretKeyRing,
@@ -351,10 +391,6 @@ func New(config *Config) (*TowerClient, error) {
Log: plog, Log: plog,
}) })
// Reconstruct the highest commit height processed for each channel
// under the client's current policy.
c.buildHighestCommitHeights()
return c, nil return c, nil
} }
@@ -450,44 +486,6 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
return sessions, nil return sessions, nil
} }
// buildHighestCommitHeights inspects the full set of candidate client sessions
// loaded from disk, and determines the highest known commit height for each
// channel. This allows the client to reject backups that it has already
// processed for it's active policy.
func (c *TowerClient) buildHighestCommitHeights() {
chanCommitHeights := make(map[lnwire.ChannelID]uint64)
for _, s := range c.candidateSessions {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
if s.Policy != c.cfg.Policy {
continue
}
// Take the highest commit height found in the session's
// committed updates.
for _, committedUpdate := range s.CommittedUpdates {
bid := committedUpdate.BackupID
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
// Take the heights commit height found in the session's acked
// updates.
for _, bid := range s.AckedUpdates {
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
}
c.chanCommitHeights = chanCommitHeights
}
// Start initializes the watchtower client by loading or negotiating an active // Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline. // session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error { func (c *TowerClient) Start() error {

View File

@@ -1239,17 +1239,15 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
return nil, err return nil, err
} }
// Fetch the acked updates for this session. // Pass the session's acked updates through the call-back if one is
ackedUpdates, err := getClientSessionAcks( // provided.
sessionBkt, session, cfg.PerAckedUpdate, err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate)
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
session.Tower = tower session.Tower = tower
session.CommittedUpdates = commitedUpdates session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates
return session, nil return session, nil
} }
@@ -1292,18 +1290,19 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
return committedUpdates, nil return committedUpdates, nil
} }
// getClientSessionAcks retrieves all acked updates for the session identified // filterClientSessionAcks retrieves all acked updates for the session
// by the serialized session id. // identified by the serialized session id and passes them to the provided
func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession, // call back if one is provided.
cb PerAckedUpdateCB) (map[uint16]BackupID, error) { func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerAckedUpdateCB) error {
// Initialize ackedUpdates so that we can return an initialized map if if cb == nil {
// no acked updates exist. return nil
ackedUpdates := make(map[uint16]BackupID) }
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
if sessionAcks == nil { if sessionAcks == nil {
return ackedUpdates, nil return nil
} }
err := sessionAcks.ForEach(func(k, v []byte) error { err := sessionAcks.ForEach(func(k, v []byte) error {
@@ -1315,19 +1314,14 @@ func getClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
return err return err
} }
ackedUpdates[seqNum] = backupID
if cb != nil {
cb(s, seqNum, backupID) cb(s, seqNum, backupID)
}
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return err
} }
return ackedUpdates, nil return nil
} }
// putClientSessionBody stores the body of the ClientSession (everything but the // putClientSessionBody stores the body of the ClientSession (everything but the

View File

@@ -48,12 +48,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
require.ErrorIs(h.t, err, expErr) require.ErrorIs(h.t, err, expErr)
} }
func (h *clientDBHarness) listSessions( func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper() h.t.Helper()
sessions, err := h.db.ListClientSessions(id) sessions, err := h.db.ListClientSessions(id, opts...)
require.NoError(h.t, err, "unable to list client sessions") require.NoError(h.t, err, "unable to list client sessions")
return sessions return sessions
@@ -520,11 +520,7 @@ func testCommitUpdate(h *clientDBHarness) {
// Assert that the committed update appears in the client session's // Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no // CommittedUpdates map when loaded from disk and that there are no
// AckedUpdates. // AckedUpdates.
dbSession := h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Try to commit the same update, which should succeed due to // Try to commit the same update, which should succeed due to
// idempotency (which is preserved when the breach hint is identical to // idempotency (which is preserved when the breach hint is identical to
@@ -534,11 +530,7 @@ func testCommitUpdate(h *clientDBHarness) {
require.Equal(h.t, lastApplied, lastApplied2) require.Equal(h.t, lastApplied, lastApplied2)
// Assert that the loaded ClientSession is the same as before. // Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Generate another random update and try to commit it at the identical // Generate another random update and try to commit it at the identical
// sequence number. Since the breach hint has changed, this should fail. // sequence number. Since the breach hint has changed, this should fail.
@@ -553,12 +545,10 @@ func testCommitUpdate(h *clientDBHarness) {
// Check that both updates now appear as committed on the ClientSession // Check that both updates now appear as committed on the ClientSession
// loaded from disk. // loaded from disk.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
}) }, nil)
checkAckedUpdates(h.t, dbSession, nil)
// Finally, create one more random update and try to commit it at index // Finally, create one more random update and try to commit it at index
// 4, which should be rejected since 3 is the next slot the database // 4, which should be rejected since 3 is the next slot the database
@@ -567,12 +557,20 @@ func testCommitUpdate(h *clientDBHarness) {
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
// Assert that the ClientSession loaded from disk remains unchanged. // Assert that the ClientSession loaded from disk remains unchanged.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
}) }, nil)
checkAckedUpdates(h.t, dbSession, nil) }
func perAckedUpdate(updates map[uint16]wtdb.BackupID) func(
_ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) {
return func(_ *wtdb.ClientSession, seq uint16,
id wtdb.BackupID) {
updates[seq] = id
}
} }
// testAckUpdate asserts the behavior of AckUpdate. // testAckUpdate asserts the behavior of AckUpdate.
@@ -628,9 +626,7 @@ func testAckUpdate(h *clientDBHarness) {
// Assert that the ClientSession loaded from disk has one update in it's // Assert that the ClientSession loaded from disk has one update in it's
// AckedUpdates map, and that the committed update has been removed. // AckedUpdates map, and that the committed update has been removed.
dbSession := h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
}) })
@@ -645,9 +641,7 @@ func testAckUpdate(h *clientDBHarness) {
h.ackUpdate(&session.ID, 2, 2, nil) h.ackUpdate(&session.ID, 2, 2, nil)
// Assert that both updates exist as AckedUpdates when loaded from disk. // Assert that both updates exist as AckedUpdates when loaded from disk.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
2: update2.BackupID, 2: update2.BackupID,
}) })
@@ -663,6 +657,19 @@ func testAckUpdate(h *clientDBHarness) {
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
} }
func (h *clientDBHarness) assertUpdates(id wtdb.SessionID,
expectedPending []wtdb.CommittedUpdate,
expectedAcked map[uint16]wtdb.BackupID) {
ackedUpdates := make(map[uint16]wtdb.BackupID)
_ = h.listSessions(
nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)),
)
dbSession := h.listSessions(nil)[id]
checkCommittedUpdates(h.t, dbSession, expectedPending)
checkAckedUpdates(h.t, ackedUpdates, expectedAcked)
}
// checkCommittedUpdates asserts that the CommittedUpdates on session match the // checkCommittedUpdates asserts that the CommittedUpdates on session match the
// expUpdates provided. // expUpdates provided.
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
@@ -682,7 +689,7 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
// checkAckedUpdates asserts that the AckedUpdates on a session match the // checkAckedUpdates asserts that the AckedUpdates on a session match the
// expUpdates provided. // expUpdates provided.
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, func checkAckedUpdates(t *testing.T, actualUpdates,
expUpdates map[uint16]wtdb.BackupID) { expUpdates map[uint16]wtdb.BackupID) {
// We promote nil expUpdates to an initialized map since the database // We promote nil expUpdates to an initialized map since the database
@@ -692,7 +699,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make(map[uint16]wtdb.BackupID) expUpdates = make(map[uint16]wtdb.BackupID)
} }
require.Equal(t, expUpdates, session.AckedUpdates) require.Equal(t, expUpdates, actualUpdates)
} }
// TestClientDB asserts the behavior of a fresh client db, a reopened client db, // TestClientDB asserts the behavior of a fresh client db, a reopened client db,

View File

@@ -47,13 +47,6 @@ type ClientSession struct {
// insertion and retrieval. // insertion and retrieval.
CommittedUpdates []CommittedUpdate CommittedUpdates []CommittedUpdate
// AckedUpdates is a map from sequence number to backup id to record
// which revoked states were uploaded via this session.
//
// NOTE: This map is serialized in it's own bucket, separate from the
// body of the ClientSession.
AckedUpdates map[uint16]BackupID
// Tower holds the pubkey and address of the watchtower. // Tower holds the pubkey and address of the watchtower.
// //
// NOTE: This value is not serialized. It is recovered by looking up the // NOTE: This value is not serialized. It is recovered by looking up the

View File

@@ -26,6 +26,7 @@ type ClientDB struct {
mu sync.Mutex mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
towerIndex map[towerPK]wtdb.TowerID towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower towers map[wtdb.TowerID]*wtdb.Tower
@@ -39,6 +40,7 @@ func NewClientDB() *ClientDB {
return &ClientDB{ return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
towerIndex: make(map[towerPK]wtdb.TowerID), towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower), towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32), indexes: make(map[keyIndexKey]uint32),
@@ -75,7 +77,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
} else { } else {
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
tower = &wtdb.Tower{ tower = &wtdb.Tower{
ID: wtdb.TowerID(towerID), ID: towerID,
IdentityKey: lnAddr.IdentityKey, IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address}, Addresses: []net.Addr{lnAddr.Address},
} }
@@ -193,7 +195,7 @@ func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
// MarkBackupIneligible records that particular commit height is ineligible for // MarkBackupIneligible records that particular commit height is ineligible for
// backup. This allows the client to track which updates it should not attempt // backup. This allows the client to track which updates it should not attempt
// to retry after startup. // to retry after startup.
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
return nil return nil
} }
@@ -213,9 +215,14 @@ func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
_ ...wtdb.ClientSessionListOption) ( opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
cfg := wtdb.NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions { for _, session := range m.activeSessions {
session := session session := session
@@ -224,6 +231,12 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
} }
session.Tower = m.towers[session.TowerID] session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session sessions[session.ID] = &session
if cfg.PerAckedUpdate != nil {
for seq, id := range m.ackedUpdates[session.ID] {
cfg.PerAckedUpdate(&session, seq, id)
}
}
} }
return sessions, nil return sessions, nil
@@ -274,8 +287,8 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
RewardPkScript: cloneBytes(session.RewardPkScript), RewardPkScript: cloneBytes(session.RewardPkScript),
}, },
CommittedUpdates: make([]wtdb.CommittedUpdate, 0), CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
AckedUpdates: make(map[uint16]wtdb.BackupID),
} }
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
return nil return nil
} }
@@ -402,7 +415,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
updates[len(updates)-1] = wtdb.CommittedUpdate{} updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1] session.CommittedUpdates = updates[:len(updates)-1]
session.AckedUpdates[seqNum] = update.BackupID m.ackedUpdates[*id][seqNum] = update.BackupID
session.TowerLastApplied = lastApplied session.TowerLastApplied = lastApplied
m.activeSessions[*id] = session m.activeSessions[*id] = session