Merge pull request #6928 from ellemouton/wtclientMemPerf

multi: remove AckedUpdates & CommittedUpdates from ClientSession struct
This commit is contained in:
Oliver Gugger 2022-10-13 15:34:35 +02:00 committed by GitHub
commit d55f861107
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 505 additions and 229 deletions

View File

@ -111,6 +111,10 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019).
closer coupling of Towers and Sessions and ensures that a session cannot be closer coupling of Towers and Sessions and ensures that a session cannot be
added if the tower it is referring to does not exist. added if the tower it is referring to does not exist.
* [Remove `AckedUpdates` & `CommittedUpdates` from the `ClientSession`
struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to
improve the performance of fetching a `ClientSession` from the DB.
* [Create a helper function to wait for peer to come * [Create a helper function to wait for peer to come
online](https://github.com/lightningnetwork/lnd/pull/6931). online](https://github.com/lightningnetwork/lnd/pull/6931).

View File

@ -265,12 +265,16 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
return nil, err return nil, err
} }
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers() opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
req.IncludeSessions,
)
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
legacyTowers, err := c.cfg.Client.RegisteredTowers() legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -286,7 +290,10 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
rpcTowers := make([]*Tower, 0, len(towers)) rpcTowers := make([]*Tower, 0, len(towers))
for _, tower := range towers { for _, tower := range towers {
rpcTower := marshallTower(tower, req.IncludeSessions) rpcTower := marshallTower(
tower, req.IncludeSessions, ackCounts,
committedUpdateCounts,
)
rpcTowers = append(rpcTowers, rpcTower) rpcTowers = append(rpcTowers, rpcTower)
} }
@ -306,16 +313,59 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context,
return nil, err return nil, err
} }
opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
req.IncludeSessions,
)
var tower *wtclient.RegisteredTower var tower *wtclient.RegisteredTower
tower, err = c.cfg.Client.LookupTower(pubKey) tower, err = c.cfg.Client.LookupTower(pubKey, opts...)
if err == wtdb.ErrTowerNotFound { if err == wtdb.ErrTowerNotFound {
tower, err = c.cfg.AnchorClient.LookupTower(pubKey) tower, err = c.cfg.AnchorClient.LookupTower(pubKey, opts...)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
return marshallTower(tower, req.IncludeSessions), nil return marshallTower(
tower, req.IncludeSessions, ackCounts, committedUpdateCounts,
), nil
}
// constructFunctionalOptions is a helper function that constructs a list of
// functional options to be used when fetching a tower from the DB. It also
// returns a map of acked-update counts and one for un-acked-update counts that
// will be populated once the db call has been made.
func constructFunctionalOptions(includeSessions bool) (
[]wtdb.ClientSessionListOption, map[wtdb.SessionID]uint16,
map[wtdb.SessionID]uint16) {
var (
opts []wtdb.ClientSessionListOption
ackCounts = make(map[wtdb.SessionID]uint16)
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
)
if !includeSessions {
return opts, ackCounts, committedUpdateCounts
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
_ wtdb.BackupID) {
ackCounts[s.ID]++
}
perCommittedUpdate := func(s *wtdb.ClientSession,
_ *wtdb.CommittedUpdate) {
committedUpdateCounts[s.ID]++
}
opts = []wtdb.ClientSessionListOption{
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
}
return opts, ackCounts, committedUpdateCounts
} }
// Stats returns the in-memory statistics of the client since startup. // Stats returns the in-memory statistics of the client since startup.
@ -387,7 +437,9 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
// marshallTower converts a client registered watchtower into its corresponding // marshallTower converts a client registered watchtower into its corresponding
// RPC type. // RPC type.
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower { func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower {
rpcAddrs := make([]string, 0, len(tower.Addresses)) rpcAddrs := make([]string, 0, len(tower.Addresses))
for _, addr := range tower.Addresses { for _, addr := range tower.Addresses {
rpcAddrs = append(rpcAddrs, addr.String()) rpcAddrs = append(rpcAddrs, addr.String())
@ -399,8 +451,8 @@ func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower
for _, session := range tower.Sessions { for _, session := range tower.Sessions {
satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000 satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000
rpcSessions = append(rpcSessions, &TowerSession{ rpcSessions = append(rpcSessions, &TowerSession{
NumBackups: uint32(len(session.AckedUpdates)), NumBackups: uint32(ackCounts[session.ID]),
NumPendingBackups: uint32(len(session.CommittedUpdates)), NumPendingBackups: uint32(pendingCounts[session.ID]),
MaxBackups: uint32(session.Policy.MaxUpdates), MaxBackups: uint32(session.Policy.MaxUpdates),
SweepSatPerVbyte: uint32(satPerVByte), SweepSatPerVbyte: uint32(satPerVByte),

View File

@ -83,10 +83,12 @@ type Client interface {
// RegisteredTowers retrieves the list of watchtowers registered with // RegisteredTowers retrieves the list of watchtowers registered with
// the client. // the client.
RegisteredTowers() ([]*RegisteredTower, error) RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
error)
// LookupTower retrieves a registered watchtower through its public key. // LookupTower retrieves a registered watchtower through its public key.
LookupTower(*btcec.PublicKey) (*RegisteredTower, error) LookupTower(*btcec.PublicKey,
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
// Stats returns the in-memory statistics of the client since startup. // Stats returns the in-memory statistics of the client since startup.
Stats() ClientStats Stats() ClientStats
@ -287,12 +289,67 @@ func New(config *Config) (*TowerClient, error) {
} }
plog := build.NewPrefixLog(prefix, log) plog := build.NewPrefixLog(prefix, log)
// Next, load all candidate towers and sessions from the database into // Load the sweep pkscripts that have been generated for all previously
// the client. We will use any of these sessions if their policies match // registered channels.
// the current policy of the client, otherwise they will be ignored and chanSummaries, err := cfg.DB.FetchChanSummaries()
// new sessions will be requested. if err != nil {
return nil, err
}
c := &TowerClient{
cfg: cfg,
log: plog,
pipeline: newTaskPipeline(plog),
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
statTicker: time.NewTicker(DefaultStatInterval),
stats: new(ClientStats),
newTowers: make(chan *newTowerMsg),
staleTowers: make(chan *staleTowerMsg),
forceQuit: make(chan struct{}),
}
// perUpdate is a callback function that will be used to inspect the
// full set of candidate client sessions loaded from disk, and to
// determine the highest known commit height for each channel. This
// allows the client to reject backups that it has already processed for
// its active policy.
perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
if policy != c.cfg.Policy {
return
}
// Take the highest commit height found in the session's acked
// updates.
height, ok := c.chanCommitHeights[id.ChanID]
if !ok || id.CommitHeight > height {
c.chanCommitHeights[id.ChanID] = id.CommitHeight
}
}
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
id wtdb.BackupID) {
perUpdate(s.Policy, id)
}
perCommittedUpdate := func(s *wtdb.ClientSession,
u *wtdb.CommittedUpdate) {
perUpdate(s.Policy, u.BackupID)
}
// Load all candidate sessions and towers from the database into the
// client. We will use any of these sessions if their policies match the
// current policy of the client, otherwise they will be ignored and new
// sessions will be requested.
isAnchorClient := cfg.Policy.IsAnchorChannel() isAnchorClient := cfg.Policy.IsAnchorChannel()
activeSessionFilter := genActiveSessionFilter(isAnchorClient) activeSessionFilter := genActiveSessionFilter(isAnchorClient)
candidateTowers := newTowerListIterator() candidateTowers := newTowerListIterator()
perActiveTower := func(tower *wtdb.Tower) { perActiveTower := func(tower *wtdb.Tower) {
// If the tower has already been marked as active, then there is // If the tower has already been marked as active, then there is
@ -307,34 +364,19 @@ func New(config *Config) (*TowerClient, error) {
// Add the tower to the set of candidate towers. // Add the tower to the set of candidate towers.
candidateTowers.AddCandidate(tower) candidateTowers.AddCandidate(tower)
} }
candidateSessions, err := getTowerAndSessionCandidates( candidateSessions, err := getTowerAndSessionCandidates(
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower, cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
wtdb.WithPerAckedUpdate(perAckedUpdate),
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
) )
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Load the sweep pkscripts that have been generated for all previously c.candidateTowers = candidateTowers
// registered channels. c.candidateSessions = candidateSessions
chanSummaries, err := cfg.DB.FetchChanSummaries()
if err != nil {
return nil, err
}
c := &TowerClient{
cfg: cfg,
log: plog,
pipeline: newTaskPipeline(plog),
candidateTowers: candidateTowers,
candidateSessions: candidateSessions,
activeSessions: make(sessionQueueSet),
summaries: chanSummaries,
statTicker: time.NewTicker(DefaultStatInterval),
stats: new(ClientStats),
newTowers: make(chan *newTowerMsg),
staleTowers: make(chan *staleTowerMsg),
forceQuit: make(chan struct{}),
}
c.negotiator = newSessionNegotiator(&NegotiatorConfig{ c.negotiator = newSessionNegotiator(&NegotiatorConfig{
DB: cfg.DB, DB: cfg.DB,
SecretKeyRing: cfg.SecretKeyRing, SecretKeyRing: cfg.SecretKeyRing,
@ -349,10 +391,6 @@ func New(config *Config) (*TowerClient, error) {
Log: plog, Log: plog,
}) })
// Reconstruct the highest commit height processed for each channel
// under the client's current policy.
c.buildHighestCommitHeights()
return c, nil return c, nil
} }
@ -363,7 +401,8 @@ func New(config *Config) (*TowerClient, error) {
// tower. // tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter func(*wtdb.ClientSession) bool, sessionFilter func(*wtdb.ClientSession) bool,
perActiveTower func(tower *wtdb.Tower)) ( perActiveTower func(tower *wtdb.Tower),
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
towers, err := db.ListTowers() towers, err := db.ListTowers()
@ -373,7 +412,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, tower := range towers { for _, tower := range towers {
sessions, err := db.ListClientSessions(&tower.ID) sessions, err := db.ListClientSessions(&tower.ID, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -413,10 +452,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
// ClientSession's SessionPrivKey field is desired, otherwise, the existing // ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used. // ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool) ( passesFilter func(*wtdb.ClientSession) bool,
opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*wtdb.ClientSession, error) {
sessions, err := db.ListClientSessions(forTower) sessions, err := db.ListClientSessions(forTower, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -446,48 +486,10 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
return sessions, nil return sessions, nil
} }
// buildHighestCommitHeights inspects the full set of candidate client sessions
// loaded from disk, and determines the highest known commit height for each
// channel. This allows the client to reject backups that it has already
// processed for it's active policy.
func (c *TowerClient) buildHighestCommitHeights() {
chanCommitHeights := make(map[lnwire.ChannelID]uint64)
for _, s := range c.candidateSessions {
// We only want to consider accepted updates that have been
// accepted under an identical policy to the client's current
// policy.
if s.Policy != c.cfg.Policy {
continue
}
// Take the highest commit height found in the session's
// committed updates.
for _, committedUpdate := range s.CommittedUpdates {
bid := committedUpdate.BackupID
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
// Take the heights commit height found in the session's acked
// updates.
for _, bid := range s.AckedUpdates {
height, ok := chanCommitHeights[bid.ChanID]
if !ok || bid.CommitHeight > height {
chanCommitHeights[bid.ChanID] = bid.CommitHeight
}
}
}
c.chanCommitHeights = chanCommitHeights
}
// Start initializes the watchtower client by loading or negotiating an active // Start initializes the watchtower client by loading or negotiating an active
// session and then begins processing backup tasks from the request pipeline. // session and then begins processing backup tasks from the request pipeline.
func (c *TowerClient) Start() error { func (c *TowerClient) Start() error {
var err error var returnErr error
c.started.Do(func() { c.started.Do(func() {
c.log.Infof("Watchtower client starting") c.log.Infof("Watchtower client starting")
@ -496,19 +498,27 @@ func (c *TowerClient) Start() error {
// sessions will be able to flush the committed updates after a // sessions will be able to flush the committed updates after a
// restart. // restart.
for _, session := range c.candidateSessions { for _, session := range c.candidateSessions {
if len(session.CommittedUpdates) > 0 { committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
if err != nil {
returnErr = err
return
}
if len(committedUpdates) > 0 {
c.log.Infof("Starting session=%s to process "+ c.log.Infof("Starting session=%s to process "+
"%d committed backups", session.ID, "%d committed backups", session.ID,
len(session.CommittedUpdates)) len(committedUpdates))
c.initActiveQueue(session)
c.initActiveQueue(session, committedUpdates)
} }
} }
// Now start the session negotiator, which will allow us to // Now start the session negotiator, which will allow us to
// request new session as soon as the backupDispatcher starts // request new session as soon as the backupDispatcher starts
// up. // up.
err = c.negotiator.Start() err := c.negotiator.Start()
if err != nil { if err != nil {
returnErr = err
return return
} }
@ -521,7 +531,7 @@ func (c *TowerClient) Start() error {
c.log.Infof("Watchtower client started successfully") c.log.Infof("Watchtower client started successfully")
}) })
return err return returnErr
} }
// Stop idempotently initiates a graceful shutdown of the watchtower client. // Stop idempotently initiates a graceful shutdown of the watchtower client.
@ -697,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
// active client's advertised policy will be ignored, but may be resumed if the // active client's advertised policy will be ignored, but may be resumed if the
// client is restarted with a matching policy. If no candidates were found, nil // client is restarted with a matching policy. If no candidates were found, nil
// is returned to signal that we need to request a new policy. // is returned to signal that we need to request a new policy.
func (c *TowerClient) nextSessionQueue() *sessionQueue { func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
// Select any candidate session at random, and remove it from the set of // Select any candidate session at random, and remove it from the set of
// candidate sessions. // candidate sessions.
var candidateSession *wtdb.ClientSession var candidateSession *wtdb.ClientSession
@ -719,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
// If none of the sessions could be used or none were found, we'll // If none of the sessions could be used or none were found, we'll
// return nil to signal that we need another session to be negotiated. // return nil to signal that we need another session to be negotiated.
if candidateSession == nil { if candidateSession == nil {
return nil return nil, nil
}
updates, err := c.cfg.DB.FetchSessionCommittedUpdates(
&candidateSession.ID,
)
if err != nil {
return nil, err
} }
// Initialize the session queue and spin it up so it can begin handling // Initialize the session queue and spin it up so it can begin handling
// updates. If the queue was already made active on startup, this will // updates. If the queue was already made active on startup, this will
// simply return the existing session queue from the set. // simply return the existing session queue from the set.
return c.getOrInitActiveQueue(candidateSession) return c.getOrInitActiveQueue(candidateSession, updates), nil
} }
// backupDispatcher processes events coming from the taskPipeline and is // backupDispatcher processes events coming from the taskPipeline and is
@ -798,7 +815,13 @@ func (c *TowerClient) backupDispatcher() {
// We've exhausted the prior session, we'll pop another // We've exhausted the prior session, we'll pop another
// from the remaining sessions and continue processing // from the remaining sessions and continue processing
// backup tasks. // backup tasks.
c.sessionQueue = c.nextSessionQueue() var err error
c.sessionQueue, err = c.nextSessionQueue()
if err != nil {
c.log.Errorf("error fetching next session "+
"queue: %v", err)
}
if c.sessionQueue != nil { if c.sessionQueue != nil {
c.log.Debugf("Loaded next candidate session "+ c.log.Debugf("Loaded next candidate session "+
"queue id=%s", c.sessionQueue.ID()) "queue id=%s", c.sessionQueue.ID())
@ -1046,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the // newSessionQueue creates a sessionQueue from a ClientSession loaded from the
// database and supplying it with the resources needed by the client. // database and supplying it with the resources needed by the client.
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{ return newSessionQueue(&sessionQueueConfig{
ClientSession: s, ClientSession: s,
ChainHash: c.cfg.ChainHash, ChainHash: c.cfg.ChainHash,
@ -1058,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
MinBackoff: c.cfg.MinBackoff, MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff, MaxBackoff: c.cfg.MaxBackoff,
Log: c.log, Log: c.log,
}) }, updates)
} }
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
// passed ClientSession. If it exists, the active sessionQueue is returned. // passed ClientSession. If it exists, the active sessionQueue is returned.
// Otherwise a new sessionQueue is initialized and added to the set. // Otherwise a new sessionQueue is initialized and added to the set.
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
if sq, ok := c.activeSessions[s.ID]; ok { if sq, ok := c.activeSessions[s.ID]; ok {
return sq return sq
} }
return c.initActiveQueue(s) return c.initActiveQueue(s, updates)
} }
// initActiveQueue creates a new sessionQueue from the passed ClientSession, // initActiveQueue creates a new sessionQueue from the passed ClientSession,
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue // adds the sessionQueue to the activeSessions set, and starts the sessionQueue
// so that it can deliver any committed updates or begin accepting newly // so that it can deliver any committed updates or begin accepting newly
// assigned tasks. // assigned tasks.
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue { func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue {
// Initialize the session queue, providing it with all of the resources // Initialize the session queue, providing it with all of the resources
// it requires from the client instance. // it requires from the client instance.
sq := c.newSessionQueue(s) sq := c.newSessionQueue(s, updates)
// Add the session queue as an active session so that we remember to // Add the session queue as an active session so that we remember to
// stop it on shutdown. // stop it on shutdown.
@ -1233,13 +1262,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// RegisteredTowers retrieves the list of watchtowers registered with the // RegisteredTowers retrieves the list of watchtowers registered with the
// client. // client.
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) { func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
[]*RegisteredTower, error) {
// Retrieve all of our towers along with all of our sessions. // Retrieve all of our towers along with all of our sessions.
towers, err := c.cfg.DB.ListTowers() towers, err := c.cfg.DB.ListTowers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
clientSessions, err := c.cfg.DB.ListClientSessions(nil) clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -1272,13 +1303,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
} }
// LookupTower retrieves a registered watchtower through its public key. // LookupTower retrieves a registered watchtower through its public key.
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) { func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
tower, err := c.cfg.DB.LoadTower(pubKey) tower, err := c.cfg.DB.LoadTower(pubKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID) towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -62,9 +62,14 @@ type DB interface {
// still be able to accept state updates. An optional tower ID can be // still be able to accept state updates. An optional tower ID can be
// used to filter out any client sessions in the response that do not // used to filter out any client sessions in the response that do not
// correspond to this tower. // correspond to this tower.
ListClientSessions(*wtdb.TowerID) ( ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) map[wtdb.SessionID]*wtdb.ClientSession, error)
// FetchSessionCommittedUpdates retrieves the current set of un-acked
// updates of the given session.
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
[]wtdb.CommittedUpdate, error)
// FetchChanSummaries loads a mapping from all registered channels to // FetchChanSummaries loads a mapping from all registered channels to
// their channel summaries. // their channel summaries.
FetchChanSummaries() (wtdb.ChannelSummaries, error) FetchChanSummaries() (wtdb.ChannelSummaries, error)

View File

@ -109,7 +109,9 @@ type sessionQueue struct {
} }
// newSessionQueue intiializes a fresh sessionQueue. // newSessionQueue intiializes a fresh sessionQueue.
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue { func newSessionQueue(cfg *sessionQueueConfig,
updates []wtdb.CommittedUpdate) *sessionQueue {
localInit := wtwire.NewInitMessage( localInit := wtwire.NewInitMessage(
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired), lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
cfg.ChainHash, cfg.ChainHash,
@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
// The database should return them in sorted order, and session queue's // The database should return them in sorted order, and session queue's
// sequence number will be equal to that of the last committed update. // sequence number will be equal to that of the last committed update.
for _, update := range sq.cfg.ClientSession.CommittedUpdates { for _, update := range updates {
sq.commitQueue.PushBack(update) sq.commitQueue.PushBack(update)
} }

View File

@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
return ErrUninitializedDB return ErrUninitializedDB
} }
towerID := TowerIDFromBytes(towerIDBytes) towerID := TowerIDFromBytes(towerIDBytes)
committedUpdateCount := make(map[SessionID]uint16)
perCommittedUpdate := func(s *ClientSession,
_ *CommittedUpdate) {
committedUpdateCount[s.ID]++
}
towerSessions, err := listTowerSessions( towerSessions, err := listTowerSessions(
towerID, sessions, towers, towersToSessionsIndex, towerID, sessions, towers, towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate),
) )
if err != nil { if err != nil {
return err return err
@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
// have any pending updates to ensure we don't load them upon // have any pending updates to ensure we don't load them upon
// restarts. // restarts.
for _, session := range towerSessions { for _, session := range towerSessions {
if len(session.CommittedUpdates) > 0 { if committedUpdateCount[session.ID] > 0 {
return ErrTowerUnackedUpdates return ErrTowerUnackedUpdates
} }
err := markSessionStatus( err := markSessionStatus(
@ -736,8 +745,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
// ListClientSessions returns the set of all client sessions known to the db. An // ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (c *ClientDB) ListClientSessions(id *TowerID) ( func (c *ClientDB) ListClientSessions(id *TowerID,
map[SessionID]*ClientSession, error) { opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
var clientSessions map[SessionID]*ClientSession var clientSessions map[SessionID]*ClientSession
err := kvdb.View(c.db, func(tx kvdb.RTx) error { err := kvdb.View(c.db, func(tx kvdb.RTx) error {
@ -757,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
// known to the db. // known to the db.
if id == nil { if id == nil {
clientSessions, err = listClientAllSessions( clientSessions, err = listClientAllSessions(
sessions, towers, sessions, towers, opts...,
) )
return err return err
} }
@ -769,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
} }
clientSessions, err = listTowerSessions( clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex, *id, sessions, towers, towerToSessionIndex, opts...,
) )
return err return err
}, func() { }, func() {
@ -783,8 +792,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
} }
// listClientAllSessions returns the set of all client sessions known to the db. // listClientAllSessions returns the set of all client sessions known to the db.
func listClientAllSessions(sessions, func listClientAllSessions(sessions, towers kvdb.RBucket,
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) { opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
err := sessions.ForEach(func(k, _ []byte) error { err := sessions.ForEach(func(k, _ []byte) error {
@ -792,7 +801,7 @@ func listClientAllSessions(sessions,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessions, towers, k) session, err := getClientSession(sessions, towers, k, opts...)
if err != nil { if err != nil {
return err return err
} }
@ -811,8 +820,8 @@ func listClientAllSessions(sessions,
// listTowerSessions returns the set of all client sessions known to the db // listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id. // that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt, func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
error) { map[SessionID]*ClientSession, error) {
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes()) towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
if towerIndexBkt == nil { if towerIndexBkt == nil {
@ -825,7 +834,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessionsBkt, towersBkt, k) session, err := getClientSession(
sessionsBkt, towersBkt, k, opts...,
)
if err != nil { if err != nil {
return err return err
} }
@ -840,6 +851,36 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
return clientSessions, nil return clientSessions, nil
} }
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
// of the given session.
func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) (
[]CommittedUpdate, error) {
var committedUpdates []CommittedUpdate
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
sessions := tx.ReadBucket(cSessionBkt)
if sessions == nil {
return ErrUninitializedDB
}
sessionBkt := sessions.NestedReadBucket(id[:])
if sessionBkt == nil {
return ErrClientSessionNotFound
}
var err error
committedUpdates, err = getClientSessionCommits(
sessionBkt, nil, nil,
)
return err
}, func() {})
if err != nil {
return nil, err
}
return committedUpdates, nil
}
// FetchChanSummaries loads a mapping from all registered channels to their // FetchChanSummaries loads a mapping from all registered channels to their
// channel summaries. // channel summaries.
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) { func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
@ -1157,11 +1198,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
return &session, nil return &session, nil
} }
// PerAckedUpdateCB describes the signature of a callback function that can be
// called for each of a session's acked updates.
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
// PerCommittedUpdateCB describes the signature of a callback function that can
// be called for each of a session's committed updates (updates that the client
// has not yet received an ACK for).
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
// ClientSessionListOption describes the signature of a functional option that
// can be used when listing client sessions in order to provide any extra
// instruction to the query.
type ClientSessionListOption func(cfg *ClientSessionListCfg)
// ClientSessionListCfg defines various query parameters that will be used when
// querying the DB for client sessions.
type ClientSessionListCfg struct {
// PerAckedUpdate will, if set, be called for each of the session's
// acked updates.
PerAckedUpdate PerAckedUpdateCB
// PerCommittedUpdate will, if set, be called for each of the session's
// committed (un-acked) updates.
PerCommittedUpdate PerCommittedUpdateCB
}
// NewClientSessionCfg constructs a new ClientSessionListCfg.
func NewClientSessionCfg() *ClientSessionListCfg {
return &ClientSessionListCfg{}
}
// WithPerAckedUpdate constructs a functional option that will set a call-back
// function to be called for each of a client's acked updates.
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerAckedUpdate = cb
}
}
// WithPerCommittedUpdate constructs a functional option that will set a
// call-back function to be called for each of a client's un-acked updates.
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
return func(cfg *ClientSessionListCfg) {
cfg.PerCommittedUpdate = cb
}
}
// getClientSession loads the full ClientSession associated with the serialized // getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower // session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body. // in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket, func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
idBytes []byte) (*ClientSession, error) { opts ...ClientSessionListOption) (*ClientSession, error) {
cfg := NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
session, err := getClientSessionBody(sessions, idBytes) session, err := getClientSessionBody(sessions, idBytes)
if err != nil { if err != nil {
@ -1173,35 +1266,37 @@ func getClientSession(sessions, towers kvdb.RBucket,
if err != nil { if err != nil {
return nil, err return nil, err
} }
// Fetch the committed updates for this session.
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
if err != nil {
return nil, err
}
// Fetch the acked updates for this session.
ackedUpdates, err := getClientSessionAcks(sessions, idBytes)
if err != nil {
return nil, err
}
session.Tower = tower session.Tower = tower
session.CommittedUpdates = commitedUpdates
session.AckedUpdates = ackedUpdates // Can't fail because client session body has already been read.
sessionBkt := sessions.NestedReadBucket(idBytes)
// Pass the session's committed (un-acked) updates through the call-back
// if one is provided.
err = filterClientSessionCommits(
sessionBkt, session, cfg.PerCommittedUpdate,
)
if err != nil {
return nil, err
}
// Pass the session's acked updates through the call-back if one is
// provided.
err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate)
if err != nil {
return nil, err
}
return session, nil return session, nil
} }
// getClientSessionCommits retrieves all committed updates for the session // getClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id. // identified by the serialized session id. If a PerCommittedUpdateCB is
func getClientSessionCommits(sessions kvdb.RBucket, // provided, then it will be called for each of the session's committed updates.
idBytes []byte) ([]CommittedUpdate, error) { func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
// Can't fail because client session body has already been read. // Initialize committedUpdates so that we can return an initialized map
sessionBkt := sessions.NestedReadBucket(idBytes)
// Initialize commitedUpdates so that we can return an initialized map
// if no committed updates exist. // if no committed updates exist.
committedUpdates := make([]CommittedUpdate, 0) committedUpdates := make([]CommittedUpdate, 0)
@ -1220,6 +1315,10 @@ func getClientSessionCommits(sessions kvdb.RBucket,
committedUpdates = append(committedUpdates, committedUpdate) committedUpdates = append(committedUpdates, committedUpdate)
if cb != nil {
cb(s, &committedUpdate)
}
return nil return nil
}) })
if err != nil { if err != nil {
@ -1229,21 +1328,19 @@ func getClientSessionCommits(sessions kvdb.RBucket,
return committedUpdates, nil return committedUpdates, nil
} }
// getClientSessionAcks retrieves all acked updates for the session identified // filterClientSessionAcks retrieves all acked updates for the session
// by the serialized session id. // identified by the serialized session id and passes them to the provided
func getClientSessionAcks(sessions kvdb.RBucket, // call back if one is provided.
idBytes []byte) (map[uint16]BackupID, error) { func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerAckedUpdateCB) error {
// Can't fail because client session body has already been read. if cb == nil {
sessionBkt := sessions.NestedReadBucket(idBytes) return nil
}
// Initialize ackedUpdates so that we can return an initialized map if
// no acked updates exist.
ackedUpdates := make(map[uint16]BackupID)
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks) sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
if sessionAcks == nil { if sessionAcks == nil {
return ackedUpdates, nil return nil
} }
err := sessionAcks.ForEach(func(k, v []byte) error { err := sessionAcks.ForEach(func(k, v []byte) error {
@ -1255,15 +1352,47 @@ func getClientSessionAcks(sessions kvdb.RBucket,
return err return err
} }
ackedUpdates[seqNum] = backupID cb(s, seqNum, backupID)
return nil return nil
}) })
if err != nil { if err != nil {
return nil, err return err
} }
return ackedUpdates, nil return nil
}
// filterClientSessionCommits retrieves all committed updates for the session
// identified by the serialized session id and passes them to the given
// PerCommittedUpdateCB callback.
func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
cb PerCommittedUpdateCB) error {
if cb == nil {
return nil
}
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
if sessionCommits == nil {
return nil
}
err := sessionCommits.ForEach(func(k, v []byte) error {
var committedUpdate CommittedUpdate
err := committedUpdate.Decode(bytes.NewReader(v))
if err != nil {
return err
}
committedUpdate.SeqNum = byteOrder.Uint16(k)
cb(s, &committedUpdate)
return nil
})
if err != nil {
return err
}
return nil
} }
// putClientSessionBody stores the body of the ClientSession (everything but the // putClientSessionBody stores the body of the ClientSession (everything but the

View File

@ -48,12 +48,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
require.ErrorIs(h.t, err, expErr) require.ErrorIs(h.t, err, expErr)
} }
func (h *clientDBHarness) listSessions( func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession { opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
h.t.Helper() h.t.Helper()
sessions, err := h.db.ListClientSessions(id) sessions, err := h.db.ListClientSessions(id, opts...)
require.NoError(h.t, err, "unable to list client sessions") require.NoError(h.t, err, "unable to list client sessions")
return sessions return sessions
@ -207,6 +207,20 @@ func (h *clientDBHarness) newTower() *wtdb.Tower {
}, nil) }, nil)
} }
func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID,
expErr error) []wtdb.CommittedUpdate {
h.t.Helper()
updates, err := h.db.FetchSessionCommittedUpdates(id)
if err != expErr {
h.t.Fatalf("expected fetch session committed updates error: "+
"%v, got: %v", expErr, err)
}
return updates
}
// testCreateClientSession asserts various conditions regarding the creation of // testCreateClientSession asserts various conditions regarding the creation of
// a new ClientSession. The test asserts: // a new ClientSession. The test asserts:
// - client sessions can only be created if a session key index is reserved. // - client sessions can only be created if a session key index is reserved.
@ -506,6 +520,9 @@ func testCommitUpdate(h *clientDBHarness) {
// session, which should fail. // session, which should fail.
update1 := randCommittedUpdate(h.t, 1) update1 := randCommittedUpdate(h.t, 1)
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound) h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
h.fetchSessionCommittedUpdates(
&session.ID, wtdb.ErrClientSessionNotFound,
)
// Reserve a session key index and insert the session. // Reserve a session key index and insert the session.
session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType) session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
@ -520,11 +537,7 @@ func testCommitUpdate(h *clientDBHarness) {
// Assert that the committed update appears in the client session's // Assert that the committed update appears in the client session's
// CommittedUpdates map when loaded from disk and that there are no // CommittedUpdates map when loaded from disk and that there are no
// AckedUpdates. // AckedUpdates.
dbSession := h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Try to commit the same update, which should succeed due to // Try to commit the same update, which should succeed due to
// idempotency (which is preserved when the breach hint is identical to // idempotency (which is preserved when the breach hint is identical to
@ -534,11 +547,7 @@ func testCommitUpdate(h *clientDBHarness) {
require.Equal(h.t, lastApplied, lastApplied2) require.Equal(h.t, lastApplied, lastApplied2)
// Assert that the loaded ClientSession is the same as before. // Assert that the loaded ClientSession is the same as before.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1,
})
checkAckedUpdates(h.t, dbSession, nil)
// Generate another random update and try to commit it at the identical // Generate another random update and try to commit it at the identical
// sequence number. Since the breach hint has changed, this should fail. // sequence number. Since the breach hint has changed, this should fail.
@ -553,12 +562,10 @@ func testCommitUpdate(h *clientDBHarness) {
// Check that both updates now appear as committed on the ClientSession // Check that both updates now appear as committed on the ClientSession
// loaded from disk. // loaded from disk.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
}) }, nil)
checkAckedUpdates(h.t, dbSession, nil)
// Finally, create one more random update and try to commit it at index // Finally, create one more random update and try to commit it at index
// 4, which should be rejected since 3 is the next slot the database // 4, which should be rejected since 3 is the next slot the database
@ -567,12 +574,20 @@ func testCommitUpdate(h *clientDBHarness) {
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate) h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
// Assert that the ClientSession loaded from disk remains unchanged. // Assert that the ClientSession loaded from disk remains unchanged.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
*update1, *update1,
*update2, *update2,
}) }, nil)
checkAckedUpdates(h.t, dbSession, nil) }
func perAckedUpdate(updates map[uint16]wtdb.BackupID) func(
_ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) {
return func(_ *wtdb.ClientSession, seq uint16,
id wtdb.BackupID) {
updates[seq] = id
}
} }
// testAckUpdate asserts the behavior of AckUpdate. // testAckUpdate asserts the behavior of AckUpdate.
@ -628,9 +643,7 @@ func testAckUpdate(h *clientDBHarness) {
// Assert that the ClientSession loaded from disk has one update in it's // Assert that the ClientSession loaded from disk has one update in it's
// AckedUpdates map, and that the committed update has been removed. // AckedUpdates map, and that the committed update has been removed.
dbSession := h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
}) })
@ -645,9 +658,7 @@ func testAckUpdate(h *clientDBHarness) {
h.ackUpdate(&session.ID, 2, 2, nil) h.ackUpdate(&session.ID, 2, 2, nil)
// Assert that both updates exist as AckedUpdates when loaded from disk. // Assert that both updates exist as AckedUpdates when loaded from disk.
dbSession = h.listSessions(nil)[session.ID] h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
checkCommittedUpdates(h.t, dbSession, nil)
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
1: update1.BackupID, 1: update1.BackupID,
2: update2.BackupID, 2: update2.BackupID,
}) })
@ -663,9 +674,22 @@ func testAckUpdate(h *clientDBHarness) {
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied) h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
} }
func (h *clientDBHarness) assertUpdates(id wtdb.SessionID,
expectedPending []wtdb.CommittedUpdate,
expectedAcked map[uint16]wtdb.BackupID) {
ackedUpdates := make(map[uint16]wtdb.BackupID)
_ = h.listSessions(
nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)),
)
committedUpates := h.fetchSessionCommittedUpdates(&id, nil)
checkCommittedUpdates(h.t, committedUpates, expectedPending)
checkAckedUpdates(h.t, ackedUpdates, expectedAcked)
}
// checkCommittedUpdates asserts that the CommittedUpdates on session match the // checkCommittedUpdates asserts that the CommittedUpdates on session match the
// expUpdates provided. // expUpdates provided.
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession, func checkCommittedUpdates(t *testing.T, actualUpdates,
expUpdates []wtdb.CommittedUpdate) { expUpdates []wtdb.CommittedUpdate) {
t.Helper() t.Helper()
@ -677,12 +701,12 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make([]wtdb.CommittedUpdate, 0) expUpdates = make([]wtdb.CommittedUpdate, 0)
} }
require.Equal(t, expUpdates, session.CommittedUpdates) require.Equal(t, expUpdates, actualUpdates)
} }
// checkAckedUpdates asserts that the AckedUpdates on a session match the // checkAckedUpdates asserts that the AckedUpdates on a session match the
// expUpdates provided. // expUpdates provided.
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession, func checkAckedUpdates(t *testing.T, actualUpdates,
expUpdates map[uint16]wtdb.BackupID) { expUpdates map[uint16]wtdb.BackupID) {
// We promote nil expUpdates to an initialized map since the database // We promote nil expUpdates to an initialized map since the database
@ -692,7 +716,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
expUpdates = make(map[uint16]wtdb.BackupID) expUpdates = make(map[uint16]wtdb.BackupID)
} }
require.Equal(t, expUpdates, session.AckedUpdates) require.Equal(t, expUpdates, actualUpdates)
} }
// TestClientDB asserts the behavior of a fresh client db, a reopened client db, // TestClientDB asserts the behavior of a fresh client db, a reopened client db,

