mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-06-28 17:53:30 +02:00
multi: migrate towers to use RangeIndex for AckedUpdates
In this commit, a migration is done that takes all the AckedUpdates of all sessions and stores them in the RangeIndex pattern instead and deletes the session's old AckedUpdates bucket. All the logic in the code is also updates in order to write and read from this new structure.
This commit is contained in:
parent
50ad10666c
commit
c3a2368f46
@ -341,27 +341,27 @@ func constructFunctionalOptions(includeSessions bool) (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
opts []wtdb.ClientSessionListOption
|
opts []wtdb.ClientSessionListOption
|
||||||
ackCounts = make(map[wtdb.SessionID]uint16)
|
|
||||||
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
|
committedUpdateCounts = make(map[wtdb.SessionID]uint16)
|
||||||
|
ackCounts = make(map[wtdb.SessionID]uint16)
|
||||||
)
|
)
|
||||||
if !includeSessions {
|
if !includeSessions {
|
||||||
return opts, ackCounts, committedUpdateCounts
|
return opts, ackCounts, committedUpdateCounts
|
||||||
}
|
}
|
||||||
|
|
||||||
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
|
perNumAckedUpdates := func(s *wtdb.ClientSession, id lnwire.ChannelID,
|
||||||
_ wtdb.BackupID) {
|
numUpdates uint16) {
|
||||||
|
|
||||||
ackCounts[s.ID]++
|
ackCounts[s.ID] += numUpdates
|
||||||
}
|
}
|
||||||
|
|
||||||
perCommittedUpdate := func(s *wtdb.ClientSession,
|
perCommittedUpdate := func(s *wtdb.ClientSession,
|
||||||
_ *wtdb.CommittedUpdate) {
|
u *wtdb.CommittedUpdate) {
|
||||||
|
|
||||||
committedUpdateCounts[s.ID]++
|
committedUpdateCounts[s.ID]++
|
||||||
}
|
}
|
||||||
|
|
||||||
opts = []wtdb.ClientSessionListOption{
|
opts = []wtdb.ClientSessionListOption{
|
||||||
wtdb.WithPerAckedUpdate(perAckedUpdate),
|
wtdb.WithPerNumAckedUpdates(perNumAckedUpdates),
|
||||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -438,7 +438,8 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
|
|||||||
// marshallTower converts a client registered watchtower into its corresponding
|
// marshallTower converts a client registered watchtower into its corresponding
|
||||||
// RPC type.
|
// RPC type.
|
||||||
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
|
func marshallTower(tower *wtclient.RegisteredTower, includeSessions bool,
|
||||||
ackCounts, pendingCounts map[wtdb.SessionID]uint16) *Tower {
|
ackCounts map[wtdb.SessionID]uint16,
|
||||||
|
pendingCounts map[wtdb.SessionID]uint16) *Tower {
|
||||||
|
|
||||||
rpcAddrs := make([]string, 0, len(tower.Addresses))
|
rpcAddrs := make([]string, 0, len(tower.Addresses))
|
||||||
for _, addr := range tower.Addresses {
|
for _, addr := range tower.Addresses {
|
||||||
|
@ -4,6 +4,8 @@ import (
|
|||||||
"github.com/btcsuite/btclog"
|
"github.com/btcsuite/btclog"
|
||||||
"github.com/lightningnetwork/lnd/build"
|
"github.com/lightningnetwork/lnd/build"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/lookout"
|
"github.com/lightningnetwork/lnd/watchtower/lookout"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
"github.com/lightningnetwork/lnd/watchtower/wtserver"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -30,4 +32,6 @@ func UseLogger(logger btclog.Logger) {
|
|||||||
log = logger
|
log = logger
|
||||||
lookout.UseLogger(logger)
|
lookout.UseLogger(logger)
|
||||||
wtserver.UseLogger(logger)
|
wtserver.UseLogger(logger)
|
||||||
|
wtclient.UseLogger(logger)
|
||||||
|
wtdb.UseLogger(logger)
|
||||||
}
|
}
|
||||||
|
@ -314,7 +314,9 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
// determine the highest known commit height for each channel. This
|
// determine the highest known commit height for each channel. This
|
||||||
// allows the client to reject backups that it has already processed for
|
// allows the client to reject backups that it has already processed for
|
||||||
// its active policy.
|
// its active policy.
|
||||||
perUpdate := func(policy wtpolicy.Policy, id wtdb.BackupID) {
|
perUpdate := func(policy wtpolicy.Policy, chanID lnwire.ChannelID,
|
||||||
|
commitHeight uint64) {
|
||||||
|
|
||||||
// We only want to consider accepted updates that have been
|
// We only want to consider accepted updates that have been
|
||||||
// accepted under an identical policy to the client's current
|
// accepted under an identical policy to the client's current
|
||||||
// policy.
|
// policy.
|
||||||
@ -324,22 +326,22 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
|
|
||||||
// Take the highest commit height found in the session's acked
|
// Take the highest commit height found in the session's acked
|
||||||
// updates.
|
// updates.
|
||||||
height, ok := c.chanCommitHeights[id.ChanID]
|
height, ok := c.chanCommitHeights[chanID]
|
||||||
if !ok || id.CommitHeight > height {
|
if !ok || commitHeight > height {
|
||||||
c.chanCommitHeights[id.ChanID] = id.CommitHeight
|
c.chanCommitHeights[chanID] = commitHeight
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
perAckedUpdate := func(s *wtdb.ClientSession, _ uint16,
|
perMaxHeight := func(s *wtdb.ClientSession, chanID lnwire.ChannelID,
|
||||||
id wtdb.BackupID) {
|
height uint64) {
|
||||||
|
|
||||||
perUpdate(s.Policy, id)
|
perUpdate(s.Policy, chanID, height)
|
||||||
}
|
}
|
||||||
|
|
||||||
perCommittedUpdate := func(s *wtdb.ClientSession,
|
perCommittedUpdate := func(s *wtdb.ClientSession,
|
||||||
u *wtdb.CommittedUpdate) {
|
u *wtdb.CommittedUpdate) {
|
||||||
|
|
||||||
perUpdate(s.Policy, u.BackupID)
|
perUpdate(s.Policy, u.BackupID.ChanID, u.BackupID.CommitHeight)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Load all candidate sessions and towers from the database into the
|
// Load all candidate sessions and towers from the database into the
|
||||||
@ -366,7 +368,7 @@ func New(config *Config) (*TowerClient, error) {
|
|||||||
|
|
||||||
candidateSessions, err := getTowerAndSessionCandidates(
|
candidateSessions, err := getTowerAndSessionCandidates(
|
||||||
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
|
cfg.DB, cfg.SecretKeyRing, activeSessionFilter, perActiveTower,
|
||||||
wtdb.WithPerAckedUpdate(perAckedUpdate),
|
wtdb.WithPerMaxHeight(perMaxHeight),
|
||||||
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
wtdb.WithPerCommittedUpdate(perCommittedUpdate),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -68,6 +68,14 @@ type DB interface {
|
|||||||
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||||
[]wtdb.CommittedUpdate, error)
|
[]wtdb.CommittedUpdate, error)
|
||||||
|
|
||||||
|
// IsAcked returns true if the given backup has been backed up using
|
||||||
|
// the given session.
|
||||||
|
IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool, error)
|
||||||
|
|
||||||
|
// NumAckedUpdates returns the number of backups that have been
|
||||||
|
// successfully backed up using the given session.
|
||||||
|
NumAckedUpdates(id *wtdb.SessionID) (uint64, error)
|
||||||
|
|
||||||
// FetchChanSummaries loads a mapping from all registered channels to
|
// FetchChanSummaries loads a mapping from all registered channels to
|
||||||
// their channel summaries.
|
// their channel summaries.
|
||||||
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
FetchChanSummaries() (wtdb.ChannelSummaries, error)
|
||||||
|
@ -36,7 +36,7 @@ var (
|
|||||||
// cSessionBkt is a top-level bucket storing:
|
// cSessionBkt is a top-level bucket storing:
|
||||||
// session-id => cSessionBody -> encoded ClientSessionBody
|
// session-id => cSessionBody -> encoded ClientSessionBody
|
||||||
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
// => cSessionCommits => seqnum -> encoded CommittedUpdate
|
||||||
// => cSessionAcks => seqnum -> encoded BackupID
|
// => cSessionAckRangeIndex => db-chan-id => start -> end
|
||||||
cSessionBkt = []byte("client-session-bucket")
|
cSessionBkt = []byte("client-session-bucket")
|
||||||
|
|
||||||
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
// cSessionBody is a sub-bucket of cSessionBkt storing only the body of
|
||||||
@ -47,9 +47,9 @@ var (
|
|||||||
// seqnum -> encoded CommittedUpdate.
|
// seqnum -> encoded CommittedUpdate.
|
||||||
cSessionCommits = []byte("client-session-commits")
|
cSessionCommits = []byte("client-session-commits")
|
||||||
|
|
||||||
// cSessionAcks is a sub-bucket of cSessionBkt storing:
|
// cSessionAckRangeIndex is a sub-bucket of cSessionBkt storing
|
||||||
// seqnum -> encoded BackupID.
|
// chan-id => start -> end
|
||||||
cSessionAcks = []byte("client-session-acks")
|
cSessionAckRangeIndex = []byte("client-session-ack-range-index")
|
||||||
|
|
||||||
// cChanIDIndexBkt is a top-level bucket storing:
|
// cChanIDIndexBkt is a top-level bucket storing:
|
||||||
// db-assigned-id -> channel-ID
|
// db-assigned-id -> channel-ID
|
||||||
@ -422,6 +422,11 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
|||||||
return ErrUninitializedDB
|
return ErrUninitializedDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||||
|
if chanIDIndexBkt == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
// Don't return an error if the watchtower doesn't exist to act
|
// Don't return an error if the watchtower doesn't exist to act
|
||||||
// as a NOP.
|
// as a NOP.
|
||||||
pubKeyBytes := pubKey.SerializeCompressed()
|
pubKeyBytes := pubKey.SerializeCompressed()
|
||||||
@ -463,7 +468,8 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
towerSessions, err := c.listTowerSessions(
|
towerSessions, err := c.listTowerSessions(
|
||||||
towerID, sessions, towersToSessionsIndex,
|
towerID, sessions, chanIDIndexBkt,
|
||||||
|
towersToSessionsIndex,
|
||||||
WithPerCommittedUpdate(perCommittedUpdate),
|
WithPerCommittedUpdate(perCommittedUpdate),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -763,6 +769,149 @@ func readRangeIndex(rangesBkt kvdb.RBucket) (*RangeIndex, error) {
|
|||||||
return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize))
|
return NewRangeIndex(ranges, WithSerializeUint64Fn(writeBigSize))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getRangeIndex checks the ClientDB's in-memory range index map to see if it
|
||||||
|
// has an entry for the given session and channel ID. If it does, this is
|
||||||
|
// returned, otherwise the range index is loaded from the DB. An optional db
|
||||||
|
// transaction parameter may be provided. If one is provided then it will be
|
||||||
|
// used to query the DB for the range index, otherwise, a new transaction will
|
||||||
|
// be created and used.
|
||||||
|
func (c *ClientDB) getRangeIndex(tx kvdb.RTx, sID SessionID,
|
||||||
|
chanID lnwire.ChannelID) (*RangeIndex, error) {
|
||||||
|
|
||||||
|
c.ackedRangeIndexMu.Lock()
|
||||||
|
defer c.ackedRangeIndexMu.Unlock()
|
||||||
|
|
||||||
|
if _, ok := c.ackedRangeIndex[sID]; !ok {
|
||||||
|
c.ackedRangeIndex[sID] = make(map[lnwire.ChannelID]*RangeIndex)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the in-memory range-index map already includes an entry for this
|
||||||
|
// session ID and channel ID pair, then return it.
|
||||||
|
if index, ok := c.ackedRangeIndex[sID][chanID]; ok {
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// readRangeIndexFromBkt is a helper that is used to read in a
|
||||||
|
// RangeIndex structure from the passed in bucket and store it in the
|
||||||
|
// ackedRangeIndex map.
|
||||||
|
readRangeIndexFromBkt := func(rangesBkt kvdb.RBucket) (*RangeIndex,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
// Create a new in-memory RangeIndex by reading in ranges from
|
||||||
|
// the DB.
|
||||||
|
rangeIndex, err := readRangeIndex(rangesBkt)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.ackedRangeIndex[sID][chanID] = rangeIndex
|
||||||
|
|
||||||
|
return rangeIndex, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If a DB transaction is provided then use it to fetch the ranges
|
||||||
|
// bucket from the DB.
|
||||||
|
if tx != nil {
|
||||||
|
rangesBkt, err := getRangesReadBucket(tx, sID, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return readRangeIndexFromBkt(rangesBkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// No DB transaction was provided. So create and use a new one.
|
||||||
|
var index *RangeIndex
|
||||||
|
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||||
|
rangesBkt, err := getRangesReadBucket(tx, sID, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
index, err = readRangeIndexFromBkt(rangesBkt)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}, func() {})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return index, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRangesReadBucket gets the range index bucket where the range index for the
|
||||||
|
// given session-channel pair is stored. If any sub-buckets along the way do not
|
||||||
|
// exist, then an error is returned. If the sub-buckets should be created
|
||||||
|
// instead, then use getRangesWriteBucket.
|
||||||
|
func getRangesReadBucket(tx kvdb.RTx, sID SessionID, chanID lnwire.ChannelID) (
|
||||||
|
kvdb.RBucket, error) {
|
||||||
|
|
||||||
|
sessions := tx.ReadBucket(cSessionBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return nil, ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||||
|
if chanDetailsBkt == nil {
|
||||||
|
return nil, ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionBkt := sessions.NestedReadBucket(sID[:])
|
||||||
|
if sessionsBkt == nil {
|
||||||
|
return nil, ErrNoRangeIndexFound
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the DB representation of the channel-ID.
|
||||||
|
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionAckRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||||
|
if sessionAckRanges == nil {
|
||||||
|
return nil, ErrNoRangeIndexFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionAckRanges.NestedReadBucket(dbChanIDBytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getRangesWriteBucket gets the range index bucket where the range index for
|
||||||
|
// the given session-channel pair is stored. If any sub-buckets along the way do
|
||||||
|
// not exist, then they are created.
|
||||||
|
func getRangesWriteBucket(tx kvdb.RwTx, sID SessionID,
|
||||||
|
chanID lnwire.ChannelID) (kvdb.RwBucket, error) {
|
||||||
|
|
||||||
|
sessions := tx.ReadWriteBucket(cSessionBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return nil, ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||||
|
if chanDetailsBkt == nil {
|
||||||
|
return nil, ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionBkt, err := sessions.CreateBucketIfNotExists(sID[:])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the DB representation of the channel-ID.
|
||||||
|
_, dbChanIDBytes, err := getDBChanID(chanDetailsBkt, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionAckRanges, err := sessionBkt.CreateBucketIfNotExists(
|
||||||
|
cSessionAckRangeIndex,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return sessionAckRanges.CreateBucketIfNotExists(dbChanIDBytes)
|
||||||
|
}
|
||||||
|
|
||||||
// createSessionKeyIndexKey returns the identifier used in the
|
// createSessionKeyIndexKey returns the identifier used in the
|
||||||
// session-key-index index, created as tower-id||blob-type.
|
// session-key-index index, created as tower-id||blob-type.
|
||||||
//
|
//
|
||||||
@ -825,13 +974,18 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
|||||||
return ErrUninitializedDB
|
return ErrUninitializedDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||||
|
if chanIDIndexBkt == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
// If no tower ID is specified, then fetch all the sessions
|
// If no tower ID is specified, then fetch all the sessions
|
||||||
// known to the db.
|
// known to the db.
|
||||||
if id == nil {
|
if id == nil {
|
||||||
clientSessions, err = c.listClientAllSessions(
|
clientSessions, err = c.listClientAllSessions(
|
||||||
sessions, opts...,
|
sessions, chanIDIndexBkt, opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -843,7 +997,8 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
clientSessions, err = c.listTowerSessions(
|
clientSessions, err = c.listTowerSessions(
|
||||||
*id, sessions, towerToSessionIndex, opts...,
|
*id, sessions, chanIDIndexBkt, towerToSessionIndex,
|
||||||
|
opts...,
|
||||||
)
|
)
|
||||||
return err
|
return err
|
||||||
}, func() {
|
}, func() {
|
||||||
@ -857,7 +1012,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// listClientAllSessions returns the set of all client sessions known to the db.
|
// listClientAllSessions returns the set of all client sessions known to the db.
|
||||||
func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
|
func (c *ClientDB) listClientAllSessions(sessions, chanIDIndexBkt kvdb.RBucket,
|
||||||
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
clientSessions := make(map[SessionID]*ClientSession)
|
clientSessions := make(map[SessionID]*ClientSession)
|
||||||
@ -866,7 +1021,9 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
|
|||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := c.getClientSession(sessions, k, opts...)
|
session, err := c.getClientSession(
|
||||||
|
sessions, chanIDIndexBkt, k, opts...,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -884,7 +1041,7 @@ func (c *ClientDB) listClientAllSessions(sessions kvdb.RBucket,
|
|||||||
|
|
||||||
// listTowerSessions returns the set of all client sessions known to the db
|
// listTowerSessions returns the set of all client sessions known to the db
|
||||||
// that are associated with the given tower id.
|
// that are associated with the given tower id.
|
||||||
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt,
|
func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt, chanIDIndexBkt,
|
||||||
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
|
||||||
map[SessionID]*ClientSession, error) {
|
map[SessionID]*ClientSession, error) {
|
||||||
|
|
||||||
@ -899,7 +1056,9 @@ func (c *ClientDB) listTowerSessions(id TowerID, sessionsBkt,
|
|||||||
// the CommittedUpdates and AckedUpdates on startup to resume
|
// the CommittedUpdates and AckedUpdates on startup to resume
|
||||||
// committed updates and compute the highest known commit height
|
// committed updates and compute the highest known commit height
|
||||||
// for each channel.
|
// for each channel.
|
||||||
session, err := c.getClientSession(sessionsBkt, k, opts...)
|
session, err := c.getClientSession(
|
||||||
|
sessionsBkt, chanIDIndexBkt, k, opts...,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -944,6 +1103,73 @@ func (c *ClientDB) FetchSessionCommittedUpdates(id *SessionID) (
|
|||||||
return committedUpdates, nil
|
return committedUpdates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAcked returns true if the given backup has been backed up using the given
|
||||||
|
// session.
|
||||||
|
func (c *ClientDB) IsAcked(id *SessionID, backupID *BackupID) (bool, error) {
|
||||||
|
index, err := c.getRangeIndex(nil, *id, backupID.ChanID)
|
||||||
|
if errors.Is(err, ErrNoRangeIndexFound) {
|
||||||
|
return false, nil
|
||||||
|
} else if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return index.IsInIndex(backupID.CommitHeight), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumAckedUpdates returns the number of backups that have been successfully
|
||||||
|
// backed up using the given session.
|
||||||
|
func (c *ClientDB) NumAckedUpdates(id *SessionID) (uint64, error) {
|
||||||
|
var numAcked uint64
|
||||||
|
err := kvdb.View(c.db, func(tx kvdb.RTx) error {
|
||||||
|
sessions := tx.ReadBucket(cSessionBkt)
|
||||||
|
if sessions == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
chanIDIndexBkt := tx.ReadBucket(cChanIDIndexBkt)
|
||||||
|
if chanIDIndexBkt == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionBkt := sessions.NestedReadBucket(id[:])
|
||||||
|
if sessionsBkt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionAckRanges := sessionBkt.NestedReadBucket(
|
||||||
|
cSessionAckRangeIndex,
|
||||||
|
)
|
||||||
|
if sessionAckRanges == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over the channel ID's in the sessionAckRanges
|
||||||
|
// bucket.
|
||||||
|
return sessionAckRanges.ForEach(func(dbChanID, _ []byte) error {
|
||||||
|
// Get the range index for the session-channel pair.
|
||||||
|
chanIDBytes := chanIDIndexBkt.Get(dbChanID)
|
||||||
|
var chanID lnwire.ChannelID
|
||||||
|
copy(chanID[:], chanIDBytes)
|
||||||
|
|
||||||
|
index, err := c.getRangeIndex(tx, *id, chanID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
numAcked += index.NumInSet()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
}, func() {
|
||||||
|
numAcked = 0
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return numAcked, nil
|
||||||
|
}
|
||||||
|
|
||||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||||
// channel summaries.
|
// channel summaries.
|
||||||
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
func (c *ClientDB) FetchChanSummaries() (ChannelSummaries, error) {
|
||||||
@ -1174,6 +1400,11 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||||||
return ErrUninitializedDB
|
return ErrUninitializedDB
|
||||||
}
|
}
|
||||||
|
|
||||||
|
chanDetailsBkt := tx.ReadBucket(cChanDetailsBkt)
|
||||||
|
if chanDetailsBkt == nil {
|
||||||
|
return ErrUninitializedDB
|
||||||
|
}
|
||||||
|
|
||||||
// We'll only load the ClientSession body for performance, since
|
// We'll only load the ClientSession body for performance, since
|
||||||
// we primarily need to inspect its SeqNum and TowerLastApplied
|
// we primarily need to inspect its SeqNum and TowerLastApplied
|
||||||
// fields. The CommittedUpdates and AckedUpdates will be
|
// fields. The CommittedUpdates and AckedUpdates will be
|
||||||
@ -1242,25 +1473,24 @@ func (c *ClientDB) AckUpdate(id *SessionID, seqNum uint16,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure that the session acks sub-bucket is initialized, so we
|
chanID := committedUpdate.BackupID.ChanID
|
||||||
// can insert an entry.
|
height := committedUpdate.BackupID.CommitHeight
|
||||||
sessionAcks, err := sessionBkt.CreateBucketIfNotExists(
|
|
||||||
cSessionAcks,
|
// Get the ranges write bucket before getting the range index to
|
||||||
)
|
// ensure that the session acks sub-bucket is initialized, so
|
||||||
|
// that we can insert an entry.
|
||||||
|
rangesBkt, err := getRangesWriteBucket(tx, *id, chanID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// The session acks only need to track the backup id of the
|
// Get the range index for the given session-channel pair.
|
||||||
// update, so we can discard the blob and hint.
|
index, err := c.getRangeIndex(tx, *id, chanID)
|
||||||
var b bytes.Buffer
|
|
||||||
err = committedUpdate.BackupID.Encode(&b)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Finally, insert the ack into the sessionAcks sub-bucket.
|
return index.Add(height, rangesBkt)
|
||||||
return sessionAcks.Put(seqNumBuf[:], b.Bytes())
|
|
||||||
}, func() {})
|
}, func() {})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1293,9 +1523,15 @@ func getClientSessionBody(sessions kvdb.RBucket,
|
|||||||
return &session, nil
|
return &session, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// PerAckedUpdateCB describes the signature of a callback function that can be
|
// PerMaxHeightCB describes the signature of a callback function that can be
|
||||||
// called for each of a session's acked updates.
|
// called for each channel that a session has updates for to communicate the
|
||||||
type PerAckedUpdateCB func(*ClientSession, uint16, BackupID)
|
// maximum commitment height that the session has backed up for the channel.
|
||||||
|
type PerMaxHeightCB func(*ClientSession, lnwire.ChannelID, uint64)
|
||||||
|
|
||||||
|
// PerNumAckedUpdatesCB describes the signature of a callback function that can
|
||||||
|
// be called for each channel that a session has updates for to communicate the
|
||||||
|
// number of updates that the session has for the channel.
|
||||||
|
type PerNumAckedUpdatesCB func(*ClientSession, lnwire.ChannelID, uint16)
|
||||||
|
|
||||||
// PerCommittedUpdateCB describes the signature of a callback function that can
|
// PerCommittedUpdateCB describes the signature of a callback function that can
|
||||||
// be called for each of a session's committed updates (updates that the client
|
// be called for each of a session's committed updates (updates that the client
|
||||||
@ -1310,9 +1546,15 @@ type ClientSessionListOption func(cfg *ClientSessionListCfg)
|
|||||||
// ClientSessionListCfg defines various query parameters that will be used when
|
// ClientSessionListCfg defines various query parameters that will be used when
|
||||||
// querying the DB for client sessions.
|
// querying the DB for client sessions.
|
||||||
type ClientSessionListCfg struct {
|
type ClientSessionListCfg struct {
|
||||||
// PerAckedUpdate will, if set, be called for each of the session's
|
// PerNumAckedUpdates will, if set, be called for each of the session's
|
||||||
// acked updates.
|
// channels to communicate the number of updates stored for that
|
||||||
PerAckedUpdate PerAckedUpdateCB
|
// channel.
|
||||||
|
PerNumAckedUpdates PerNumAckedUpdatesCB
|
||||||
|
|
||||||
|
// PerMaxHeight will, if set, be called for each of the session's
|
||||||
|
// channels to communicate the highest commit height of updates stored
|
||||||
|
// for that channel.
|
||||||
|
PerMaxHeight PerMaxHeightCB
|
||||||
|
|
||||||
// PerCommittedUpdate will, if set, be called for each of the session's
|
// PerCommittedUpdate will, if set, be called for each of the session's
|
||||||
// committed (un-acked) updates.
|
// committed (un-acked) updates.
|
||||||
@ -1324,11 +1566,22 @@ func NewClientSessionCfg() *ClientSessionListCfg {
|
|||||||
return &ClientSessionListCfg{}
|
return &ClientSessionListCfg{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithPerAckedUpdate constructs a functional option that will set a call-back
|
// WithPerMaxHeight constructs a functional option that will set a call-back
|
||||||
// function to be called for each of a client's acked updates.
|
// function to be called for each of a session's channels to communicate the
|
||||||
func WithPerAckedUpdate(cb PerAckedUpdateCB) ClientSessionListOption {
|
// maximum commitment height that the session has stored for the channel.
|
||||||
|
func WithPerMaxHeight(cb PerMaxHeightCB) ClientSessionListOption {
|
||||||
return func(cfg *ClientSessionListCfg) {
|
return func(cfg *ClientSessionListCfg) {
|
||||||
cfg.PerAckedUpdate = cb
|
cfg.PerMaxHeight = cb
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithPerNumAckedUpdates constructs a functional option that will set a
|
||||||
|
// call-back function to be called for each of a session's channels to
|
||||||
|
// communicate the number of updates that the session has stored for the
|
||||||
|
// channel.
|
||||||
|
func WithPerNumAckedUpdates(cb PerNumAckedUpdatesCB) ClientSessionListOption {
|
||||||
|
return func(cfg *ClientSessionListCfg) {
|
||||||
|
cfg.PerNumAckedUpdates = cb
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1343,21 +1596,22 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
|
|||||||
// getClientSession loads the full ClientSession associated with the serialized
|
// getClientSession loads the full ClientSession associated with the serialized
|
||||||
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
// session id. This method populates the CommittedUpdates, AckUpdates and Tower
|
||||||
// in addition to the ClientSession's body.
|
// in addition to the ClientSession's body.
|
||||||
func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte,
|
func (c *ClientDB) getClientSession(sessionsBkt, chanIDIndexBkt kvdb.RBucket,
|
||||||
opts ...ClientSessionListOption) (*ClientSession, error) {
|
idBytes []byte, opts ...ClientSessionListOption) (*ClientSession,
|
||||||
|
error) {
|
||||||
|
|
||||||
cfg := NewClientSessionCfg()
|
cfg := NewClientSessionCfg()
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
o(cfg)
|
o(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := getClientSessionBody(sessions, idBytes)
|
session, err := getClientSessionBody(sessionsBkt, idBytes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Can't fail because client session body has already been read.
|
// Can't fail because client session body has already been read.
|
||||||
sessionBkt := sessions.NestedReadBucket(idBytes)
|
sessionBkt := sessionsBkt.NestedReadBucket(idBytes)
|
||||||
|
|
||||||
// Pass the session's committed (un-acked) updates through the call-back
|
// Pass the session's committed (un-acked) updates through the call-back
|
||||||
// if one is provided.
|
// if one is provided.
|
||||||
@ -1370,7 +1624,10 @@ func (c *ClientDB) getClientSession(sessions kvdb.RBucket, idBytes []byte,
|
|||||||
|
|
||||||
// Pass the session's acked updates through the call-back if one is
|
// Pass the session's acked updates through the call-back if one is
|
||||||
// provided.
|
// provided.
|
||||||
err = filterClientSessionAcks(sessionBkt, session, cfg.PerAckedUpdate)
|
err = c.filterClientSessionAcks(
|
||||||
|
sessionBkt, chanIDIndexBkt, session, cfg.PerMaxHeight,
|
||||||
|
cfg.PerNumAckedUpdates,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -1419,35 +1676,43 @@ func getClientSessionCommits(sessionBkt kvdb.RBucket, s *ClientSession,
|
|||||||
// filterClientSessionAcks retrieves all acked updates for the session
|
// filterClientSessionAcks retrieves all acked updates for the session
|
||||||
// identified by the serialized session id and passes them to the provided
|
// identified by the serialized session id and passes them to the provided
|
||||||
// call back if one is provided.
|
// call back if one is provided.
|
||||||
func filterClientSessionAcks(sessionBkt kvdb.RBucket, s *ClientSession,
|
func (c *ClientDB) filterClientSessionAcks(sessionBkt,
|
||||||
cb PerAckedUpdateCB) error {
|
chanIDIndexBkt kvdb.RBucket, s *ClientSession, perMaxCb PerMaxHeightCB,
|
||||||
|
perNumAckedUpdates PerNumAckedUpdatesCB) error {
|
||||||
|
|
||||||
if cb == nil {
|
if perMaxCb == nil && perNumAckedUpdates == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionAcks := sessionBkt.NestedReadBucket(cSessionAcks)
|
sessionAcksRanges := sessionBkt.NestedReadBucket(cSessionAckRangeIndex)
|
||||||
if sessionAcks == nil {
|
if sessionAcksRanges == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := sessionAcks.ForEach(func(k, v []byte) error {
|
return sessionAcksRanges.ForEach(func(dbChanID, _ []byte) error {
|
||||||
seqNum := byteOrder.Uint16(k)
|
rangeBkt := sessionAcksRanges.NestedReadBucket(dbChanID)
|
||||||
|
if rangeBkt == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
var backupID BackupID
|
index, err := readRangeIndex(rangeBkt)
|
||||||
err := backupID.Decode(bytes.NewReader(v))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
cb(s, seqNum, backupID)
|
chanIDBytes := chanIDIndexBkt.Get(dbChanID)
|
||||||
|
var chanID lnwire.ChannelID
|
||||||
|
copy(chanID[:], chanIDBytes)
|
||||||
|
|
||||||
|
if perMaxCb != nil {
|
||||||
|
perMaxCb(s, chanID, index.MaxHeight())
|
||||||
|
}
|
||||||
|
|
||||||
|
if perNumAckedUpdates != nil {
|
||||||
|
perNumAckedUpdates(s, chanID, uint16(index.NumInSet()))
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterClientSessionCommits retrieves all committed updates for the session
|
// filterClientSessionCommits retrieves all committed updates for the session
|
||||||
|
@ -221,6 +221,26 @@ func (h *clientDBHarness) fetchSessionCommittedUpdates(id *wtdb.SessionID,
|
|||||||
return updates
|
return updates
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (h *clientDBHarness) isAcked(id *wtdb.SessionID, backupID *wtdb.BackupID,
|
||||||
|
expErr error) bool {
|
||||||
|
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
isAcked, err := h.db.IsAcked(id, backupID)
|
||||||
|
require.ErrorIs(h.t, err, expErr)
|
||||||
|
|
||||||
|
return isAcked
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *clientDBHarness) numAcked(id *wtdb.SessionID, expErr error) uint64 {
|
||||||
|
h.t.Helper()
|
||||||
|
|
||||||
|
numAcked, err := h.db.NumAckedUpdates(id)
|
||||||
|
require.ErrorIs(h.t, err, expErr)
|
||||||
|
|
||||||
|
return numAcked
|
||||||
|
}
|
||||||
|
|
||||||
// testCreateClientSession asserts various conditions regarding the creation of
|
// testCreateClientSession asserts various conditions regarding the creation of
|
||||||
// a new ClientSession. The test asserts:
|
// a new ClientSession. The test asserts:
|
||||||
// - client sessions can only be created if a session key index is reserved.
|
// - client sessions can only be created if a session key index is reserved.
|
||||||
@ -453,6 +473,7 @@ func testRemoveTower(h *clientDBHarness) {
|
|||||||
}
|
}
|
||||||
h.insertSession(session, nil)
|
h.insertSession(session, nil)
|
||||||
update := randCommittedUpdate(h.t, 1)
|
update := randCommittedUpdate(h.t, 1)
|
||||||
|
h.registerChan(update.BackupID.ChanID, nil, nil)
|
||||||
h.commitUpdate(&session.ID, update, nil)
|
h.commitUpdate(&session.ID, update, nil)
|
||||||
|
|
||||||
// We should not be able to fully remove it from the database since
|
// We should not be able to fully remove it from the database since
|
||||||
@ -583,16 +604,6 @@ func testCommitUpdate(h *clientDBHarness) {
|
|||||||
}, nil)
|
}, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func perAckedUpdate(updates map[uint16]wtdb.BackupID) func(
|
|
||||||
_ *wtdb.ClientSession, seq uint16, id wtdb.BackupID) {
|
|
||||||
|
|
||||||
return func(_ *wtdb.ClientSession, seq uint16,
|
|
||||||
id wtdb.BackupID) {
|
|
||||||
|
|
||||||
updates[seq] = id
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// testAckUpdate asserts the behavior of AckUpdate.
|
// testAckUpdate asserts the behavior of AckUpdate.
|
||||||
func testAckUpdate(h *clientDBHarness) {
|
func testAckUpdate(h *clientDBHarness) {
|
||||||
const blobType = blob.TypeAltruistCommit
|
const blobType = blob.TypeAltruistCommit
|
||||||
@ -628,6 +639,8 @@ func testAckUpdate(h *clientDBHarness) {
|
|||||||
|
|
||||||
// Commit to a random update at seqnum 1.
|
// Commit to a random update at seqnum 1.
|
||||||
update1 := randCommittedUpdate(h.t, 1)
|
update1 := randCommittedUpdate(h.t, 1)
|
||||||
|
|
||||||
|
h.registerChan(update1.BackupID.ChanID, nil, nil)
|
||||||
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
lastApplied := h.commitUpdate(&session.ID, update1, nil)
|
||||||
require.Zero(h.t, lastApplied)
|
require.Zero(h.t, lastApplied)
|
||||||
|
|
||||||
@ -654,6 +667,7 @@ func testAckUpdate(h *clientDBHarness) {
|
|||||||
// value is 1, since this was what was provided in the last successful
|
// value is 1, since this was what was provided in the last successful
|
||||||
// ack.
|
// ack.
|
||||||
update2 := randCommittedUpdate(h.t, 2)
|
update2 := randCommittedUpdate(h.t, 2)
|
||||||
|
h.registerChan(update2.BackupID.ChanID, nil, nil)
|
||||||
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
lastApplied = h.commitUpdate(&session.ID, update2, nil)
|
||||||
require.EqualValues(h.t, 1, lastApplied)
|
require.EqualValues(h.t, 1, lastApplied)
|
||||||
|
|
||||||
@ -681,13 +695,16 @@ func (h *clientDBHarness) assertUpdates(id wtdb.SessionID,
|
|||||||
expectedPending []wtdb.CommittedUpdate,
|
expectedPending []wtdb.CommittedUpdate,
|
||||||
expectedAcked map[uint16]wtdb.BackupID) {
|
expectedAcked map[uint16]wtdb.BackupID) {
|
||||||
|
|
||||||
ackedUpdates := make(map[uint16]wtdb.BackupID)
|
committedUpdates := h.fetchSessionCommittedUpdates(&id, nil)
|
||||||
_ = h.listSessions(
|
checkCommittedUpdates(h.t, committedUpdates, expectedPending)
|
||||||
nil, wtdb.WithPerAckedUpdate(perAckedUpdate(ackedUpdates)),
|
|
||||||
)
|
// Check acked updates.
|
||||||
committedUpates := h.fetchSessionCommittedUpdates(&id, nil)
|
numAcked := h.numAcked(&id, nil)
|
||||||
checkCommittedUpdates(h.t, committedUpates, expectedPending)
|
require.EqualValues(h.t, len(expectedAcked), numAcked)
|
||||||
checkAckedUpdates(h.t, ackedUpdates, expectedAcked)
|
for _, backupID := range expectedAcked {
|
||||||
|
isAcked := h.isAcked(&id, &backupID, nil)
|
||||||
|
require.True(h.t, isAcked)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
|
// checkCommittedUpdates asserts that the CommittedUpdates on session match the
|
||||||
@ -707,21 +724,6 @@ func checkCommittedUpdates(t *testing.T, actualUpdates,
|
|||||||
require.Equal(t, expUpdates, actualUpdates)
|
require.Equal(t, expUpdates, actualUpdates)
|
||||||
}
|
}
|
||||||
|
|
||||||
// checkAckedUpdates asserts that the AckedUpdates on a session match the
|
|
||||||
// expUpdates provided.
|
|
||||||
func checkAckedUpdates(t *testing.T, actualUpdates,
|
|
||||||
expUpdates map[uint16]wtdb.BackupID) {
|
|
||||||
|
|
||||||
// We promote nil expUpdates to an initialized map since the database
|
|
||||||
// should never return a nil map. This promotion is done purely out of
|
|
||||||
// convenience for the testing framework.
|
|
||||||
if expUpdates == nil {
|
|
||||||
expUpdates = make(map[uint16]wtdb.BackupID)
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, expUpdates, actualUpdates)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
// TestClientDB asserts the behavior of a fresh client db, a reopened client db,
|
||||||
// and the mock implementation. This ensures that all databases function
|
// and the mock implementation. This ensures that all databases function
|
||||||
// identically, especially in the negative paths.
|
// identically, especially in the negative paths.
|
||||||
|
@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
|
||||||
)
|
)
|
||||||
|
|
||||||
// log is a logger that is initialized with no output filters. This
|
// log is a logger that is initialized with no output filters. This
|
||||||
@ -32,6 +33,7 @@ func UseLogger(logger btclog.Logger) {
|
|||||||
migration1.UseLogger(logger)
|
migration1.UseLogger(logger)
|
||||||
migration2.UseLogger(logger)
|
migration2.UseLogger(logger)
|
||||||
migration3.UseLogger(logger)
|
migration3.UseLogger(logger)
|
||||||
|
migration4.UseLogger(logger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// logClosure is used to provide a closure over expensive logging operations so
|
// logClosure is used to provide a closure over expensive logging operations so
|
||||||
|
@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration1"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration2"
|
||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration3"
|
||||||
|
"github.com/lightningnetwork/lnd/watchtower/wtdb/migration4"
|
||||||
)
|
)
|
||||||
|
|
||||||
// txMigration is a function which takes a prior outdated version of the
|
// txMigration is a function which takes a prior outdated version of the
|
||||||
@ -49,6 +50,11 @@ var clientDBVersions = []version{
|
|||||||
{
|
{
|
||||||
txMigration: migration3.MigrateChannelIDIndex,
|
txMigration: migration3.MigrateChannelIDIndex,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
dbMigration: migration4.MigrateAckedUpdates(
|
||||||
|
migration4.DefaultSessionsPerTx,
|
||||||
|
),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
// getLatestDBVersion returns the last known database version.
|
// getLatestDBVersion returns the last known database version.
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package wtmock
|
package wtmock
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
@ -11,6 +12,8 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var byteOrder = binary.BigEndian
|
||||||
|
|
||||||
type towerPK [33]byte
|
type towerPK [33]byte
|
||||||
|
|
||||||
type keyIndexKey struct {
|
type keyIndexKey struct {
|
||||||
@ -18,18 +21,23 @@ type keyIndexKey struct {
|
|||||||
blobType blob.Type
|
blobType blob.Type
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
|
||||||
|
|
||||||
|
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
|
||||||
|
|
||||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||||
// behavior.
|
// behavior.
|
||||||
type ClientDB struct {
|
type ClientDB struct {
|
||||||
nextTowerID uint64 // to be used atomically
|
nextTowerID uint64 // to be used atomically
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
summaries map[lnwire.ChannelID]wtdb.ClientChanSummary
|
||||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||||
ackedUpdates map[wtdb.SessionID]map[uint16]wtdb.BackupID
|
ackedUpdates rangeIndexArrayMap
|
||||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
persistedAckedUpdates rangeIndexKVStore
|
||||||
towerIndex map[towerPK]wtdb.TowerID
|
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||||
towers map[wtdb.TowerID]*wtdb.Tower
|
towerIndex map[towerPK]wtdb.TowerID
|
||||||
|
towers map[wtdb.TowerID]*wtdb.Tower
|
||||||
|
|
||||||
nextIndex uint32
|
nextIndex uint32
|
||||||
indexes map[keyIndexKey]uint32
|
indexes map[keyIndexKey]uint32
|
||||||
@ -39,14 +47,21 @@ type ClientDB struct {
|
|||||||
// NewClientDB initializes a new mock ClientDB.
|
// NewClientDB initializes a new mock ClientDB.
|
||||||
func NewClientDB() *ClientDB {
|
func NewClientDB() *ClientDB {
|
||||||
return &ClientDB{
|
return &ClientDB{
|
||||||
summaries: make(map[lnwire.ChannelID]wtdb.ClientChanSummary),
|
summaries: make(
|
||||||
activeSessions: make(map[wtdb.SessionID]wtdb.ClientSession),
|
map[lnwire.ChannelID]wtdb.ClientChanSummary,
|
||||||
ackedUpdates: make(map[wtdb.SessionID]map[uint16]wtdb.BackupID),
|
),
|
||||||
committedUpdates: make(map[wtdb.SessionID][]wtdb.CommittedUpdate),
|
activeSessions: make(
|
||||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
map[wtdb.SessionID]wtdb.ClientSession,
|
||||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
),
|
||||||
indexes: make(map[keyIndexKey]uint32),
|
ackedUpdates: make(rangeIndexArrayMap),
|
||||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
persistedAckedUpdates: make(rangeIndexKVStore),
|
||||||
|
committedUpdates: make(
|
||||||
|
map[wtdb.SessionID][]wtdb.CommittedUpdate,
|
||||||
|
),
|
||||||
|
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||||
|
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||||
|
indexes: make(map[keyIndexKey]uint32),
|
||||||
|
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,9 +248,20 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
|||||||
}
|
}
|
||||||
sessions[session.ID] = &session
|
sessions[session.ID] = &session
|
||||||
|
|
||||||
if cfg.PerAckedUpdate != nil {
|
if cfg.PerMaxHeight != nil {
|
||||||
for seq, id := range m.ackedUpdates[session.ID] {
|
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||||
cfg.PerAckedUpdate(&session, seq, id)
|
cfg.PerMaxHeight(
|
||||||
|
&session, chanID, index.MaxHeight(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.PerNumAckedUpdates != nil {
|
||||||
|
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||||
|
cfg.PerNumAckedUpdates(
|
||||||
|
&session, chanID,
|
||||||
|
uint16(index.NumInSet()),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -266,6 +292,37 @@ func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
|||||||
return updates, nil
|
return updates, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAcked returns true if the given backup has been backed up using the given
|
||||||
|
// session.
|
||||||
|
func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
index, ok := m.ackedUpdates[*id][backupID.ChanID]
|
||||||
|
if !ok {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return index.IsInIndex(backupID.CommitHeight), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NumAckedUpdates returns the number of backups that have been successfully
|
||||||
|
// backed up using the given session.
|
||||||
|
func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
var numAcked uint64
|
||||||
|
|
||||||
|
for _, index := range m.ackedUpdates[*id] {
|
||||||
|
numAcked += index.NumInSet()
|
||||||
|
}
|
||||||
|
|
||||||
|
return numAcked, nil
|
||||||
|
}
|
||||||
|
|
||||||
// CreateClientSession records a newly negotiated client session in the set of
|
// CreateClientSession records a newly negotiated client session in the set of
|
||||||
// active sessions. The session can be identified by its SessionID.
|
// active sessions. The session can be identified by its SessionID.
|
||||||
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||||
@ -311,7 +368,10 @@ func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
|||||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
m.ackedUpdates[session.ID] = make(map[uint16]wtdb.BackupID)
|
m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex)
|
||||||
|
m.persistedAckedUpdates[session.ID] = make(
|
||||||
|
map[lnwire.ChannelID]*mockKVStore,
|
||||||
|
)
|
||||||
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
|
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@ -443,7 +503,25 @@ func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
|||||||
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
||||||
m.committedUpdates[session.ID] = updates[:len(updates)-1]
|
m.committedUpdates[session.ID] = updates[:len(updates)-1]
|
||||||
|
|
||||||
m.ackedUpdates[*id][seqNum] = update.BackupID
|
chanID := update.BackupID.ChanID
|
||||||
|
if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok {
|
||||||
|
index, err := wtdb.NewRangeIndex(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ackedUpdates[*id][chanID] = index
|
||||||
|
m.persistedAckedUpdates[*id][chanID] = newMockKVStore()
|
||||||
|
}
|
||||||
|
|
||||||
|
err := m.ackedUpdates[*id][chanID].Add(
|
||||||
|
update.BackupID.CommitHeight,
|
||||||
|
m.persistedAckedUpdates[*id][chanID],
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
session.TowerLastApplied = lastApplied
|
session.TowerLastApplied = lastApplied
|
||||||
|
|
||||||
m.activeSessions[*id] = session
|
m.activeSessions[*id] = session
|
||||||
@ -512,3 +590,39 @@ func copyTower(tower *wtdb.Tower) *wtdb.Tower {
|
|||||||
|
|
||||||
return t
|
return t
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type mockKVStore struct {
|
||||||
|
kv map[uint64]uint64
|
||||||
|
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockKVStore() *mockKVStore {
|
||||||
|
return &mockKVStore{
|
||||||
|
kv: make(map[uint64]uint64),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockKVStore) Put(key, value []byte) error {
|
||||||
|
if m.err != nil {
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
k := byteOrder.Uint64(key)
|
||||||
|
v := byteOrder.Uint64(value)
|
||||||
|
|
||||||
|
m.kv[k] = v
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockKVStore) Delete(key []byte) error {
|
||||||
|
if m.err != nil {
|
||||||
|
return m.err
|
||||||
|
}
|
||||||
|
|
||||||
|
k := byteOrder.Uint64(key)
|
||||||
|
delete(m.kv, k)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user