mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-31 02:01:46 +02:00
watchtower: completely remove the mock tower client DB
Remove the use of the mock tower client DB and use the actual bbolt DB everywhere instead.
This commit is contained in:
parent
f889c9b1cc
commit
ff0d8fc619
@ -18,51 +18,13 @@ const (
|
||||
waitTime = time.Second * 2
|
||||
)
|
||||
|
||||
type initQueue func(t *testing.T) wtdb.Queue[*wtdb.BackupID]
|
||||
|
||||
// TestDiskOverflowQueue tests that the DiskOverflowQueue behaves as expected.
|
||||
func TestDiskOverflowQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbs := []struct {
|
||||
name string
|
||||
init initQueue
|
||||
}{
|
||||
{
|
||||
name: "kvdb",
|
||||
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
return db.GetDBQueue([]byte("test-namespace"))
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) wtdb.Queue[*wtdb.BackupID] {
|
||||
db := wtmock.NewClientDB()
|
||||
|
||||
return db.GetDBQueue([]byte("test-namespace"))
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
run func(*testing.T, initQueue)
|
||||
run func(*testing.T, wtdb.Queue[*wtdb.BackupID])
|
||||
}{
|
||||
{
|
||||
name: "overflow to disk",
|
||||
@ -78,29 +40,43 @@ func TestDiskOverflowQueue(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
initDB := func() wtdb.Queue[*wtdb.BackupID] {
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
test.run(t, db.init)
|
||||
})
|
||||
}
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
require.NoError(t, db.Close())
|
||||
})
|
||||
|
||||
return db.GetDBQueue([]byte("test-namespace"))
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
|
||||
t.Run(test.name, func(tt *testing.T) {
|
||||
tt.Parallel()
|
||||
|
||||
test.run(tt, initDB())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testOverflowToDisk is a basic test that ensures that the queue correctly
|
||||
// overflows items to disk and then correctly reloads them.
|
||||
func testOverflowToDisk(t *testing.T, initQueue initQueue) {
|
||||
func testOverflowToDisk(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) {
|
||||
// Generate some backup IDs that we want to add to the queue.
|
||||
tasks := genBackupIDs(10)
|
||||
|
||||
// Init the DB.
|
||||
db := initQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
@ -146,7 +122,9 @@ func testOverflowToDisk(t *testing.T, initQueue initQueue) {
|
||||
// testRestartWithSmallerBufferSize tests that if the queue is restarted with
|
||||
// a smaller in-memory buffer size that it was initially started with, then
|
||||
// tasks are still loaded in the correct order.
|
||||
func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
|
||||
func testRestartWithSmallerBufferSize(t *testing.T,
|
||||
db wtdb.Queue[*wtdb.BackupID]) {
|
||||
|
||||
const (
|
||||
firstMaxInMemItems = 5
|
||||
secondMaxInMemItems = 2
|
||||
@ -155,9 +133,6 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
|
||||
// Generate some backup IDs that we want to add to the queue.
|
||||
tasks := genBackupIDs(10)
|
||||
|
||||
// Create a db.
|
||||
db := newQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
@ -223,14 +198,11 @@ func testRestartWithSmallerBufferSize(t *testing.T, newQueue initQueue) {
|
||||
// testStartStopQueue is a stress test that pushes a large number of tasks
|
||||
// through the queue while also restarting the queue a couple of times
|
||||
// throughout.
|
||||
func testStartStopQueue(t *testing.T, newQueue initQueue) {
|
||||
func testStartStopQueue(t *testing.T, db wtdb.Queue[*wtdb.BackupID]) {
|
||||
// Generate a lot of backup IDs that we want to add to the
|
||||
// queue one after the other.
|
||||
tasks := genBackupIDs(200_000)
|
||||
|
||||
// Construct the ClientDB.
|
||||
db := newQueue(t)
|
||||
|
||||
// New mock logger.
|
||||
log := newMockLogger(t.Logf)
|
||||
|
||||
|
@ -13,7 +13,6 @@ import (
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtpolicy"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -964,12 +963,6 @@ func TestClientDB(t *testing.T) {
|
||||
return db
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) wtclient.DB {
|
||||
return wtmock.NewClientDB()
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -4,9 +4,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtclient"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtmock"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -15,53 +13,24 @@ import (
|
||||
func TestDiskQueue(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dbs := []struct {
|
||||
name string
|
||||
init clientDBInit
|
||||
}{
|
||||
{
|
||||
name: "bbolt",
|
||||
init: func(t *testing.T) wtclient.DB {
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
// Construct the ClientDB.
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
return db
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "mock",
|
||||
init: func(t *testing.T) wtclient.DB {
|
||||
return wtmock.NewClientDB()
|
||||
},
|
||||
},
|
||||
dbCfg := &kvdb.BoltConfig{
|
||||
DBTimeout: kvdb.DefaultDBTimeout,
|
||||
}
|
||||
|
||||
for _, database := range dbs {
|
||||
db := database
|
||||
t.Run(db.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
// Construct the ClientDB.
|
||||
bdb, err := wtdb.NewBoltBackendCreator(
|
||||
true, t.TempDir(), "wtclient.db",
|
||||
)(dbCfg)
|
||||
require.NoError(t, err)
|
||||
|
||||
testQueue(t, db.init(t))
|
||||
})
|
||||
}
|
||||
}
|
||||
db, err := wtdb.OpenClientDB(bdb)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
func testQueue(t *testing.T, db wtclient.DB) {
|
||||
namespace := []byte("test-namespace")
|
||||
queue := db.GetDBQueue(namespace)
|
||||
|
||||
|
@ -1,887 +0,0 @@
|
||||
package wtmock
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"net"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/btcsuite/btcd/btcec/v2"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/watchtower/blob"
|
||||
"github.com/lightningnetwork/lnd/watchtower/wtdb"
|
||||
)
|
||||
|
||||
var byteOrder = binary.BigEndian
|
||||
|
||||
type towerPK [33]byte
|
||||
|
||||
type keyIndexKey struct {
|
||||
towerID wtdb.TowerID
|
||||
blobType blob.Type
|
||||
}
|
||||
|
||||
type rangeIndexArrayMap map[wtdb.SessionID]map[lnwire.ChannelID]*wtdb.RangeIndex
|
||||
|
||||
type rangeIndexKVStore map[wtdb.SessionID]map[lnwire.ChannelID]*mockKVStore
|
||||
|
||||
type channel struct {
|
||||
summary *wtdb.ClientChanSummary
|
||||
closedHeight uint32
|
||||
sessions map[wtdb.SessionID]bool
|
||||
}
|
||||
|
||||
// ClientDB is a mock, in-memory database or testing the watchtower client
|
||||
// behavior.
|
||||
type ClientDB struct {
|
||||
nextTowerID uint64 // to be used atomically
|
||||
|
||||
mu sync.Mutex
|
||||
channels map[lnwire.ChannelID]*channel
|
||||
activeSessions map[wtdb.SessionID]wtdb.ClientSession
|
||||
ackedUpdates rangeIndexArrayMap
|
||||
persistedAckedUpdates rangeIndexKVStore
|
||||
committedUpdates map[wtdb.SessionID][]wtdb.CommittedUpdate
|
||||
towerIndex map[towerPK]wtdb.TowerID
|
||||
towers map[wtdb.TowerID]*wtdb.Tower
|
||||
closableSessions map[wtdb.SessionID]uint32
|
||||
|
||||
nextIndex uint32
|
||||
indexes map[keyIndexKey]uint32
|
||||
legacyIndexes map[wtdb.TowerID]uint32
|
||||
|
||||
queues map[string]wtdb.Queue[*wtdb.BackupID]
|
||||
}
|
||||
|
||||
// NewClientDB initializes a new mock ClientDB.
|
||||
func NewClientDB() *ClientDB {
|
||||
return &ClientDB{
|
||||
channels: make(map[lnwire.ChannelID]*channel),
|
||||
activeSessions: make(
|
||||
map[wtdb.SessionID]wtdb.ClientSession,
|
||||
),
|
||||
ackedUpdates: make(rangeIndexArrayMap),
|
||||
persistedAckedUpdates: make(rangeIndexKVStore),
|
||||
committedUpdates: make(
|
||||
map[wtdb.SessionID][]wtdb.CommittedUpdate,
|
||||
),
|
||||
towerIndex: make(map[towerPK]wtdb.TowerID),
|
||||
towers: make(map[wtdb.TowerID]*wtdb.Tower),
|
||||
indexes: make(map[keyIndexKey]uint32),
|
||||
legacyIndexes: make(map[wtdb.TowerID]uint32),
|
||||
closableSessions: make(map[wtdb.SessionID]uint32),
|
||||
queues: make(map[string]wtdb.Queue[*wtdb.BackupID]),
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTower initialize an address record used to communicate with a
|
||||
// watchtower. Each Tower is assigned a unique ID, that is used to amortize
|
||||
// storage costs of the public key when used by multiple sessions. If the tower
|
||||
// already exists, the address is appended to the list of all addresses used to
|
||||
// that tower previously and its corresponding sessions are marked as active.
|
||||
func (m *ClientDB) CreateTower(lnAddr *lnwire.NetAddress) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var towerPubKey towerPK
|
||||
copy(towerPubKey[:], lnAddr.IdentityKey.SerializeCompressed())
|
||||
|
||||
var tower *wtdb.Tower
|
||||
towerID, ok := m.towerIndex[towerPubKey]
|
||||
if ok {
|
||||
tower = m.towers[towerID]
|
||||
tower.AddAddress(lnAddr.Address)
|
||||
|
||||
towerSessions, err := m.listClientSessions(&towerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for id, session := range towerSessions {
|
||||
session.Status = wtdb.CSessionActive
|
||||
m.activeSessions[id] = *session
|
||||
}
|
||||
} else {
|
||||
towerID = wtdb.TowerID(atomic.AddUint64(&m.nextTowerID, 1))
|
||||
tower = &wtdb.Tower{
|
||||
ID: towerID,
|
||||
IdentityKey: lnAddr.IdentityKey,
|
||||
Addresses: []net.Addr{lnAddr.Address},
|
||||
}
|
||||
}
|
||||
|
||||
m.towerIndex[towerPubKey] = towerID
|
||||
m.towers[towerID] = tower
|
||||
|
||||
return copyTower(tower), nil
|
||||
}
|
||||
|
||||
// RemoveTower modifies a tower's record within the database. If an address is
|
||||
// provided, then _only_ the address record should be removed from the tower's
|
||||
// persisted state. Otherwise, we'll attempt to mark the tower as inactive by
|
||||
// marking all of its sessions inactive. If any of its sessions has unacked
|
||||
// updates, then ErrTowerUnackedUpdates is returned. If the tower doesn't have
|
||||
// any sessions at all, it'll be completely removed from the database.
|
||||
//
|
||||
// NOTE: An error is not returned if the tower doesn't exist.
|
||||
func (m *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
tower, err := m.loadTower(pubKey)
|
||||
if err == wtdb.ErrTowerNotFound {
|
||||
return nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if addr != nil {
|
||||
tower.RemoveAddress(addr)
|
||||
if len(tower.Addresses) == 0 {
|
||||
return wtdb.ErrLastTowerAddr
|
||||
}
|
||||
m.towers[tower.ID] = tower
|
||||
return nil
|
||||
}
|
||||
|
||||
towerSessions, err := m.listClientSessions(&tower.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(towerSessions) == 0 {
|
||||
var towerPK towerPK
|
||||
copy(towerPK[:], pubKey.SerializeCompressed())
|
||||
delete(m.towerIndex, towerPK)
|
||||
delete(m.towers, tower.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
for id, session := range towerSessions {
|
||||
if len(m.committedUpdates[session.ID]) > 0 {
|
||||
return wtdb.ErrTowerUnackedUpdates
|
||||
}
|
||||
session.Status = wtdb.CSessionInactive
|
||||
m.activeSessions[id] = *session
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadTower retrieves a tower by its public key.
|
||||
func (m *ClientDB) LoadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
return m.loadTower(pubKey)
|
||||
}
|
||||
|
||||
// loadTower retrieves a tower by its public key.
|
||||
//
|
||||
// NOTE: This method requires the database's lock to be acquired.
|
||||
func (m *ClientDB) loadTower(pubKey *btcec.PublicKey) (*wtdb.Tower, error) {
|
||||
var towerPK towerPK
|
||||
copy(towerPK[:], pubKey.SerializeCompressed())
|
||||
|
||||
towerID, ok := m.towerIndex[towerPK]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrTowerNotFound
|
||||
}
|
||||
tower, ok := m.towers[towerID]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrTowerNotFound
|
||||
}
|
||||
|
||||
return copyTower(tower), nil
|
||||
}
|
||||
|
||||
// LoadTowerByID retrieves a tower by its tower ID.
|
||||
func (m *ClientDB) LoadTowerByID(towerID wtdb.TowerID) (*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if tower, ok := m.towers[towerID]; ok {
|
||||
return copyTower(tower), nil
|
||||
}
|
||||
|
||||
return nil, wtdb.ErrTowerNotFound
|
||||
}
|
||||
|
||||
// ListTowers retrieves the list of towers available within the database.
|
||||
func (m *ClientDB) ListTowers() ([]*wtdb.Tower, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
towers := make([]*wtdb.Tower, 0, len(m.towers))
|
||||
for _, tower := range m.towers {
|
||||
towers = append(towers, copyTower(tower))
|
||||
}
|
||||
|
||||
return towers, nil
|
||||
}
|
||||
|
||||
// MarkBackupIneligible records that particular commit height is ineligible for
|
||||
// backup. This allows the client to track which updates it should not attempt
|
||||
// to retry after startup.
|
||||
func (m *ClientDB) MarkBackupIneligible(_ lnwire.ChannelID, _ uint64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) ListClientSessions(tower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
return m.listClientSessions(tower, opts...)
|
||||
}
|
||||
|
||||
// listClientSessions returns the set of all client sessions known to the db. An
|
||||
// optional tower ID can be used to filter out any client sessions in the
|
||||
// response that do not correspond to this tower.
|
||||
func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
|
||||
opts ...wtdb.ClientSessionListOption) (
|
||||
map[wtdb.SessionID]*wtdb.ClientSession, error) {
|
||||
|
||||
cfg := wtdb.NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
sessions := make(map[wtdb.SessionID]*wtdb.ClientSession)
|
||||
for _, session := range m.activeSessions {
|
||||
session := session
|
||||
if tower != nil && *tower != session.TowerID {
|
||||
continue
|
||||
}
|
||||
|
||||
if cfg.PreEvaluateFilterFn != nil &&
|
||||
!cfg.PreEvaluateFilterFn(&session) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if cfg.PerMaxHeight != nil {
|
||||
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||
cfg.PerMaxHeight(
|
||||
&session, chanID, index.MaxHeight(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerNumAckedUpdates != nil {
|
||||
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||
cfg.PerNumAckedUpdates(
|
||||
&session, chanID,
|
||||
uint16(index.NumInSet()),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerCommittedUpdate != nil {
|
||||
for _, update := range m.committedUpdates[session.ID] {
|
||||
update := update
|
||||
cfg.PerCommittedUpdate(&session, &update)
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PostEvaluateFilterFn != nil &&
|
||||
!cfg.PostEvaluateFilterFn(&session) {
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
sessions[session.ID] = &session
|
||||
}
|
||||
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
// FetchSessionCommittedUpdates retrieves the current set of un-acked updates
|
||||
// of the given session.
|
||||
func (m *ClientDB) FetchSessionCommittedUpdates(id *wtdb.SessionID) (
|
||||
[]wtdb.CommittedUpdate, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
updates, ok := m.committedUpdates[*id]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}
|
||||
|
||||
// IsAcked returns true if the given backup has been backed up using the given
|
||||
// session.
|
||||
func (m *ClientDB) IsAcked(id *wtdb.SessionID, backupID *wtdb.BackupID) (bool,
|
||||
error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
index, ok := m.ackedUpdates[*id][backupID.ChanID]
|
||||
if !ok {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return index.IsInIndex(backupID.CommitHeight), nil
|
||||
}
|
||||
|
||||
// NumAckedUpdates returns the number of backups that have been successfully
|
||||
// backed up using the given session.
|
||||
func (m *ClientDB) NumAckedUpdates(id *wtdb.SessionID) (uint64, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
var numAcked uint64
|
||||
|
||||
for _, index := range m.ackedUpdates[*id] {
|
||||
numAcked += index.NumInSet()
|
||||
}
|
||||
|
||||
return numAcked, nil
|
||||
}
|
||||
|
||||
// CreateClientSession records a newly negotiated client session in the set of
|
||||
// active sessions. The session can be identified by its SessionID.
|
||||
func (m *ClientDB) CreateClientSession(session *wtdb.ClientSession) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Ensure that we aren't overwriting an existing session.
|
||||
if _, ok := m.activeSessions[session.ID]; ok {
|
||||
return wtdb.ErrClientSessionAlreadyExists
|
||||
}
|
||||
|
||||
key := keyIndexKey{
|
||||
towerID: session.TowerID,
|
||||
blobType: session.Policy.BlobType,
|
||||
}
|
||||
|
||||
// Ensure that a session key index has been reserved for this tower.
|
||||
keyIndex, err := m.getSessionKeyIndex(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure that the session's index matches the reserved index.
|
||||
if keyIndex != session.KeyIndex {
|
||||
return wtdb.ErrIncorrectKeyIndex
|
||||
}
|
||||
|
||||
// Remove the key index reservation for this tower. Once committed, this
|
||||
// permits us to create another session with this tower.
|
||||
delete(m.indexes, key)
|
||||
if key.blobType == blob.TypeAltruistCommit {
|
||||
delete(m.legacyIndexes, key.towerID)
|
||||
}
|
||||
|
||||
m.activeSessions[session.ID] = wtdb.ClientSession{
|
||||
ID: session.ID,
|
||||
ClientSessionBody: wtdb.ClientSessionBody{
|
||||
SeqNum: session.SeqNum,
|
||||
TowerLastApplied: session.TowerLastApplied,
|
||||
TowerID: session.TowerID,
|
||||
KeyIndex: session.KeyIndex,
|
||||
Policy: session.Policy,
|
||||
RewardPkScript: cloneBytes(session.RewardPkScript),
|
||||
},
|
||||
}
|
||||
m.ackedUpdates[session.ID] = make(map[lnwire.ChannelID]*wtdb.RangeIndex)
|
||||
m.persistedAckedUpdates[session.ID] = make(
|
||||
map[lnwire.ChannelID]*mockKVStore,
|
||||
)
|
||||
m.committedUpdates[session.ID] = make([]wtdb.CommittedUpdate, 0)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// NextSessionKeyIndex reserves a new session key derivation index for a
|
||||
// particular tower id. The index is reserved for that tower until
|
||||
// CreateClientSession is invoked for that tower and index, at which point a new
|
||||
// index for that tower can be reserved. Multiple calls to this method before
|
||||
// CreateClientSession is invoked should return the same index unless forceNext
|
||||
// is set to true.
|
||||
func (m *ClientDB) NextSessionKeyIndex(towerID wtdb.TowerID, blobType blob.Type,
|
||||
forceNext bool) (uint32, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
key := keyIndexKey{
|
||||
towerID: towerID,
|
||||
blobType: blobType,
|
||||
}
|
||||
|
||||
if !forceNext {
|
||||
if index, err := m.getSessionKeyIndex(key); err == nil {
|
||||
return index, nil
|
||||
}
|
||||
}
|
||||
|
||||
// By default, we use the next available bucket sequence as the key
|
||||
// index. But if forceNext is true, then it is assumed that some data
|
||||
// loss occurred and so the sequence is incremented a by a jump of 1000
|
||||
// so that we can arrive at a brand new key index quicker.
|
||||
nextIndex := m.nextIndex + 1
|
||||
if forceNext {
|
||||
nextIndex = m.nextIndex + 1000
|
||||
}
|
||||
m.nextIndex = nextIndex
|
||||
m.indexes[key] = nextIndex
|
||||
|
||||
return nextIndex, nil
|
||||
}
|
||||
|
||||
func (m *ClientDB) getSessionKeyIndex(key keyIndexKey) (uint32, error) {
|
||||
if index, ok := m.indexes[key]; ok {
|
||||
return index, nil
|
||||
}
|
||||
|
||||
if key.blobType == blob.TypeAltruistCommit {
|
||||
if index, ok := m.legacyIndexes[key.towerID]; ok {
|
||||
return index, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, wtdb.ErrNoReservedKeyIndex
|
||||
}
|
||||
|
||||
// CommitUpdate persists the CommittedUpdate provided in the slot for (session,
|
||||
// seqNum). This allows the client to retransmit this update on startup.
|
||||
func (m *ClientDB) CommitUpdate(id *wtdb.SessionID,
|
||||
update *wtdb.CommittedUpdate) (uint16, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Fail if session doesn't exist.
|
||||
session, ok := m.activeSessions[*id]
|
||||
if !ok {
|
||||
return 0, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Check if an update has already been committed for this state.
|
||||
for _, dbUpdate := range m.committedUpdates[session.ID] {
|
||||
if dbUpdate.SeqNum == update.SeqNum {
|
||||
// If the breach hint matches, we'll just return the
|
||||
// last applied value so the client can retransmit.
|
||||
if dbUpdate.Hint == update.Hint {
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
|
||||
// Otherwise, fail since the breach hint doesn't match.
|
||||
return 0, wtdb.ErrUpdateAlreadyCommitted
|
||||
}
|
||||
}
|
||||
|
||||
// Sequence number must increment.
|
||||
if update.SeqNum != session.SeqNum+1 {
|
||||
return 0, wtdb.ErrCommitUnorderedUpdate
|
||||
}
|
||||
|
||||
// Save the update and increment the sequence number.
|
||||
m.committedUpdates[session.ID] = append(
|
||||
m.committedUpdates[session.ID], *update,
|
||||
)
|
||||
session.SeqNum++
|
||||
m.activeSessions[*id] = session
|
||||
|
||||
return session.TowerLastApplied, nil
|
||||
}
|
||||
|
||||
// AckUpdate persists an acknowledgment for a given (session, seqnum) pair. This
|
||||
// removes the update from the set of committed updates, and validates the
|
||||
// lastApplied value returned from the tower.
|
||||
func (m *ClientDB) AckUpdate(id *wtdb.SessionID, seqNum,
|
||||
lastApplied uint16) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Fail if session doesn't exist.
|
||||
session, ok := m.activeSessions[*id]
|
||||
if !ok {
|
||||
return wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Ensure the returned last applied value does not exceed the highest
|
||||
// allocated sequence number.
|
||||
if lastApplied > session.SeqNum {
|
||||
return wtdb.ErrUnallocatedLastApplied
|
||||
}
|
||||
|
||||
// Ensure the last applied value isn't lower than a previous one sent by
|
||||
// the tower.
|
||||
if lastApplied < session.TowerLastApplied {
|
||||
return wtdb.ErrLastAppliedReversion
|
||||
}
|
||||
|
||||
// Retrieve the committed update, failing if none is found. We should
|
||||
// only receive acks for state updates that we send.
|
||||
updates := m.committedUpdates[session.ID]
|
||||
for i, update := range updates {
|
||||
if update.SeqNum != seqNum {
|
||||
continue
|
||||
}
|
||||
|
||||
// Add sessionID to channel.
|
||||
channel, ok := m.channels[update.BackupID.ChanID]
|
||||
if !ok {
|
||||
return wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
channel.sessions[*id] = true
|
||||
|
||||
// Remove the committed update from disk and mark the update as
|
||||
// acked. The tower last applied value is also recorded to send
|
||||
// along with the next update.
|
||||
copy(updates[:i], updates[i+1:])
|
||||
updates[len(updates)-1] = wtdb.CommittedUpdate{}
|
||||
m.committedUpdates[session.ID] = updates[:len(updates)-1]
|
||||
|
||||
chanID := update.BackupID.ChanID
|
||||
if _, ok := m.ackedUpdates[*id][update.BackupID.ChanID]; !ok {
|
||||
index, err := wtdb.NewRangeIndex(nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.ackedUpdates[*id][chanID] = index
|
||||
m.persistedAckedUpdates[*id][chanID] = newMockKVStore()
|
||||
}
|
||||
|
||||
err := m.ackedUpdates[*id][chanID].Add(
|
||||
update.BackupID.CommitHeight,
|
||||
m.persistedAckedUpdates[*id][chanID],
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
session.TowerLastApplied = lastApplied
|
||||
|
||||
m.activeSessions[*id] = session
|
||||
return nil
|
||||
}
|
||||
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// GetDBQueue returns a BackupID Queue instance under the given name space.
|
||||
func (m *ClientDB) GetDBQueue(namespace []byte) wtdb.Queue[*wtdb.BackupID] {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if q, ok := m.queues[string(namespace)]; ok {
|
||||
return q
|
||||
}
|
||||
|
||||
q := NewQueueDB[*wtdb.BackupID]()
|
||||
m.queues[string(namespace)] = q
|
||||
|
||||
return q
|
||||
}
|
||||
|
||||
// DeleteCommittedUpdate deletes the committed update with the given sequence
|
||||
// number from the given session.
|
||||
func (m *ClientDB) DeleteCommittedUpdate(id *wtdb.SessionID,
|
||||
seqNum uint16) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Fail if session doesn't exist.
|
||||
session, ok := m.activeSessions[*id]
|
||||
if !ok {
|
||||
return wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// Retrieve the committed update, failing if none is found.
|
||||
updates := m.committedUpdates[session.ID]
|
||||
for i, update := range updates {
|
||||
if update.SeqNum != seqNum {
|
||||
continue
|
||||
}
|
||||
|
||||
// Remove the committed update from "disk".
|
||||
updates = append(updates[:i], updates[i+1:]...)
|
||||
m.committedUpdates[session.ID] = updates
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return wtdb.ErrCommittedUpdateNotFound
|
||||
}
|
||||
|
||||
// ListClosableSessions fetches and returns the IDs for all sessions marked as
|
||||
// closable.
|
||||
func (m *ClientDB) ListClosableSessions() (map[wtdb.SessionID]uint32, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
cs := make(map[wtdb.SessionID]uint32, len(m.closableSessions))
|
||||
for id, height := range m.closableSessions {
|
||||
cs[id] = height
|
||||
}
|
||||
|
||||
return cs, nil
|
||||
}
|
||||
|
||||
// FetchChanSummaries loads a mapping from all registered channels to their
|
||||
// channel summaries. Only the channels that have not yet been marked as closed
|
||||
// will be loaded.
|
||||
func (m *ClientDB) FetchChanSummaries() (wtdb.ChannelSummaries, error) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
summaries := make(map[lnwire.ChannelID]wtdb.ClientChanSummary)
|
||||
for chanID, channel := range m.channels {
|
||||
// Don't load the channel if it has been marked as closed.
|
||||
if channel.closedHeight > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
summaries[chanID] = wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(
|
||||
channel.summary.SweepPkScript,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
return summaries, nil
|
||||
}
|
||||
|
||||
// MarkChannelClosed will mark a registered channel as closed by setting
|
||||
// its closed-height as the given block height. It returns a list of
|
||||
// session IDs for sessions that are now considered closable due to the
|
||||
// close of this channel.
|
||||
func (m *ClientDB) MarkChannelClosed(chanID lnwire.ChannelID,
|
||||
blockHeight uint32) ([]wtdb.SessionID, error) {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
// If there are no sessions for this channel, the channel details can be
|
||||
// deleted.
|
||||
if len(channel.sessions) == 0 {
|
||||
delete(m.channels, chanID)
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Mark the channel as closed.
|
||||
channel.closedHeight = blockHeight
|
||||
|
||||
// Now iterate through all the sessions of the channel to check if any
|
||||
// of them are closeable.
|
||||
var closableSessions []wtdb.SessionID
|
||||
for sessID := range channel.sessions {
|
||||
isClosable, err := m.isSessionClosable(sessID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !isClosable {
|
||||
continue
|
||||
}
|
||||
|
||||
closableSessions = append(closableSessions, sessID)
|
||||
|
||||
// Add session to "closableSessions" list and add the block
|
||||
// height that this last channel was closed in. This will be
|
||||
// used in future to determine when we should delete the
|
||||
// session.
|
||||
m.closableSessions[sessID] = blockHeight
|
||||
}
|
||||
|
||||
return closableSessions, nil
|
||||
}
|
||||
|
||||
// isSessionClosable returns true if a session is considered closable. A session
|
||||
// is considered closable only if:
|
||||
// 1) It has no un-acked updates
|
||||
// 2) It is exhausted (ie it cant accept any more updates)
|
||||
// 3) All the channels that it has acked-updates for are closed.
|
||||
func (m *ClientDB) isSessionClosable(id wtdb.SessionID) (bool, error) {
|
||||
// The session is not closable if it has un-acked updates.
|
||||
if len(m.committedUpdates[id]) > 0 {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
sess, ok := m.activeSessions[id]
|
||||
if !ok {
|
||||
return false, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
// The session is not closable if it is not yet exhausted.
|
||||
if sess.SeqNum != sess.Policy.MaxUpdates {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Iterate over each of the channels that the session has acked-updates
|
||||
// for. If any of those channels are not closed, then the session is
|
||||
// not yet closable.
|
||||
for chanID := range m.ackedUpdates[id] {
|
||||
channel, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
// Channel is not yet closed, and so we can not yet delete the
|
||||
// session.
|
||||
if channel.closedHeight == 0 {
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetClientSession loads the ClientSession with the given ID from the DB.
|
||||
func (m *ClientDB) GetClientSession(id wtdb.SessionID,
|
||||
opts ...wtdb.ClientSessionListOption) (*wtdb.ClientSession, error) {
|
||||
|
||||
cfg := wtdb.NewClientSessionCfg()
|
||||
for _, o := range opts {
|
||||
o(cfg)
|
||||
}
|
||||
|
||||
session, ok := m.activeSessions[id]
|
||||
if !ok {
|
||||
return nil, wtdb.ErrClientSessionNotFound
|
||||
}
|
||||
|
||||
if cfg.PerMaxHeight != nil {
|
||||
for chanID, index := range m.ackedUpdates[session.ID] {
|
||||
cfg.PerMaxHeight(&session, chanID, index.MaxHeight())
|
||||
}
|
||||
}
|
||||
|
||||
if cfg.PerCommittedUpdate != nil {
|
||||
for _, update := range m.committedUpdates[session.ID] {
|
||||
update := update
|
||||
cfg.PerCommittedUpdate(&session, &update)
|
||||
}
|
||||
}
|
||||
|
||||
return &session, nil
|
||||
}
|
||||
|
||||
// DeleteSession can be called when a session should be deleted from the DB.
|
||||
// All references to the session will also be deleted from the DB. Note that a
|
||||
// session will only be deleted if it is considered closable.
|
||||
func (m *ClientDB) DeleteSession(id wtdb.SessionID) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
_, ok := m.closableSessions[id]
|
||||
if !ok {
|
||||
return wtdb.ErrSessionNotClosable
|
||||
}
|
||||
|
||||
// For each of the channels, delete the session ID entry.
|
||||
for chanID := range m.ackedUpdates[id] {
|
||||
c, ok := m.channels[chanID]
|
||||
if !ok {
|
||||
return wtdb.ErrChannelNotRegistered
|
||||
}
|
||||
|
||||
delete(c.sessions, id)
|
||||
}
|
||||
|
||||
delete(m.closableSessions, id)
|
||||
delete(m.activeSessions, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RegisterChannel registers a channel for use within the client database. For
|
||||
// now, all that is stored in the channel summary is the sweep pkscript that
|
||||
// we'd like any tower sweeps to pay into. In the future, this will be extended
|
||||
// to contain more info to allow the client efficiently request historical
|
||||
// states to be backed up under the client's active policy.
|
||||
func (m *ClientDB) RegisterChannel(chanID lnwire.ChannelID,
|
||||
sweepPkScript []byte) error {
|
||||
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if _, ok := m.channels[chanID]; ok {
|
||||
return wtdb.ErrChannelAlreadyRegistered
|
||||
}
|
||||
|
||||
m.channels[chanID] = &channel{
|
||||
summary: &wtdb.ClientChanSummary{
|
||||
SweepPkScript: cloneBytes(sweepPkScript),
|
||||
},
|
||||
sessions: make(map[wtdb.SessionID]bool),
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func cloneBytes(b []byte) []byte {
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
bb := make([]byte, len(b))
|
||||
copy(bb, b)
|
||||
|
||||
return bb
|
||||
}
|
||||
|
||||
func copyTower(tower *wtdb.Tower) *wtdb.Tower {
|
||||
t := &wtdb.Tower{
|
||||
ID: tower.ID,
|
||||
IdentityKey: tower.IdentityKey,
|
||||
Addresses: make([]net.Addr, len(tower.Addresses)),
|
||||
}
|
||||
copy(t.Addresses, tower.Addresses)
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
type mockKVStore struct {
|
||||
kv map[uint64]uint64
|
||||
|
||||
err error
|
||||
}
|
||||
|
||||
func newMockKVStore() *mockKVStore {
|
||||
return &mockKVStore{
|
||||
kv: make(map[uint64]uint64),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockKVStore) Put(key, value []byte) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
|
||||
k := byteOrder.Uint64(key)
|
||||
v := byteOrder.Uint64(value)
|
||||
|
||||
m.kv[k] = v
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockKVStore) Delete(key []byte) error {
|
||||
if m.err != nil {
|
||||
return m.err
|
||||
}
|
||||
|
||||
k := byteOrder.Uint64(key)
|
||||
delete(m.kv, k)
|
||||
|
||||
return nil
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user