View File

@ -37,23 +37,6 @@ type ClientSession struct {
ClientSessionBody ClientSessionBody
// CommittedUpdates is a sorted list of unacked updates. These updates
// can be resent after a restart if the updates failed to send or
// receive an acknowledgment.
//
// NOTE: This list is serialized in it's own bucket, separate from the
// body of the ClientSession. The representation on disk is a key value
// map from sequence number to CommittedUpdateBody to allow efficient
// insertion and retrieval.
CommittedUpdates []CommittedUpdate
// AckedUpdates is a map from sequence number to backup id to record
// which revoked states were uploaded via this session.
//
// NOTE: This map is serialized in it's own bucket, separate from the
// body of the ClientSession.
AckedUpdates map[uint16]BackupID
// Tower holds the pubkey and address of the watchtower. // Tower holds the pubkey and address of the watchtower.
// //
// NOTE: This value is not serialized. It is recovered by looking up the // NOTE: This value is not serialized. It is recovered by looking up the

View File

@ -26,6 +26,8 @@ type ClientDB struct {
mu sync.Mutex mu sync.Mutex
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
activeSessions map[wtdb.SessionID]wtdb.ClientSession activeSessions map[wtdb.SessionID]wtdb.ClientSession
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
towerIndex map[towerPK]wtdb.TowerID towerIndex map[towerPK]wtdb.TowerID
towers map[wtdb.TowerID]*wtdb.Tower towers map[wtdb.TowerID]*wtdb.Tower
@ -39,6 +41,8 @@ func NewClientDB() *ClientDB {
return &ClientDB{ return &ClientDB{
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary), summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession), activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
towerIndex: make(map[towerPK]wtdb.TowerID), towerIndex: make(map[towerPK]wtdb.TowerID),
towers: make(map[wtdb.TowerID]*wtdb.Tower), towers: make(map[wtdb.TowerID]*wtdb.Tower),
indexes: make(map[keyIndexKey]uint32), indexes: make(map[keyIndexKey]uint32),
@ -75,7 +79,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
} else { } else {
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1)) towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
tower = &wtdb.Tower{ tower = &wtdb.Tower{
ID: wtdb.TowerID(towerID), ID: towerID,
IdentityKey: lnAddr.IdentityKey, IdentityKey: lnAddr.IdentityKey,
Addresses: []net.Addr{lnAddr.Address}, Addresses: []net.Addr{lnAddr.Address},
} }
@ -129,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
} }
for id, session := range towerSessions { for id, session := range towerSessions {
if len(session.CommittedUpdates) > 0 { if len(m.committedUpdates[session.ID]) > 0 {
return wtdb.ErrTowerUnackedUpdates return wtdb.ErrTowerUnackedUpdates
} }
session.Status = wtdb.CSessionInactive session.Status = wtdb.CSessionInactive
@ -193,26 +197,33 @@ func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
// MarkBackupIneligible records that particular commit height is ineligible for // MarkBackupIneligible records that particular commit height is ineligible for
// backup. This allows the client to track which updates it should not attempt // backup. This allows the client to track which updates it should not attempt
// to retry after startup. // to retry after startup.
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error { func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
return nil return nil
} }
// ListClientSessions returns the set of all client sessions known to the db. An // ListClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (m *ClientDB) ListClientSessions( func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
return m.listClientSessions(tower) return m.listClientSessions(tower, opts...)
} }
// listClientSessions returns the set of all client sessions known to the db. An // listClientSessions returns the set of all client sessions known to the db. An
// optional tower ID can be used to filter out any client sessions in the // optional tower ID can be used to filter out any client sessions in the
// response that do not correspond to this tower. // response that do not correspond to this tower.
func (m *ClientDB) listClientSessions( func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) { opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) {
cfg := wtdb.NewClientSessionCfg()
for _, o := range opts {
o(cfg)
}
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession) sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
for _, session := range m.activeSessions { for _, session := range m.activeSessions {
@ -222,11 +233,40 @@ func (m *ClientDB) listClientSessions(
} }
session.Tower = m.towers[session.TowerID] session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session sessions[session.ID] = &session
if cfg.PerAckedUpdate != nil {
for seq, id := range m.ackedUpdates[session.ID] {
cfg.PerAckedUpdate(&session, seq, id)
}
}
if cfg.PerCommittedUpdate != nil {
for _, update := range m.committedUpdates[session.ID] {
update := update
cfg.PerCommittedUpdate(&session, &update)
}
}
} }
return sessions, nil return sessions, nil
} }
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
// of the given session.
func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
[]wtdb.CommittedUpdate, error) {
m.mu.Lock()
defer m.mu.Unlock()
updates, ok := m.committedUpdates[*id]
if !ok {
return nil, wtdb.ErrClientSessionNotFound
}
return updates, nil
}
// CreateClientSession records a newly negotiated client session in the set of // CreateClientSession records a newly negotiated client session in the set of
// active sessions. The session can be identified by its SessionID. // active sessions. The session can be identified by its SessionID.
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error { func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
@ -271,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
Policy: session.Policy, Policy: session.Policy,
RewardPkScript: cloneBytes(session.RewardPkScript), RewardPkScript: cloneBytes(session.RewardPkScript),
}, },
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
AckedUpdates: make(map[uint16]wtdb.BackupID),
} }
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
return nil return nil
} }
@ -334,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
} }
// Check if an update has already been committed for this state. // Check if an update has already been committed for this state.
for _, dbUpdate := range session.CommittedUpdates { for _, dbUpdate := range m.committedUpdates[session.ID] {
if dbUpdate.SeqNum == update.SeqNum { if dbUpdate.SeqNum == update.SeqNum {
// If the breach hint matches, we'll just return the // If the breach hint matches, we'll just return the
// last applied value so the client can retransmit. // last applied value so the client can retransmit.
@ -353,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
} }
// Save the update and increment the sequence number. // Save the update and increment the sequence number.
session.CommittedUpdates = append(session.CommittedUpdates, *update) m.committedUpdates[session.ID] = append(
m.committedUpdates[session.ID], *update,
)
session.SeqNum++ session.SeqNum++
m.activeSessions[*id] = session m.activeSessions[*id] = session
@ -363,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This // AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
// removes the update from the set of committed updates, and validates the // removes the update from the set of committed updates, and validates the
// lastApplied value returned from the tower. // lastApplied value returned from the tower.
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error { func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
lastApplied uint16) error {
m.mu.Lock() m.mu.Lock()
defer m.mu.Unlock() defer m.mu.Unlock()
@ -387,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// Retrieve the committed update, failing if none is found. We should // Retrieve the committed update, failing if none is found. We should
// only receive acks for state updates that we send. // only receive acks for state updates that we send.
updates := session.CommittedUpdates updates := m.committedUpdates[session.ID]
for i, update := range updates { for i, update := range updates {
if update.SeqNum != seqNum { if update.SeqNum != seqNum {
continue continue
@ -398,9 +442,9 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
// along with the next update. // along with the next update.
copy(updates[:i], updates[i+1:]) copy(updates[:i], updates[i+1:])
updates[len(updates)-1] = wtdb.CommittedUpdate{} updates[len(updates)-1] = wtdb.CommittedUpdate{}
session.CommittedUpdates = updates[:len(updates)-1] m.committedUpdates[session.ID] = updates[:len(updates)-1]
session.AckedUpdates[seqNum] = update.BackupID m.ackedUpdates[*id][seqNum] = update.BackupID
session.TowerLastApplied = lastApplied session.TowerLastApplied = lastApplied
m.activeSessions[*id] = session m.activeSessions[*id] = session