mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-29 17:19:33 +02:00
Merge pull request #6928 from ellemouton/wtclientMemPerf
multi: remove AckedUpdates & CommittedUpdates from ClientSession struct
This commit is contained in:
commit
d55f861107
@ -111,6 +111,10 @@ crash](https://github.com/lightningnetwork/lnd/pull/7019).
|
||||
closer coupling of Towers and Sessions and ensures that a session cannot be
|
||||
added if the tower it is referring to does not exist.
|
||||
|
||||
* [Remove `AckedUpdates` & `CommittedUpdates` from the `ClientSession`
|
||||
struct](https://github.com/lightningnetwork/lnd/pull/6928) in order to
|
||||
improve the performance of fetching a `ClientSession` from the DB.
|
||||
|
||||
* [Create a helper function to wait for peer to come
|
||||
online](https://github.com/lightningnetwork/lnd/pull/6931).
|
||||
|
||||
|
@ -265,12 +265,16 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers()
|
||||
opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
|
||||
req.IncludeSessions,
|
||||
)
|
||||
|
||||
anchorTowers, err := c.cfg.AnchorClient.RegisteredTowers(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
legacyTowers, err := c.cfg.Client.RegisteredTowers()
|
||||
legacyTowers, err := c.cfg.Client.RegisteredTowers(opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -286,7 +290,10 @@ func (c *WatchtowerClient) ListTowers(ctx context.Context,
|
||||
|
||||
rpcTowers := make([]*Tower, 0, len(towers))
|
||||
for _, tower := range towers {
|
||||
rpcTower := marshallTower(tower, req.IncludeSessions)
|
||||
rpcTower := marshallTower(
|
||||
tower, req.IncludeSessions, ackCounts,
|
||||
committedUpdateCounts,
|
||||
)
|
||||
rpcTowers = append(rpcTowers, rpcTower)
|
||||
}
|
||||
|
||||
@ -306,16 +313,59 @@ func (c *WatchtowerClient) GetTowerInfo(ctx context.Context,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opts, ackCounts, committedUpdateCounts := constructFunctionalOptions(
|
||||
req.IncludeSessions,
|
||||
)
|
||||
|
||||
var tower *wtclient.RegisteredTower
|
||||
tower, err = c.cfg.Client.LookupTower(pubKey)
|
||||
tower, err = c.cfg.Client.LookupTower(pubKey, opts...)
|
||||
if err == wtdb.ErrTowerNotFound {
|
||||
tower, err = c.cfg.AnchorClient.LookupTower(pubKey)
|
||||
tower, err = c.cfg.AnchorClient.LookupTower(pubKey, opts...)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return marshallTower(tower, req.IncludeSessions), nil
|
||||
return marshallTower(
|
||||
tower, req.IncludeSessions, ackCounts, committedUpdateCounts,
|
||||
), nil
|
||||
}
|
||||
|
||||
// constructFunctionalOptions is a helper function that constructs a list of
|
||||
// functional options to be used when fetching a tower from the DB. It also
|
||||
// returns a map of acked-update counts and one for un-acked-update counts that
|
||||
// will be populated once the db call has been made.
|
||||
func constructFunctionalOptions(includeSessions bool) (
|
||||
[]wtdb.ClientSessionListOption, map[wtdb.SessionID]uint16,
|
||||
map[wtdb.SessionID]uint16) {
|
||||
|
||||
var (
|
||||
opts []wtdb.ClientSessionListOption
|
||||
ackCounts = make(map[wtdb.SessionID]uint16)
|
||||
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
|
||||
)
|
||||
if !includeSessions {
|
||||
return opts, ackCounts, committedUpdateCounts
|
||||
}
|
||||
|
||||
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
|
||||
_ wtdb.BackupID) {
|
||||
|
||||
ackCounts[s.ID]++
|
||||
}
|
||||
|
||||
perCommittedUpdate := func(s *wtdb.ClientSession,
|
||||
_ *wtdb.CommittedUpdate) {
|
||||
|
||||
committedUpdateCounts[s.ID]++
|
||||
}
|
||||
|
||||
opts = []wtdb.ClientSessionListOption{
|
||||
wtdb.WithPerAckedUpdate(perAckedUpdate),
|
||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||
}
|
||||
|
||||
return opts, ackCounts, committedUpdateCounts
|
||||
}
|
||||
|
||||
// Stats returns the in-memory statistics of the client since startup.
|
||||
@ -387,7 +437,9 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
|
||||
|
||||
// marshallTower converts a client registered watchtower into its corresponding
|
||||
// RPC type.
|
||||
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower {
|
||||
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
|
||||
ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower {
|
||||
|
||||
rpcAddrs := make([]string, 0, len(tower.Addresses))
|
||||
for _, addr := range tower.Addresses {
|
||||
rpcAddrs = append(rpcAddrs, addr.String())
|
||||
@ -399,8 +451,8 @@ func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool) *Tower
|
||||
for _, session := range tower.Sessions {
|
||||
satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000
|
||||
rpcSessions = append(rpcSessions, &TowerSession{
|
||||
NumBackups: uint32(len(session.AckedUpdates)),
|
||||
NumPendingBackups: uint32(len(session.CommittedUpdates)),
|
||||
NumBackups: uint32(ackCounts[session.ID]),
|
||||
NumPendingBackups: uint32(pendingCounts[session.ID]),
|
||||
MaxBackups: uint32(session.Policy.MaxUpdates),
|
||||
SweepSatPerVbyte: uint32(satPerVByte),
|
||||
|
||||
|
@ -83,10 +83,12 @@ type Client interface {
|
||||
|
||||
// RegisteredTowers retrieves the list of watchtowers registered with
|
||||
// the client.
|
||||
RegisteredTowers() ([]*RegisteredTower, error)
|
||||
RegisteredTowers(...wtdb.ClientSessionListOption) ([]*RegisteredTower,
|
||||
error)
|
||||
|
||||
// LookupTower retrieves a registered watchtower through its public key.
|
||||
LookupTower(*btcec.PublicKey) (*RegisteredTower, error)
|
||||
LookupTower(*btcec.PublicKey,
|
||||
...wtdb.ClientSessionListOption) (*RegisteredTower, error)
|
||||
|
||||
// Stats returns the in-memory statistics of the client since startup.
|
||||
Stats() ClientStats
|
||||
@ -287,12 +289,67 @@ func New(config *Config) (*TowerClient, error) {
|
||||
}
|
||||
plog := build.NewPrefixLog(prefix, log)
|
||||
|
||||
// Next, load all candidate towers and sessions from the database into
|
||||
// the client. We will use any of these sessions if their policies match
|
||||
// the current policy of the client, otherwise they will be ignored and
|
||||
// new sessions will be requested.
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
// registered channels.
|
||||
chanSummaries, err := cfg.DB.FetchChanSummaries()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c := &TowerClient{
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
chanCommitHeights: make(map[lnwire.ChannelID]uint64),
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(ClientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
forceQuit: make(chan struct{}),
|
||||
}
|
||||
|
||||
// perUpdate is a callback function that will be used to inspect the
|
||||
// full set of candidate client sessions loaded from disk, and to
|
||||
// determine the highest known commit height for each channel. This
|
||||
// allows the client to reject backups that it has already processed for
|
||||
// its active policy.
|
||||
perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) {
|
||||
// We only want to consider accepted updates that have been
|
||||
// accepted under an identical policy to the client's current
|
||||
// policy.
|
||||
if policy != c.cfg.Policy {
|
||||
return
|
||||
}
|
||||
|
||||
// Take the highest commit height found in the session's acked
|
||||
// updates.
|
||||
height, ok := c.chanCommitHeights[id.ChanID]
|
||||
if !ok || id.CommitHeight > height {
|
||||
c.chanCommitHeights[id.ChanID] = id.CommitHeight
|
||||
}
|
||||
}
|
||||
|
||||
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
|
||||
id wtdb.BackupID) {
|
||||
|
||||
perUpdate(s.Policy, id)
|
||||
}
|
||||
|
||||
perCommittedUpdate := func(s *wtdb.ClientSession,
|
||||
u *wtdb.CommittedUpdate) {
|
||||
|
||||
perUpdate(s.Policy, u.BackupID)
|
||||
}
|
||||
|
||||
// Load all candidate sessions and towers from the database into the
|
||||
// client. We will use any of these sessions if their policies match the
|
||||
// current policy of the client, otherwise they will be ignored and new
|
||||
// sessions will be requested.
|
||||
isAnchorClient := cfg.Policy.IsAnchorChannel()
|
||||
activeSessionFilter := genActiveSessionFilter(isAnchorClient)
|
||||
|
||||
candidateTowers := newTowerListIterator()
|
||||
perActiveTower := func(tower *wtdb.Tower) {
|
||||
// If the tower has already been marked as active, then there is
|
||||
@ -307,34 +364,19 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// Add the tower to the set of candidate towers.
|
||||
candidateTowers.AddCandidate(tower)
|
||||
}
|
||||
|
||||
candidateSessions, err := getTowerAndSessionCandidates(
|
||||
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
|
||||
wtdb.WithPerAckedUpdate(perAckedUpdate),
|
||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Load the sweep pkscripts that have been generated for all previously
|
||||
// registered channels.
|
||||
chanSummaries, err := cfg.DB.FetchChanSummaries()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.candidateTowers = candidateTowers
|
||||
c.candidateSessions = candidateSessions
|
||||
|
||||
c := &TowerClient{
|
||||
cfg: cfg,
|
||||
log: plog,
|
||||
pipeline: newTaskPipeline(plog),
|
||||
candidateTowers: candidateTowers,
|
||||
candidateSessions: candidateSessions,
|
||||
activeSessions: make(sessionQueueSet),
|
||||
summaries: chanSummaries,
|
||||
statTicker: time.NewTicker(DefaultStatInterval),
|
||||
stats: new(ClientStats),
|
||||
newTowers: make(chan *newTowerMsg),
|
||||
staleTowers: make(chan *staleTowerMsg),
|
||||
forceQuit: make(chan struct{}),
|
||||
}
|
||||
c.negotiator = newSessionNegotiator(&NegotiatorConfig{
|
||||
DB: cfg.DB,
|
||||
SecretKeyRing: cfg.SecretKeyRing,
|
||||
@ -349,10 +391,6 @@ func New(config *Config) (*TowerClient, error) {
|
||||
Log: plog,
|
||||
})
|
||||
|
||||
// Reconstruct the highest commit height processed for each channel
|
||||
// under the client's current policy.
|
||||
c.buildHighestCommitHeights()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
@ -363,7 +401,8 @@ func New(config *Config) (*TowerClient, error) {
|
||||
// tower.
|
||||
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
sessionFilter func(*wtdb.ClientSession) bool,
|
||||
perActiveTower func(tower *wtdb.Tower)) (
|
||||
perActiveTower func(tower *wtdb.Tower),
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
towers, err := db.ListTowers()
|
||||
@ -373,7 +412,7 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
|
||||
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, tower := range towers {
|
||||
sessions, err := db.ListClientSessions(&tower.ID)
|
||||
sessions, err := db.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -413,10 +452,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
|
||||
// ClientSession's SessionPrivKey field is desired, otherwise, the existing
|
||||
// ListClientSessions method should be used.
|
||||
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
passesFilter func(*wtdb.ClientSession) bool) (
|
||||
passesFilter func(*wtdb.ClientSession) bool,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
sessions, err := db.ListClientSessions(forTower)
|
||||
sessions, err := db.ListClientSessions(forTower, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -446,48 +486,10 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// buildHighestCommitHeights inspects the full set of candidate client sessions
|
||||
// loaded from disk, and determines the highest known commit height for each
|
||||
// channel. This allows the client to reject backups that it has already
|
||||
// processed for it's active policy.
|
||||
func (c *TowerClient) buildHighestCommitHeights() {
|
||||
chanCommitHeights := make(map[lnwire.ChannelID]uint64)
|
||||
for _, s := range c.candidateSessions {
|
||||
// We only want to consider accepted updates that have been
|
||||
// accepted under an identical policy to the client's current
|
||||
// policy.
|
||||
if s.Policy != c.cfg.Policy {
|
||||
continue
|
||||
}
|
||||
|
||||
// Take the highest commit height found in the session's
|
||||
// committed updates.
|
||||
for _, committedUpdate := range s.CommittedUpdates {
|
||||
bid := committedUpdate.BackupID
|
||||
|
||||
height, ok := chanCommitHeights[bid.ChanID]
|
||||
if !ok || bid.CommitHeight > height {
|
||||
chanCommitHeights[bid.ChanID] = bid.CommitHeight
|
||||
}
|
||||
}
|
||||
|
||||
// Take the heights commit height found in the session's acked
|
||||
// updates.
|
||||
for _, bid := range s.AckedUpdates {
|
||||
height, ok := chanCommitHeights[bid.ChanID]
|
||||
if !ok || bid.CommitHeight > height {
|
||||
chanCommitHeights[bid.ChanID] = bid.CommitHeight
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c.chanCommitHeights = chanCommitHeights
|
||||
}
|
||||
|
||||
// Start initializes the watchtower client by loading or negotiating an active
|
||||
// session and then begins processing backup tasks from the request pipeline.
|
||||
func (c *TowerClient) Start() error {
|
||||
var err error
|
||||
var returnErr error
|
||||
c.started.Do(func() {
|
||||
c.log.Infof("Watchtower client starting")
|
||||
|
||||
@ -496,19 +498,27 @@ func (c *TowerClient) Start() error {
|
||||
// sessions will be able to flush the committed updates after a
|
||||
// restart.
|
||||
for _, session := range c.candidateSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
committedUpdates, err := c.cfg.DB.FetchSessionCommittedUpdates(&session.ID)
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
if len(committedUpdates) > 0 {
|
||||
c.log.Infof("Starting session=%s to process "+
|
||||
"%d committed backups", session.ID,
|
||||
len(session.CommittedUpdates))
|
||||
c.initActiveQueue(session)
|
||||
len(committedUpdates))
|
||||
|
||||
c.initActiveQueue(session, committedUpdates)
|
||||
}
|
||||
}
|
||||
|
||||
// Now start the session negotiator, which will allow us to
|
||||
// request new session as soon as the backupDispatcher starts
|
||||
// up.
|
||||
err = c.negotiator.Start()
|
||||
err := c.negotiator.Start()
|
||||
if err != nil {
|
||||
returnErr = err
|
||||
return
|
||||
}
|
||||
|
||||
@ -521,7 +531,7 @@ func (c *TowerClient) Start() error {
|
||||
|
||||
c.log.Infof("Watchtower client started successfully")
|
||||
})
|
||||
return err
|
||||
return returnErr
|
||||
}
|
||||
|
||||
// Stop idempotently initiates a graceful shutdown of the watchtower client.
|
||||
@ -697,7 +707,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
|
||||
// active client's advertised policy will be ignored, but may be resumed if the
|
||||
// client is restarted with a matching policy. If no candidates were found, nil
|
||||
// is returned to signal that we need to request a new policy.
|
||||
func (c *TowerClient) nextSessionQueue() *sessionQueue {
|
||||
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
|
||||
// Select any candidate session at random, and remove it from the set of
|
||||
// candidate sessions.
|
||||
var candidateSession *wtdb.ClientSession
|
||||
@ -719,13 +729,20 @@ func (c *TowerClient) nextSessionQueue() *sessionQueue {
|
||||
// If none of the sessions could be used or none were found, we'll
|
||||
// return nil to signal that we need another session to be negotiated.
|
||||
if candidateSession == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
updates, err := c.cfg.DB.FetchSessionCommittedUpdates(
|
||||
&candidateSession.ID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize the session queue and spin it up so it can begin handling
|
||||
// updates. If the queue was already made active on startup, this will
|
||||
// simply return the existing session queue from the set.
|
||||
return c.getOrInitActiveQueue(candidateSession)
|
||||
return c.getOrInitActiveQueue(candidateSession, updates), nil
|
||||
}
|
||||
|
||||
// backupDispatcher processes events coming from the taskPipeline and is
|
||||
@ -798,7 +815,13 @@ func (c *TowerClient) backupDispatcher() {
|
||||
// We've exhausted the prior session, we'll pop another
|
||||
// from the remaining sessions and continue processing
|
||||
// backup tasks.
|
||||
c.sessionQueue = c.nextSessionQueue()
|
||||
var err error
|
||||
c.sessionQueue, err = c.nextSessionQueue()
|
||||
if err != nil {
|
||||
c.log.Errorf("error fetching next session "+
|
||||
"queue: %v", err)
|
||||
}
|
||||
|
||||
if c.sessionQueue != nil {
|
||||
c.log.Debugf("Loaded next candidate session "+
|
||||
"queue id=%s", c.sessionQueue.ID())
|
||||
@ -1046,7 +1069,9 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
|
||||
|
||||
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the
|
||||
// database and supplying it with the resources needed by the client.
|
||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
return newSessionQueue(&sessionQueueConfig{
|
||||
ClientSession: s,
|
||||
ChainHash: c.cfg.ChainHash,
|
||||
@ -1058,28 +1083,32 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
MinBackoff: c.cfg.MinBackoff,
|
||||
MaxBackoff: c.cfg.MaxBackoff,
|
||||
Log: c.log,
|
||||
})
|
||||
}, updates)
|
||||
}
|
||||
|
||||
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
|
||||
// passed ClientSession. If it exists, the active sessionQueue is returned.
|
||||
// Otherwise a new sessionQueue is initialized and added to the set.
|
||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
if sq, ok := c.activeSessions[s.ID]; ok {
|
||||
return sq
|
||||
}
|
||||
|
||||
return c.initActiveQueue(s)
|
||||
return c.initActiveQueue(s, updates)
|
||||
}
|
||||
|
||||
// initActiveQueue creates a new sessionQueue from the passed ClientSession,
|
||||
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue
|
||||
// so that it can deliver any committed updates or begin accepting newly
|
||||
// assigned tasks.
|
||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession) *sessionQueue {
|
||||
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
// Initialize the session queue, providing it with all of the resources
|
||||
// it requires from the client instance.
|
||||
sq := c.newSessionQueue(s)
|
||||
sq := c.newSessionQueue(s, updates)
|
||||
|
||||
// Add the session queue as an active session so that we remember to
|
||||
// stop it on shutdown.
|
||||
@ -1233,13 +1262,15 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||
|
||||
// RegisteredTowers retrieves the list of watchtowers registered with the
|
||||
// client.
|
||||
func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
||||
func (c *TowerClient) RegisteredTowers(opts ...wtdb.ClientSessionListOption) (
|
||||
[]*RegisteredTower, error) {
|
||||
|
||||
// Retrieve all of our towers along with all of our sessions.
|
||||
towers, err := c.cfg.DB.ListTowers()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil)
|
||||
clientSessions, err := c.cfg.DB.ListClientSessions(nil, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -1272,13 +1303,15 @@ func (c *TowerClient) RegisteredTowers() ([]*RegisteredTower, error) {
|
||||
}
|
||||
|
||||
// LookupTower retrieves a registered watchtower through its public key.
|
||||
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey) (*RegisteredTower, error) {
|
||||
func (c *TowerClient) LookupTower(pubKey *btcec.PublicKey,
|
||||
opts ...wtdb.ClientSessionListOption) (*RegisteredTower, error) {
|
||||
|
||||
tower, err := c.cfg.DB.LoadTower(pubKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
||||
towerSessions, err := c.cfg.DB.ListClientSessions(&tower.ID, opts...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -62,9 +62,14 @@ type DB interface {
|
||||
// still be able to accept state updates. An optional tower ID can be
|
||||
// used to filter out any client sessions in the response that do not
|
||||
// correspond to this tower.
|
||||
ListClientSessions(*wtdb.TowerID) (
|
||||
ListClientSessions(*wtdb.TowerID, ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error)
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked
|
||||
// updates of the given session.
|
||||
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||
[]wtdb.CommittedUpdate, error)
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to
|
||||
// their channel summaries.
|
||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||
|
@ -109,7 +109,9 @@ type sessionQueue struct {
|
||||
}
|
||||
|
||||
// newSessionQueue intiializes a fresh sessionQueue.
|
||||
func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
func newSessionQueue(cfg *sessionQueueConfig,
|
||||
updates []wtdb.CommittedUpdate) *sessionQueue {
|
||||
|
||||
localInit := wtwire.NewInitMessage(
|
||||
lnwire.NewRawFeatureVector(wtwire.AltruistSessionsRequired),
|
||||
cfg.ChainHash,
|
||||
@ -137,7 +139,7 @@ func newSessionQueue(cfg *sessionQueueConfig) *sessionQueue {
|
||||
|
||||
// The database should return them in sorted order, and session queue's
|
||||
// sequence number will be equal to that of the last committed update.
|
||||
for _, update := range sq.cfg.ClientSession.CommittedUpdates {
|
||||
for _, update := range updates {
|
||||
sq.commitQueue.PushBack(update)
|
||||
}
|
||||
|
||||
|
@ -420,8 +420,17 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
towerID := TowerIDFromBytes(towerIDBytes)
|
||||
|
||||
committedUpdateCount := make(map[SessionID]uint16)
|
||||
perCommittedUpdate := func(s *ClientSession,
|
||||
_ *CommittedUpdate) {
|
||||
|
||||
committedUpdateCount[s.ID]++
|
||||
}
|
||||
|
||||
towerSessions, err := listTowerSessions(
|
||||
towerID, sessions, towers, towersToSessionsIndex,
|
||||
WithPerCommittedUpdate(perCommittedUpdate),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -447,7 +456,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
// have any pending updates to ensure we don't load them upon
|
||||
// restarts.
|
||||
for _, session := range towerSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
if committedUpdateCount[session.ID] > 0 {
|
||||
return ErrTowerUnackedUpdates
|
||||
}
|
||||
err := markSessionStatus(
|
||||
@ -736,8 +745,8 @@ func getSessionKeyIndex(keyIndexes kvdb.RwBucket, towerID TowerID,
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
func (c *ClientDB) ListClientSessions(id *TowerID,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
var clientSessions map[SessionID]*ClientSession
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
@ -757,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
// known to the db.
|
||||
if id == nil {
|
||||
clientSessions, err = listClientAllSessions(
|
||||
sessions, towers,
|
||||
sessions, towers, opts...,
|
||||
)
|
||||
return err
|
||||
}
|
||||
@ -769,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
clientSessions, err = listTowerSessions(
|
||||
*id, sessions, towers, towerToSessionIndex,
|
||||
*id, sessions, towers, towerToSessionIndex, opts...,
|
||||
)
|
||||
return err
|
||||
}, func() {
|
||||
@ -783,8 +792,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID) (
|
||||
}
|
||||
|
||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||
func listClientAllSessions(sessions,
|
||||
towers kvdb.RBucket) (map[SessionID]*ClientSession, error) {
|
||||
func listClientAllSessions(sessions, towers kvdb.RBucket,
|
||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||
|
||||
clientSessions := make(map[SessionID]*ClientSession)
|
||||
err := sessions.ForEach(func(k, _ []byte) error {
|
||||
@ -792,7 +801,7 @@ func listClientAllSessions(sessions,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessions, towers, k)
|
||||
session, err := getClientSession(sessions, towers, k, opts...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -811,8 +820,8 @@ func listClientAllSessions(sessions,
|
||||
// listTowerSessions returns the set of all client sessions known to the db
|
||||
// that are associated with the given tower id.
|
||||
func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
towerToSessionIndex kvdb.RBucket) (map[SessionID]*ClientSession,
|
||||
error) {
|
||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||
map[SessionID]*ClientSession, error) {
|
||||
|
||||
towerIndexBkt := towerToSessionIndex.NestedReadBucket(id.Bytes())
|
||||
if towerIndexBkt == nil {
|
||||
@ -825,7 +834,9 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||
// committed updates and compute the highest known commit height
|
||||
// for each channel.
|
||||
session, err := getClientSession(sessionsBkt, towersBkt, k)
|
||||
session, err := getClientSession(
|
||||
sessionsBkt, towersBkt, k, opts...,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -840,6 +851,36 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
|
||||
return clientSessions, nil
|
||||
}
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
|
||||
// of the given session.
|
||||
func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) (
|
||||
[]CommittedUpdate, error) {
|
||||
|
||||
var committedUpdates []CommittedUpdate
|
||||
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||
sessions := tx.ReadBucket(cSessionBkt)
|
||||
if sessions == nil {
|
||||
return ErrUninitializedDB
|
||||
}
|
||||
|
||||
sessionBkt := sessions.NestedReadBucket(id[:])
|
||||
if sessionBkt == nil {
|
||||
return ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
var err error
|
||||
committedUpdates, err = getClientSessionCommits(
|
||||
sessionBkt, nil, nil,
|
||||
)
|
||||
return err
|
||||
}, func() {})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries.
|
||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||
@ -1157,11 +1198,63 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// PerAckedUpdateCB describes the signature of a callback function that can be
|
||||
// called for each of a session's acked updates.
|
||||
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
||||
|
||||
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||
// be called for each of a session's committed updates (updates that the client
|
||||
// has not yet received an ACK for).
|
||||
type PerCommittedUpdateCB func(*ClientSession, *CommittedUpdate)
|
||||
|
||||
// ClientSessionListOption describes the signature of a functional option that
|
||||
// can be used when listing client sessions in order to provide any extra
|
||||
// instruction to the query.
|
||||
type ClientSessionListOption func(cfg *ClientSessionListCfg)
|
||||
|
||||
// ClientSessionListCfg defines various query parameters that will be used when
|
||||
// querying the DB for client sessions.
|
||||
type ClientSessionListCfg struct {
|
||||
// PerAckedUpdate will, if set, be called for each of the session's
|
||||
// acked updates.
|
||||
PerAckedUpdate PerAckedUpdateCB
|
||||
|
||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||
// committed (un-acked) updates.
|
||||
PerCommittedUpdate PerCommittedUpdateCB
|
||||
}
|
||||
|
||||
// NewClientSessionCfg constructs a new ClientSessionListCfg.
|
||||
func NewClientSessionCfg() *ClientSessionListCfg {
|
||||
return &ClientSessionListCfg{}
|
||||
}
|
||||
|
||||
// WithPerAckedUpdate constructs a functional option that will set a call-back
|
||||
// function to be called for each of a client's acked updates.
|
||||
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerAckedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// WithPerCommittedUpdate constructs a functional option that will set a
|
||||
// call-back function to be called for each of a client's un-acked updates.
|
||||
func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
||||
return func(cfg *ClientSessionListCfg) {
|
||||
cfg.PerCommittedUpdate = cb
|
||||
}
|
||||
}
|
||||
|
||||
// getClientSession loads the full ClientSession associated with the serialized
|
||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||
// in addition to the ClientSession's body.
|
||||
func getClientSession(sessions, towers kvdb.RBucket,
|
||||
idBytes []byte) (*ClientSession, error) {
|
||||
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
|
||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
||||
|
||||
cfg := NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
session, err := getClientSessionBody(sessions, idBytes)
|
||||
if err != nil {
|
||||
@ -1173,35 +1266,37 @@ func getClientSession(sessions, towers kvdb.RBucket,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the committed updates for this session.
|
||||
commitedUpdates, err := getClientSessionCommits(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fetch the acked updates for this session.
|
||||
ackedUpdates, err := getClientSessionAcks(sessions, idBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
session.Tower = tower
|
||||
session.CommittedUpdates = commitedUpdates
|
||||
session.AckedUpdates = ackedUpdates
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Pass the session's committed (un-acked) updates through the call-back
|
||||
// if one is provided.
|
||||
err = filterClientSessionCommits(
|
||||
sessionBkt, session, cfg.PerCommittedUpdate,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Pass the session's acked updates through the call-back if one is
|
||||
// provided.
|
||||
err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return session, nil
|
||||
}
|
||||
|
||||
// getClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id.
|
||||
func getClientSessionCommits(sessions kvdb.RBucket,
|
||||
idBytes []byte) ([]CommittedUpdate, error) {
|
||||
// identified by the serialized session id. If a PerCommittedUpdateCB is
|
||||
// provided, then it will be called for each of the session's committed updates.
|
||||
func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerCommittedUpdateCB) ([]CommittedUpdate, error) {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Initialize commitedUpdates so that we can return an initialized map
|
||||
// Initialize committedUpdates so that we can return an initialized map
|
||||
// if no committed updates exist.
|
||||
committedUpdates := make([]CommittedUpdate, 0)
|
||||
|
||||
@ -1220,6 +1315,10 @@ func getClientSessionCommits(sessions kvdb.RBucket,
|
||||
|
||||
committedUpdates = append(committedUpdates, committedUpdate)
|
||||
|
||||
if cb != nil {
|
||||
cb(s, &committedUpdate)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
@ -1229,21 +1328,19 @@ func getClientSessionCommits(sessions kvdb.RBucket,
|
||||
return committedUpdates, nil
|
||||
}
|
||||
|
||||
// getClientSessionAcks retrieves all acked updates for the session identified
|
||||
// by the serialized session id.
|
||||
func getClientSessionAcks(sessions kvdb.RBucket,
|
||||
idBytes []byte) (map[uint16]BackupID, error) {
|
||||
// filterClientSessionAcks retrieves all acked updates for the session
|
||||
// identified by the serialized session id and passes them to the provided
|
||||
// call back if one is provided.
|
||||
func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerAckedUpdateCB) error {
|
||||
|
||||
// Can't fail because client session body has already been read.
|
||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
||||
|
||||
// Initialize ackedUpdates so that we can return an initialized map if
|
||||
// no acked updates exist.
|
||||
ackedUpdates := make(map[uint16]BackupID)
|
||||
if cb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
|
||||
if sessionAcks == nil {
|
||||
return ackedUpdates, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
err := sessionAcks.ForEach(func(k, v []byte) error {
|
||||
@ -1255,15 +1352,47 @@ func getClientSessionAcks(sessions kvdb.RBucket,
|
||||
return err
|
||||
}
|
||||
|
||||
ackedUpdates[seqNum] = backupID
|
||||
|
||||
cb(s, seqNum, backupID)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return ackedUpdates, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// filterClientSessionCommits retrieves all committed updates for the session
|
||||
// identified by the serialized session id and passes them to the given
|
||||
// PerCommittedUpdateCB callback.
|
||||
func filterClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
||||
cb PerCommittedUpdateCB) error {
|
||||
|
||||
if cb == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionCommits := sessionBkt.NestedReadBucket(cSessionCommits)
|
||||
if sessionCommits == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := sessionCommits.ForEach(func(k, v []byte) error {
|
||||
var committedUpdate CommittedUpdate
|
||||
err := committedUpdate.Decode(bytes.NewReader(v))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
committedUpdate.SeqNum = byteOrder.Uint16(k)
|
||||
|
||||
cb(s, &committedUpdate)
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// putClientSessionBody stores the body of the ClientSession (everything but the
|
||||
|
@ -48,12 +48,12 @@ func (h *clientDBHarness) insertSession(session *wtdb.ClientSession,
|
||||
require.ErrorIs(h.t, err, expErr)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) listSessions(
|
||||
id *wtdb.TowerID) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
func (h *clientDBHarness) listSessions(id *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) map[wtdb.SessionID]*wtdb.ClientSession {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
sessions, err := h.db.ListClientSessions(id)
|
||||
sessions, err := h.db.ListClientSessions(id, opts...)
|
||||
require.NoError(h.t, err, "unable to list client sessions")
|
||||
|
||||
return sessions
|
||||
@ -207,6 +207,20 @@ func (h *clientDBHarness) newTower() *wtdb.Tower {
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID,
|
||||
expErr error) []wtdb.CommittedUpdate {
|
||||
|
||||
h.t.Helper()
|
||||
|
||||
updates, err := h.db.FetchSessionCommittedUpdates(id)
|
||||
if err != expErr {
|
||||
h.t.Fatalf("expected fetch session committed updates error: "+
|
||||
"%v, got: %v", expErr, err)
|
||||
}
|
||||
|
||||
return updates
|
||||
}
|
||||
|
||||
// testCreateClientSession asserts various conditions regarding the creation of
|
||||
// a new ClientSession. The test asserts:
|
||||
// - client sessions can only be created if a session key index is reserved.
|
||||
@ -506,6 +520,9 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// session, which should fail.
|
||||
update1 := randCommittedUpdate(h.t, 1)
|
||||
h.commitUpdate(&session.ID, update1, wtdb.ErrClientSessionNotFound)
|
||||
h.fetchSessionCommittedUpdates(
|
||||
&session.ID, wtdb.ErrClientSessionNotFound,
|
||||
)
|
||||
|
||||
// Reserve a session key index and insert the session.
|
||||
session.KeyIndex = h.nextKeyIndex(session.TowerID, blobType)
|
||||
@ -520,11 +537,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
// Assert that the committed update appears in the client session's
|
||||
// CommittedUpdates map when loaded from disk and that there are no
|
||||
// AckedUpdates.
|
||||
dbSession := h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
|
||||
|
||||
// Try to commit the same update, which should succeed due to
|
||||
// idempotency (which is preserved when the breach hint is identical to
|
||||
@ -534,11 +547,7 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
require.Equal(h.t, lastApplied, lastApplied2)
|
||||
|
||||
// Assert that the loaded ClientSession is the same as before.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
h.assertUpdates(session.ID, []wtdb.CommittedUpdate{*update1}, nil)
|
||||
|
||||
// Generate another random update and try to commit it at the identical
|
||||
// sequence number. Since the breach hint has changed, this should fail.
|
||||
@ -553,12 +562,10 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
|
||||
// Check that both updates now appear as committed on the ClientSession
|
||||
// loaded from disk.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
}, nil)
|
||||
|
||||
// Finally, create one more random update and try to commit it at index
|
||||
// 4, which should be rejected since 3 is the next slot the database
|
||||
@ -567,12 +574,20 @@ func testCommitUpdate(h *clientDBHarness) {
|
||||
h.commitUpdate(&session.ID, update4, wtdb.ErrCommitUnorderedUpdate)
|
||||
|
||||
// Assert that the ClientSession loaded from disk remains unchanged.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, []wtdb.CommittedUpdate{
|
||||
h.assertUpdates(session.ID, []wtdb.CommittedUpdate{
|
||||
*update1,
|
||||
*update2,
|
||||
})
|
||||
checkAckedUpdates(h.t, dbSession, nil)
|
||||
}, nil)
|
||||
}
|
||||
|
||||
func perAckedUpdate(updates map[uint16]wtdb.BackupID) func(
|
||||
_ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) {
|
||||
|
||||
return func(_ *wtdb.ClientSession, seq uint16,
|
||||
id wtdb.BackupID) {
|
||||
|
||||
updates[seq] = id
|
||||
}
|
||||
}
|
||||
|
||||
// testAckUpdate asserts the behavior of AckUpdate.
|
||||
@ -628,9 +643,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
|
||||
// Assert that the ClientSession loaded from disk has one update in it's
|
||||
// AckedUpdates map, and that the committed update has been removed.
|
||||
dbSession := h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
})
|
||||
|
||||
@ -645,9 +658,7 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
h.ackUpdate(&session.ID, 2, 2, nil)
|
||||
|
||||
// Assert that both updates exist as AckedUpdates when loaded from disk.
|
||||
dbSession = h.listSessions(nil)[session.ID]
|
||||
checkCommittedUpdates(h.t, dbSession, nil)
|
||||
checkAckedUpdates(h.t, dbSession, map[uint16]wtdb.BackupID{
|
||||
h.assertUpdates(session.ID, nil, map[uint16]wtdb.BackupID{
|
||||
1: update1.BackupID,
|
||||
2: update2.BackupID,
|
||||
})
|
||||
@ -663,9 +674,22 @@ func testAckUpdate(h *clientDBHarness) {
|
||||
h.ackUpdate(&session.ID, 4, 3, wtdb.ErrUnallocatedLastApplied)
|
||||
}
|
||||
|
||||
func (h *clientDBHarness) assertUpdates(id wtdb.SessionID,
|
||||
expectedPending []wtdb.CommittedUpdate,
|
||||
expectedAcked map[uint16]wtdb.BackupID) {
|
||||
|
||||
ackedUpdates := make(map[uint16]wtdb.BackupID)
|
||||
_ = h.listSessions(
|
||||
nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)),
|
||||
)
|
||||
committedUpates := h.fetchSessionCommittedUpdates(&id, nil)
|
||||
checkCommittedUpdates(h.t, committedUpates, expectedPending)
|
||||
checkAckedUpdates(h.t, ackedUpdates, expectedAcked)
|
||||
}
|
||||
|
||||
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
|
||||
// expUpdates provided.
|
||||
func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
func checkCommittedUpdates(t *testing.T, actualUpdates,
|
||||
expUpdates []wtdb.CommittedUpdate) {
|
||||
|
||||
t.Helper()
|
||||
@ -677,12 +701,12 @@ func checkCommittedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make([]wtdb.CommittedUpdate, 0)
|
||||
}
|
||||
|
||||
require.Equal(t, expUpdates, session.CommittedUpdates)
|
||||
require.Equal(t, expUpdates, actualUpdates)
|
||||
}
|
||||
|
||||
// checkAckedUpdates asserts that the AckedUpdates on a session match the
|
||||
// expUpdates provided.
|
||||
func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
func checkAckedUpdates(t *testing.T, actualUpdates,
|
||||
expUpdates map[uint16]wtdb.BackupID) {
|
||||
|
||||
// We promote nil expUpdates to an initialized map since the database
|
||||
@ -692,7 +716,7 @@ func checkAckedUpdates(t *testing.T, session *wtdb.ClientSession,
|
||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
||||
}
|
||||
|
||||
require.Equal(t, expUpdates, session.AckedUpdates)
|
||||
require.Equal(t, expUpdates, actualUpdates)
|
||||
}
|
||||
|
||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||
|
@ -37,23 +37,6 @@ type ClientSession struct {
|
||||
|
||||
ClientSessionBody
|
||||
|
||||
// CommittedUpdates is a sorted list of unacked updates. These updates
|
||||
// can be resent after a restart if the updates failed to send or
|
||||
// receive an acknowledgment.
|
||||
//
|
||||
// NOTE: This list is serialized in it's own bucket, separate from the
|
||||
// body of the ClientSession. The representation on disk is a key value
|
||||
// map from sequence number to CommittedUpdateBody to allow efficient
|
||||
// insertion and retrieval.
|
||||
CommittedUpdates []CommittedUpdate
|
||||
|
||||
// AckedUpdates is a map from sequence number to backup id to record
|
||||
// which revoked states were uploaded via this session.
|
||||
//
|
||||
// NOTE: This map is serialized in it's own bucket, separate from the
|
||||
// body of the ClientSession.
|
||||
AckedUpdates map[uint16]BackupID
|
||||
|
||||
// Tower holds the pubkey and address of the watchtower.
|
||||
//
|
||||
// NOTE: This value is not serialized. It is recovered by looking up the
|
||||
|
@ -23,11 +23,13 @@ type keyIndexKey struct {
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
mu sync.Mutex
|
||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
|
||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
@ -37,12 +39,14 @@ type ClientDB struct {
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
||||
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
|
||||
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
}
|
||||
}
|
||||
|
||||
@ -75,7 +79,7 @@ func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
} else {
|
||||
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
||||
tower = &wtdb.Tower{
|
||||
ID: wtdb.TowerID(towerID),
|
||||
ID: towerID,
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
@ -129,7 +133,7 @@ func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
}
|
||||
|
||||
for id, session := range towerSessions {
|
||||
if len(session.CommittedUpdates) > 0 {
|
||||
if len(m.committedUpdates[session.ID]) > 0 {
|
||||
return wtdb.ErrTowerUnackedUpdates
|
||||
}
|
||||
session.Status = wtdb.CSessionInactive
|
||||
@ -193,26 +197,33 @@ func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
|
||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||
// backup. This allows the client to track which updates it should not attempt
|
||||
// to retry after startup.
|
||||
func (m *ClientDB) MarkBackupIneligible(chanID lnwire.ChannelID, commitHeight uint64) error {
|
||||
func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) ListClientSessions(
|
||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.listClientSessions(tower)
|
||||
return m.listClientSessions(tower, opts...)
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) listClientSessions(
|
||||
tower *wtdb.TowerID) (map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
cfg := wtdb.NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, session := range m.activeSessions {
|
||||
@ -222,11 +233,40 @@ func (m *ClientDB) listClientSessions(
|
||||
}
|
||||
session.Tower = m.towers[session.TowerID]
|
||||
sessions[session.ID] = &session
|
||||
|
||||
if cfg.PerAckedUpdate != nil {
|
||||
for seq, id := range m.ackedUpdates[session.ID] {
|
||||
cfg.PerAckedUpdate(&session, seq, id)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerCommittedUpdate != nil {
|
||||
for _, update := range m.committedUpdates[session.ID] {
|
||||
update := update
|
||||
cfg.PerCommittedUpdate(&session, &update)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
|
||||
// of the given session.
|
||||
func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||
[]wtdb.CommittedUpdate, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
updates, ok := m.committedUpdates[*id]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
// active sessions. The session can be identified by its SessionID.
|
||||
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
@ -271,9 +311,9 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
Policy: session.Policy,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
},
|
||||
CommittedUpdates: make([]wtdb.CommittedUpdate, 0),
|
||||
AckedUpdates: make(map[uint16]wtdb.BackupID),
|
||||
}
|
||||
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
|
||||
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -334,7 +374,7 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
}
|
||||
|
||||
// Check if an update has already been committed for this state.
|
||||
for _, dbUpdate := range session.CommittedUpdates {
|
||||
for _, dbUpdate := range m.committedUpdates[session.ID] {
|
||||
if dbUpdate.SeqNum == update.SeqNum {
|
||||
// If the breach hint matches, we'll just return the
|
||||
// last applied value so the client can retransmit.
|
||||
@ -353,7 +393,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
}
|
||||
|
||||
// Save the update and increment the sequence number.
|
||||
session.CommittedUpdates = append(session.CommittedUpdates, *update)
|
||||
m.committedUpdates[session.ID] = append(
|
||||
m.committedUpdates[session.ID], *update,
|
||||
)
|
||||
session.SeqNum++
|
||||
m.activeSessions[*id] = session
|
||||
|
||||
@ -363,7 +405,9 @@ func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) error {
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
||||
lastApplied uint16) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
@ -387,7 +431,7 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
updates := session.CommittedUpdates
|
||||
updates := m.committedUpdates[session.ID]
|
||||
for i, update := range updates {
|
||||
if update.SeqNum != seqNum {
|
||||
continue
|
||||
@ -398,9 +442,9 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum, lastApplied uint16) err
|
||||
// along with the next update.
|
||||
copy(updates[:i], updates[i+1:])
|
||||
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
||||
session.CommittedUpdates = updates[:len(updates)-1]
|
||||
m.committedUpdates[session.ID] = updates[:len(updates)-1]
|
||||
|
||||
session.AckedUpdates[seqNum] = update.BackupID
|
||||
m.ackedUpdates[*id][seqNum] = update.BackupID
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
m.activeSessions[*id] = session
|
||||
|
Loading…
x
Reference in New Issue
Block a user