watchtower: make use of the new AddressIterator

This commit upgrades the wtclient package to make use of the new
`AddressIterator`. It does so by first creating new `Tower` and
`ClientSession` types. The new `Tower` type has an `AddressIterator`
instead of a list of addresses. The `ClientSession` type contains a
`Tower`.
This commit is contained in:
Elle Mouton
2022-10-12 09:47:38 +02:00
parent 7924542500
commit 8a7329b988
13 changed files with 274 additions and 184 deletions

View File

@ -13,7 +13,7 @@ import (
type TowerCandidateIterator interface { type TowerCandidateIterator interface {
// AddCandidate adds a new candidate tower to the iterator. If the // AddCandidate adds a new candidate tower to the iterator. If the
// candidate already exists, then any new addresses are added to it. // candidate already exists, then any new addresses are added to it.
AddCandidate(*wtdb.Tower) AddCandidate(*Tower)
// RemoveCandidate removes an existing candidate tower from the // RemoveCandidate removes an existing candidate tower from the
// iterator. An optional address can be provided to indicate a stale // iterator. An optional address can be provided to indicate a stale
@ -32,7 +32,7 @@ type TowerCandidateIterator interface {
// Next returns the next candidate tower. The iterator is not required // Next returns the next candidate tower. The iterator is not required
// to return results in any particular order. If no more candidates are // to return results in any particular order. If no more candidates are
// available, ErrTowerCandidatesExhausted is returned. // available, ErrTowerCandidatesExhausted is returned.
Next() (*wtdb.Tower, error) Next() (*Tower, error)
} }
// towerListIterator is a linked-list backed TowerCandidateIterator. // towerListIterator is a linked-list backed TowerCandidateIterator.
@ -40,7 +40,7 @@ type towerListIterator struct {
mu sync.Mutex mu sync.Mutex
queue *list.List queue *list.List
nextCandidate *list.Element nextCandidate *list.Element
candidates map[wtdb.TowerID]*wtdb.Tower candidates map[wtdb.TowerID]*Tower
} }
// Compile-time constraint to ensure *towerListIterator implements the // Compile-time constraint to ensure *towerListIterator implements the
@ -49,10 +49,10 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil)
// newTowerListIterator initializes a new towerListIterator from a variadic list // newTowerListIterator initializes a new towerListIterator from a variadic list
// of lnwire.NetAddresses. // of lnwire.NetAddresses.
func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { func newTowerListIterator(candidates ...*Tower) *towerListIterator {
iter := &towerListIterator{ iter := &towerListIterator{
queue: list.New(), queue: list.New(),
candidates: make(map[wtdb.TowerID]*wtdb.Tower), candidates: make(map[wtdb.TowerID]*Tower),
} }
for _, candidate := range candidates { for _, candidate := range candidates {
@ -79,7 +79,7 @@ func (t *towerListIterator) Reset() error {
// Next returns the next candidate tower. This iterator will always return // Next returns the next candidate tower. This iterator will always return
// candidates in the order given when the iterator was instantiated. If no more // candidates in the order given when the iterator was instantiated. If no more
// candidates are available, ErrTowerCandidatesExhausted is returned. // candidates are available, ErrTowerCandidatesExhausted is returned.
func (t *towerListIterator) Next() (*wtdb.Tower, error) { func (t *towerListIterator) Next() (*Tower, error) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -107,7 +107,7 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) {
// AddCandidate adds a new candidate tower to the iterator. If the candidate // AddCandidate adds a new candidate tower to the iterator. If the candidate
// already exists, then any new addresses are added to it. // already exists, then any new addresses are added to it.
func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { func (t *towerListIterator) AddCandidate(candidate *Tower) {
t.mu.Lock() t.mu.Lock()
defer t.mu.Unlock() defer t.mu.Unlock()
@ -121,8 +121,16 @@ func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) {
t.nextCandidate = t.queue.Back() t.nextCandidate = t.queue.Back()
} }
} else { } else {
for _, addr := range candidate.Addresses { candidate.Addresses.Reset()
tower.AddAddress(addr) firstAddr := candidate.Addresses.Peek()
tower.Addresses.Add(firstAddr)
for {
next, err := candidate.Addresses.Next()
if err != nil {
break
}
tower.Addresses.Add(next)
} }
} }
} }
@ -142,9 +150,9 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID,
return nil return nil
} }
if addr != nil { if addr != nil {
tower.RemoveAddress(addr) err := tower.Addresses.Remove(addr)
if len(tower.Addresses) == 0 { if err != nil {
return wtdb.ErrLastTowerAddr return err
} }
} else { } else {
delete(t.candidates, candidate) delete(t.candidates, candidate)

View File

@ -33,31 +33,38 @@ func randAddr(t *testing.T) net.Addr {
} }
} }
func randTower(t *testing.T) *wtdb.Tower { func randTower(t *testing.T) *Tower {
t.Helper() t.Helper()
priv, err := btcec.NewPrivateKey() priv, err := btcec.NewPrivateKey()
require.NoError(t, err, "unable to create private key") require.NoError(t, err, "unable to create private key")
pubKey := priv.PubKey() pubKey := priv.PubKey()
return &wtdb.Tower{ addrs, err := newAddressIterator(randAddr(t))
require.NoError(t, err)
return &Tower{
ID: wtdb.TowerID(rand.Uint64()), ID: wtdb.TowerID(rand.Uint64()),
IdentityKey: pubKey, IdentityKey: pubKey,
Addresses: []net.Addr{randAddr(t)}, Addresses: addrs,
} }
} }
func copyTower(tower *wtdb.Tower) *wtdb.Tower { func copyTower(t *testing.T, tower *Tower) *Tower {
t := &wtdb.Tower{ t.Helper()
addrs := tower.Addresses.GetAll()
addrIterator, err := newAddressIterator(addrs...)
require.NoError(t, err)
return &Tower{
ID: tower.ID, ID: tower.ID,
IdentityKey: tower.IdentityKey, IdentityKey: tower.IdentityKey,
Addresses: make([]net.Addr, len(tower.Addresses)), Addresses: addrIterator,
} }
copy(t.Addresses, tower.Addresses)
return t
} }
func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *Tower,
c *wtdb.Tower, active bool) { active bool) {
t.Helper() t.Helper()
@ -71,12 +78,14 @@ func assertActiveCandidate(t *testing.T, i TowerCandidateIterator,
c.ID) c.ID)
} }
func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) {
t.Helper() t.Helper()
tower, err := i.Next() tower, err := i.Next()
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, c, tower) require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey))
require.Equal(t, tower.ID, c.ID)
require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll())
} }
// TestTowerCandidateIterator asserts the internal state of a // TestTowerCandidateIterator asserts the internal state of a
@ -88,13 +97,13 @@ func TestTowerCandidateIterator(t *testing.T) {
// towers. We'll use copies of these towers within the iterator to // towers. We'll use copies of these towers within the iterator to
// ensure the iterator properly updates the state of its candidates. // ensure the iterator properly updates the state of its candidates.
const numTowers = 4 const numTowers = 4
towers := make([]*wtdb.Tower, 0, numTowers) towers := make([]*Tower, 0, numTowers)
for i := 0; i < numTowers; i++ { for i := 0; i < numTowers; i++ {
towers = append(towers, randTower(t)) towers = append(towers, randTower(t))
} }
towerCopies := make([]*wtdb.Tower, 0, numTowers) towerCopies := make([]*Tower, 0, numTowers)
for _, tower := range towers { for _, tower := range towers {
towerCopies = append(towerCopies, copyTower(tower)) towerCopies = append(towerCopies, copyTower(t, tower))
} }
towerIterator := newTowerListIterator(towerCopies...) towerIterator := newTowerListIterator(towerCopies...)
@ -112,13 +121,13 @@ func TestTowerCandidateIterator(t *testing.T) {
towerIterator.Reset() towerIterator.Reset()
// We'll then attempt to test the RemoveCandidate behavior of the // We'll then attempt to test the RemoveCandidate behavior of the
// iterator. We'll remove the address of the first tower, which should // iterator. We'll attempt to remove the address of the first tower,
// result in it not having any addresses left, but still being an active // which should result in an error due to it being the last address of
// candidate. // the tower.
firstTower := towers[0] firstTower := towers[0]
firstTowerAddr := firstTower.Addresses[0] firstTowerAddr := firstTower.Addresses.Peek()
firstTower.RemoveAddress(firstTowerAddr) err = towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr)
towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) require.ErrorIs(t, err, wtdb.ErrLastTowerAddr)
assertActiveCandidate(t, towerIterator, firstTower, true) assertActiveCandidate(t, towerIterator, firstTower, true)
assertNextCandidate(t, towerIterator, firstTower) assertNextCandidate(t, towerIterator, firstTower)
@ -126,7 +135,8 @@ func TestTowerCandidateIterator(t *testing.T) {
// not providing the optional address. Since it's been removed, we // not providing the optional address. Since it's been removed, we
// should expect to see the third tower next. // should expect to see the third tower next.
secondTower, thirdTower := towers[1], towers[2] secondTower, thirdTower := towers[1], towers[2]
towerIterator.RemoveCandidate(secondTower.ID, nil) err = towerIterator.RemoveCandidate(secondTower.ID, nil)
require.NoError(t, err)
assertActiveCandidate(t, towerIterator, secondTower, false) assertActiveCandidate(t, towerIterator, secondTower, false)
assertNextCandidate(t, towerIterator, thirdTower) assertNextCandidate(t, towerIterator, thirdTower)
@ -135,7 +145,7 @@ func TestTowerCandidateIterator(t *testing.T) {
// iterator, but the new address should be. // iterator, but the new address should be.
fourthTower := towers[3] fourthTower := towers[3]
assertActiveCandidate(t, towerIterator, fourthTower, true) assertActiveCandidate(t, towerIterator, fourthTower, true)
fourthTower.AddAddress(randAddr(t)) fourthTower.Addresses.Add(randAddr(t))
towerIterator.AddCandidate(fourthTower) towerIterator.AddCandidate(fourthTower)
assertNextCandidate(t, towerIterator, fourthTower) assertNextCandidate(t, towerIterator, fourthTower)

View File

@ -45,8 +45,8 @@ const (
// genActiveSessionFilter generates a filter that selects active sessions that // genActiveSessionFilter generates a filter that selects active sessions that
// also match the desired channel type, either legacy or anchor. // also match the desired channel type, either legacy or anchor.
func genActiveSessionFilter(anchor bool) func(*wtdb.ClientSession) bool { func genActiveSessionFilter(anchor bool) func(*ClientSession) bool {
return func(s *wtdb.ClientSession) bool { return func(s *ClientSession) bool {
return s.Status == wtdb.CSessionActive && return s.Status == wtdb.CSessionActive &&
anchor == s.Policy.IsAnchorChannel() anchor == s.Policy.IsAnchorChannel()
} }
@ -241,7 +241,7 @@ type TowerClient struct {
negotiator SessionNegotiator negotiator SessionNegotiator
candidateTowers TowerCandidateIterator candidateTowers TowerCandidateIterator
candidateSessions map[wtdb.SessionID]*wtdb.ClientSession candidateSessions map[wtdb.SessionID]*ClientSession
activeSessions sessionQueueSet activeSessions sessionQueueSet
sessionQueue *sessionQueue sessionQueue *sessionQueue
@ -351,7 +351,7 @@ func New(config *Config) (*TowerClient, error) {
activeSessionFilter := genActiveSessionFilter(isAnchorClient) activeSessionFilter := genActiveSessionFilter(isAnchorClient)
candidateTowers := newTowerListIterator() candidateTowers := newTowerListIterator()
perActiveTower := func(tower *wtdb.Tower) { perActiveTower := func(tower *Tower) {
// If the tower has already been marked as active, then there is // If the tower has already been marked as active, then there is
// no need to add it to the iterator again. // no need to add it to the iterator again.
if candidateTowers.IsActive(tower.ID) { if candidateTowers.IsActive(tower.ID) {
@ -400,18 +400,23 @@ func New(config *Config) (*TowerClient, error) {
// sessionFilter check then the perActiveTower call-back will be called on that // sessionFilter check then the perActiveTower call-back will be called on that
// tower. // tower.
func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
sessionFilter func(*wtdb.ClientSession) bool, sessionFilter func(*ClientSession) bool,
perActiveTower func(tower *wtdb.Tower), perActiveTower func(tower *Tower),
opts ...wtdb.ClientSessionListOption) ( opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*ClientSession, error) {
towers, err := db.ListTowers() towers, err := db.ListTowers()
if err != nil { if err != nil {
return nil, err return nil, err
} }
candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) candidateSessions := make(map[wtdb.SessionID]*ClientSession)
for _, tower := range towers { for _, dbTower := range towers {
tower, err := NewTowerFromDBTower(dbTower)
if err != nil {
return nil, err
}
sessions, err := db.ListClientSessions(&tower.ID, opts...) sessions, err := db.ListClientSessions(&tower.ID, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
@ -427,16 +432,24 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.SessionKeyECDH = keychain.NewPubKeyECDH(
sessionKeyECDH := keychain.NewPubKeyECDH(
towerKeyDesc, keyRing, towerKeyDesc, keyRing,
) )
if !sessionFilter(s) { cs := &ClientSession{
ID: s.ID,
ClientSessionBody: s.ClientSessionBody,
Tower: tower,
SessionKeyECDH: sessionKeyECDH,
}
if !sessionFilter(cs) {
continue continue
} }
// Add the session to the set of candidate sessions. // Add the session to the set of candidate sessions.
candidateSessions[s.ID] = s candidateSessions[s.ID] = cs
perActiveTower(tower) perActiveTower(tower)
} }
} }
@ -452,11 +465,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing,
// ClientSession's SessionPrivKey field is desired, otherwise, the existing // ClientSession's SessionPrivKey field is desired, otherwise, the existing
// ListClientSessions method should be used. // ListClientSessions method should be used.
func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
passesFilter func(*wtdb.ClientSession) bool, passesFilter func(*ClientSession) bool,
opts ...wtdb.ClientSessionListOption) ( opts ...wtdb.ClientSessionListOption) (
map[wtdb.SessionID]*wtdb.ClientSession, error) { map[wtdb.SessionID]*ClientSession, error) {
sessions, err := db.ListClientSessions(forTower, opts...) dbSessions, err := db.ListClientSessions(forTower, opts...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -466,7 +479,13 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
// be able to communicate with the towers and authenticate session // be able to communicate with the towers and authenticate session
// requests. This prevents us from having to store the private keys on // requests. This prevents us from having to store the private keys on
// disk. // disk.
for _, s := range sessions { sessions := make(map[wtdb.SessionID]*ClientSession)
for _, s := range dbSessions {
dbTower, err := db.LoadTowerByID(s.TowerID)
if err != nil {
return nil, err
}
towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession, Family: keychain.KeyFamilyTowerSession,
Index: s.KeyIndex, Index: s.KeyIndex,
@ -474,13 +493,27 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID,
if err != nil { if err != nil {
return nil, err return nil, err
} }
s.SessionKeyECDH = keychain.NewPubKeyECDH(towerKeyDesc, keyRing) sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing)
tower, err := NewTowerFromDBTower(dbTower)
if err != nil {
return nil, err
}
cs := &ClientSession{
ID: s.ID,
ClientSessionBody: s.ClientSessionBody,
Tower: tower,
SessionKeyECDH: sessionKeyECDH,
}
// If an optional filter was provided, use it to filter out any // If an optional filter was provided, use it to filter out any
// undesired sessions. // undesired sessions.
if passesFilter != nil && !passesFilter(s) { if passesFilter != nil && !passesFilter(cs) {
delete(sessions, s.ID) continue
} }
sessions[s.ID] = cs
} }
return sessions, nil return sessions, nil
@ -710,7 +743,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) {
// Select any candidate session at random, and remove it from the set of // Select any candidate session at random, and remove it from the set of
// candidate sessions. // candidate sessions.
var candidateSession *wtdb.ClientSession var candidateSession *ClientSession
for id, sessionInfo := range c.candidateSessions { for id, sessionInfo := range c.candidateSessions {
delete(c.candidateSessions, id) delete(c.candidateSessions, id)
@ -1069,7 +1102,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error
// newSessionQueue creates a sessionQueue from a ClientSession loaded from the // newSessionQueue creates a sessionQueue from a ClientSession loaded from the
// database and supplying it with the resources needed by the client. // database and supplying it with the resources needed by the client.
func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, func (c *TowerClient) newSessionQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue { updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{ return newSessionQueue(&sessionQueueConfig{
@ -1089,7 +1122,7 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession,
// getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the
// passed ClientSession. If it exists, the active sessionQueue is returned. // passed ClientSession. If it exists, the active sessionQueue is returned.
// Otherwise, a new sessionQueue is initialized and added to the set. // Otherwise, a new sessionQueue is initialized and added to the set.
func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, func (c *TowerClient) getOrInitActiveQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue { updates []wtdb.CommittedUpdate) *sessionQueue {
if sq, ok := c.activeSessions[s.ID]; ok { if sq, ok := c.activeSessions[s.ID]; ok {
@ -1103,7 +1136,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession,
// adds the sessionQueue to the activeSessions set, and starts the sessionQueue // adds the sessionQueue to the activeSessions set, and starts the sessionQueue
// so that it can deliver any committed updates or begin accepting newly // so that it can deliver any committed updates or begin accepting newly
// assigned tasks. // assigned tasks.
func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, func (c *TowerClient) initActiveQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue { updates []wtdb.CommittedUpdate) *sessionQueue {
// Initialize the session queue, providing it with all the resources it // Initialize the session queue, providing it with all the resources it
@ -1156,10 +1189,16 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error {
// We'll start by updating our persisted state, followed by our // We'll start by updating our persisted state, followed by our
// in-memory state, with the new tower. This might not actually be a new // in-memory state, with the new tower. This might not actually be a new
// tower, but it might include a new address at which it can be reached. // tower, but it might include a new address at which it can be reached.
tower, err := c.cfg.DB.CreateTower(msg.addr) dbTower, err := c.cfg.DB.CreateTower(msg.addr)
if err != nil { if err != nil {
return err return err
} }
tower, err := NewTowerFromDBTower(dbTower)
if err != nil {
return err
}
c.candidateTowers.AddCandidate(tower) c.candidateTowers.AddCandidate(tower)
// Include all of its corresponding sessions to our set of candidates. // Include all of its corresponding sessions to our set of candidates.
@ -1251,7 +1290,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
// If our active session queue corresponds to the stale tower, we'll // If our active session queue corresponds to the stale tower, we'll
// proceed to negotiate a new one. // proceed to negotiate a new one.
if c.sessionQueue != nil { if c.sessionQueue != nil {
activeTower := c.sessionQueue.towerAddr.IdentityKey.SerializeCompressed() activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed()
if bytes.Equal(pubKey, activeTower) { if bytes.Equal(pubKey, activeTower) {
c.sessionQueue = nil c.sessionQueue = nil
} }

View File

@ -1471,10 +1471,8 @@ var clientTests = []clientTest{
}, },
{ {
// Assert that if a client changes the address for a server and // Assert that if a client changes the address for a server and
// then tries to back up updates then the client will not switch // then tries to back up updates then the client will switch to
// to the new address. The client will only use the server's new // the new address.
// address after a restart. This is a bug that will be fixed in
// a future commit.
name: "change address of existing session", name: "change address of existing session",
cfg: harnessCfg{ cfg: harnessCfg{
localBalance: localBalance, localBalance: localBalance,
@ -1535,16 +1533,7 @@ var clientTests = []clientTest{
// Now attempt to back up the rest of the updates. // Now attempt to back up the rest of the updates.
h.backupStates(chanID, numUpdates/2, maxUpdates, nil) h.backupStates(chanID, numUpdates/2, maxUpdates, nil)
// Assert that the server does not receive the updates. // Assert that the server does receive the updates.
h.waitServerUpdates(nil, waitTime)
// Restart the client and attempt to back up the updates
// again.
h.client.Stop()
h.startClient()
h.backupStates(chanID, numUpdates/2, maxUpdates, nil)
// The server should now receive the updates.
h.waitServerUpdates(hints[:maxUpdates], waitTime) h.waitServerUpdates(hints[:maxUpdates], waitTime)
}, },
}, },

View File

@ -20,10 +20,6 @@ var (
// down. // down.
ErrNegotiatorExiting = errors.New("negotiator exiting") ErrNegotiatorExiting = errors.New("negotiator exiting")
// ErrNoTowerAddrs signals that the client could not be created because
// we have no addresses with which we can reach a tower.
ErrNoTowerAddrs = errors.New("no tower addresses")
// ErrFailedNegotiation signals that the session negotiator could not // ErrFailedNegotiation signals that the session negotiator could not
// acquire a new session as requested. // acquire a new session as requested.
ErrFailedNegotiation = errors.New("session negotiation unsuccessful") ErrFailedNegotiation = errors.New("session negotiation unsuccessful")

View File

@ -118,3 +118,50 @@ type ECDHKeyRing interface {
// key. // key.
DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error)
} }
// Tower represents the info about a watchtower server that a watchtower client
// needs in order to connect to it.
type Tower struct {
// ID is the unique, db-assigned, identifier for this tower.
ID wtdb.TowerID
// IdentityKey is the public key of the remote node, used to
// authenticate the brontide transport.
IdentityKey *btcec.PublicKey
// Addresses is an AddressIterator that can be used to manage the
// addresses for this tower.
Addresses AddressIterator
}
// NewTowerFromDBTower converts a wtdb.Tower, which uses a static address list,
// into a Tower which uses an address iterator.
func NewTowerFromDBTower(t *wtdb.Tower) (*Tower, error) {
addrs, err := newAddressIterator(t.Addresses...)
if err != nil {
return nil, err
}
return &Tower{
ID: t.ID,
IdentityKey: t.IdentityKey,
Addresses: addrs,
}, nil
}
// ClientSession represents the session that a tower client has with a server.
type ClientSession struct {
// ID is the client's public key used when authenticating with the
// tower.
ID wtdb.SessionID
wtdb.ClientSessionBody
// Tower represents the tower that the client session has been made
// with.
Tower *Tower
// SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret
// key used to connect to the watchtower.
SessionKeyECDH keychain.SingleKeyECDH
}

View File

@ -25,7 +25,7 @@ type SessionNegotiator interface {
// NewSessions is a read-only channel where newly negotiated sessions // NewSessions is a read-only channel where newly negotiated sessions
// will be delivered. // will be delivered.
NewSessions() <-chan *wtdb.ClientSession NewSessions() <-chan *ClientSession
// Start safely initializes the session negotiator. // Start safely initializes the session negotiator.
Start() error Start() error
@ -105,8 +105,8 @@ type sessionNegotiator struct {
log btclog.Logger log btclog.Logger
dispatcher chan struct{} dispatcher chan struct{}
newSessions chan *wtdb.ClientSession newSessions chan *ClientSession
successfulNegotiations chan *wtdb.ClientSession successfulNegotiations chan *ClientSession
wg sync.WaitGroup wg sync.WaitGroup
quit chan struct{} quit chan struct{}
@ -139,8 +139,8 @@ func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator {
log: cfg.Log, log: cfg.Log,
localInit: localInit, localInit: localInit,
dispatcher: make(chan struct{}, 1), dispatcher: make(chan struct{}, 1),
newSessions: make(chan *wtdb.ClientSession), newSessions: make(chan *ClientSession),
successfulNegotiations: make(chan *wtdb.ClientSession), successfulNegotiations: make(chan *ClientSession),
quit: make(chan struct{}), quit: make(chan struct{}),
} }
} }
@ -171,7 +171,7 @@ func (n *sessionNegotiator) Stop() error {
// NewSessions returns a receive-only channel from which newly negotiated // NewSessions returns a receive-only channel from which newly negotiated
// sessions will be returned. // sessions will be returned.
func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { func (n *sessionNegotiator) NewSessions() <-chan *ClientSession {
return n.newSessions return n.newSessions
} }
@ -333,18 +333,10 @@ retryWithBackoff:
} }
} }
// createSession takes a tower an attempts to negotiate a session using any of // createSession takes a tower and attempts to negotiate a session using any of
// its stored addresses. This method returns after the first successful // its stored addresses. This method returns after the first successful
// negotiation, or after all addresses have failed with ErrFailedNegotiation. If // negotiation, or after all addresses have failed with ErrFailedNegotiation.
// the tower has no addresses, ErrNoTowerAddrs is returned. func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
keyIndex uint32) error {
// If the tower has no addresses, there's nothing we can do.
if len(tower.Addresses) == 0 {
return ErrNoTowerAddrs
}
sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey( sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey(
keychain.KeyLocator{ keychain.KeyLocator{
Family: keychain.KeyFamilyTowerSession, Family: keychain.KeyFamilyTowerSession,
@ -358,8 +350,14 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
sessionKeyDesc, n.cfg.SecretKeyRing, sessionKeyDesc, n.cfg.SecretKeyRing,
) )
for _, lnAddr := range tower.LNAddrs() { addr := tower.Addresses.Peek()
err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr) for {
lnAddr := &lnwire.NetAddress{
IdentityKey: tower.IdentityKey,
Address: addr,
}
err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
switch { switch {
case err == ErrPermanentTowerFailure: case err == ErrPermanentTowerFailure:
// TODO(conner): report to iterator? can then be reset // TODO(conner): report to iterator? can then be reset
@ -370,6 +368,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
n.log.Debugf("Request for session negotiation with "+ n.log.Debugf("Request for session negotiation with "+
"tower=%s failed, trying again -- reason: "+ "tower=%s failed, trying again -- reason: "+
"%v", lnAddr, err) "%v", lnAddr, err)
// Get the next tower address if there is one.
addr, err = tower.Addresses.Next()
if err == ErrAddressesExhausted {
tower.Addresses.Reset()
return ErrFailedNegotiation
}
continue continue
default: default:
@ -385,7 +392,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower,
// returns true if all steps succeed and the new session has been persisted, and // returns true if all steps succeed and the new session has been persisted, and
// fails otherwise. // fails otherwise.
func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { keyIndex uint32, tower *Tower, lnAddr *lnwire.NetAddress) error {
// Connect to the tower address using our generated session key. // Connect to the tower address using our generated session key.
conn, err := n.cfg.Dial(sessionKey, lnAddr) conn, err := n.cfg.Dial(sessionKey, lnAddr)
@ -456,26 +463,31 @@ func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH,
rewardPkScript := createSessionReply.Data rewardPkScript := createSessionReply.Data
sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey()) sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey())
clientSession := &wtdb.ClientSession{ dbClientSession := &wtdb.ClientSession{
ClientSessionBody: wtdb.ClientSessionBody{ ClientSessionBody: wtdb.ClientSessionBody{
TowerID: tower.ID, TowerID: tower.ID,
KeyIndex: keyIndex, KeyIndex: keyIndex,
Policy: n.cfg.Policy, Policy: n.cfg.Policy,
RewardPkScript: rewardPkScript, RewardPkScript: rewardPkScript,
}, },
Tower: tower, ID: sessionID,
SessionKeyECDH: sessionKey,
ID: sessionID,
} }
err = n.cfg.DB.CreateClientSession(clientSession) err = n.cfg.DB.CreateClientSession(dbClientSession)
if err != nil { if err != nil {
return fmt.Errorf("unable to persist ClientSession: %v", return fmt.Errorf("unable to persist ClientSession: %v",
err) err)
} }
n.log.Debugf("New session negotiated with %s, policy: %s", n.log.Debugf("New session negotiated with %s, policy: %s",
lnAddr, clientSession.Policy) lnAddr, dbClientSession.Policy)
clientSession := &ClientSession{
ID: sessionID,
ClientSessionBody: dbClientSession.ClientSessionBody,
Tower: tower,
SessionKeyECDH: sessionKey,
}
// We have a newly negotiated session, return it to the // We have a newly negotiated session, return it to the
// dispatcher so that it can update how many outstanding // dispatcher so that it can update how many outstanding

View File

@ -34,7 +34,7 @@ const (
type sessionQueueConfig struct { type sessionQueueConfig struct {
// ClientSession provides access to the negotiated session parameters // ClientSession provides access to the negotiated session parameters
// and updating its persistent storage. // and updating its persistent storage.
ClientSession *wtdb.ClientSession ClientSession *ClientSession
// ChainHash identifies the chain for which the session's justice // ChainHash identifies the chain for which the session's justice
// transactions are targeted. // transactions are targeted.
@ -97,7 +97,7 @@ type sessionQueue struct {
queueCond *sync.Cond queueCond *sync.Cond
localInit *wtwire.Init localInit *wtwire.Init
towerAddr *lnwire.NetAddress tower *Tower
seqNum uint16 seqNum uint16
@ -117,18 +117,13 @@ func newSessionQueue(cfg *sessionQueueConfig,
cfg.ChainHash, cfg.ChainHash,
) )
towerAddr := &lnwire.NetAddress{
IdentityKey: cfg.ClientSession.Tower.IdentityKey,
Address: cfg.ClientSession.Tower.Addresses[0],
}
sq := &sessionQueue{ sq := &sessionQueue{
cfg: cfg, cfg: cfg,
log: cfg.Log, log: cfg.Log,
commitQueue: list.New(), commitQueue: list.New(),
pendingQueue: list.New(), pendingQueue: list.New(),
localInit: localInit, localInit: localInit,
towerAddr: towerAddr, tower: cfg.ClientSession.Tower,
seqNum: cfg.ClientSession.SeqNum, seqNum: cfg.ClientSession.SeqNum,
retryBackoff: cfg.MinBackoff, retryBackoff: cfg.MinBackoff,
quit: make(chan struct{}), quit: make(chan struct{}),
@ -293,18 +288,48 @@ func (q *sessionQueue) sessionManager() {
// drainBackups attempts to send all pending updates in the queue to the tower. // drainBackups attempts to send all pending updates in the queue to the tower.
func (q *sessionQueue) drainBackups() { func (q *sessionQueue) drainBackups() {
// First, check that we are able to dial this session's tower. var (
conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionKeyECDH, q.towerAddr) conn wtserver.Peer
if err != nil { err error
q.log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v", towerAddr = q.tower.Addresses.Peek()
q.ID(), q.towerAddr, err) )
q.increaseBackoff() for {
select { q.log.Infof("SessionQueue(%s) attempting to dial tower at %v",
case <-time.After(q.retryBackoff): q.ID(), towerAddr)
case <-q.forceQuit:
// First, check that we are able to dial this session's tower.
conn, err = q.cfg.Dial(
q.cfg.ClientSession.SessionKeyECDH, &lnwire.NetAddress{
IdentityKey: q.tower.IdentityKey,
Address: towerAddr,
},
)
if err != nil {
// If there are more addrs available, immediately try
// those.
nextAddr, iteratorErr := q.tower.Addresses.Next()
if iteratorErr == nil {
towerAddr = nextAddr
continue
}
// Otherwise, if we have exhausted the address list,
// back off and try again later.
q.tower.Addresses.Reset()
q.log.Errorf("SessionQueue(%s) unable to dial tower "+
"at any available Addresses: %v", q.ID(), err)
q.increaseBackoff()
select {
case <-time.After(q.retryBackoff):
case <-q.forceQuit:
}
return
} }
return
break
} }
defer conn.Close() defer conn.Close()
@ -324,9 +349,7 @@ func (q *sessionQueue) drainBackups() {
} }
// Now, send the state update to the tower and wait for a reply. // Now, send the state update to the tower and wait for a reply.
err = q.sendStateUpdate( err = q.sendStateUpdate(conn, stateUpdate, sendInit, isPending)
conn, stateUpdate, q.localInit, sendInit, isPending,
)
if err != nil { if err != nil {
q.log.Errorf("SessionQueue(%s) unable to send state "+ q.log.Errorf("SessionQueue(%s) unable to send state "+
"update: %v", q.ID(), err) "update: %v", q.ID(), err)
@ -483,8 +506,12 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool,
// variable indicates whether we should back off before attempting to send the // variable indicates whether we should back off before attempting to send the
// next state update. // next state update.
func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init, stateUpdate *wtwire.StateUpdate, sendInit, isPending bool) error {
sendInit, isPending bool) error {
towerAddr := &lnwire.NetAddress{
IdentityKey: conn.RemotePub(),
Address: conn.RemoteAddr(),
}
// If this is the first message being sent to the tower, we must send an // If this is the first message being sent to the tower, we must send an
// Init message to establish that server supports the features we // Init message to establish that server supports the features we
@ -505,7 +532,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
remoteInit, ok := remoteMsg.(*wtwire.Init) remoteInit, ok := remoteMsg.(*wtwire.Init)
if !ok { if !ok {
return fmt.Errorf("watchtower %s responded with %T "+ return fmt.Errorf("watchtower %s responded with %T "+
"to Init", q.towerAddr, remoteMsg) "to Init", towerAddr, remoteMsg)
} }
// Validate Init. // Validate Init.
@ -532,7 +559,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply)
if !ok { if !ok {
return fmt.Errorf("watchtower %s responded with %T to "+ return fmt.Errorf("watchtower %s responded with %T to "+
"StateUpdate", q.towerAddr, remoteMsg) "StateUpdate", towerAddr, remoteMsg)
} }
// Process the reply from the tower. // Process the reply from the tower.
@ -547,8 +574,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
err := fmt.Errorf("received error code %v in "+ err := fmt.Errorf("received error code %v in "+
"StateUpdateReply for seqnum=%d", "StateUpdateReply for seqnum=%d",
stateUpdateReply.Code, stateUpdate.SeqNum) stateUpdateReply.Code, stateUpdate.SeqNum)
q.log.Warnf("SessionQueue(%s) unable to upload state update to "+ q.log.Warnf("SessionQueue(%s) unable to upload state update "+
"tower=%s: %v", q.ID(), q.towerAddr, err) "to tower=%s: %v", q.ID(), towerAddr, err)
return err return err
} }
@ -559,7 +586,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer,
// TODO(conner): borked watchtower // TODO(conner): borked watchtower
err = fmt.Errorf("unable to ack seqnum=%d: %v", err = fmt.Errorf("unable to ack seqnum=%d: %v",
stateUpdate.SeqNum, err) stateUpdate.SeqNum, err)
q.log.Errorf("SessionQueue(%v) failed to ack update: %v", q.ID(), err) q.log.Errorf("SessionQueue(%v) failed to ack update: %v",
q.ID(), err)
return err return err
case err == wtdb.ErrLastAppliedReversion: case err == wtdb.ErrLastAppliedReversion:

View File

@ -429,7 +429,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error {
} }
towerSessions, err := listTowerSessions( towerSessions, err := listTowerSessions(
towerID, sessions, towers, towersToSessionsIndex, towerID, sessions, towersToSessionsIndex,
WithPerCommittedUpdate(perCommittedUpdate), WithPerCommittedUpdate(perCommittedUpdate),
) )
if err != nil { if err != nil {
@ -766,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
// known to the db. // known to the db.
if id == nil { if id == nil {
clientSessions, err = listClientAllSessions( clientSessions, err = listClientAllSessions(
sessions, towers, opts..., sessions, opts...,
) )
return err return err
} }
@ -778,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
} }
clientSessions, err = listTowerSessions( clientSessions, err = listTowerSessions(
*id, sessions, towers, towerToSessionIndex, opts..., *id, sessions, towerToSessionIndex, opts...,
) )
return err return err
}, func() { }, func() {
@ -792,7 +792,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID,
} }
// listClientAllSessions returns the set of all client sessions known to the db. // listClientAllSessions returns the set of all client sessions known to the db.
func listClientAllSessions(sessions, towers kvdb.RBucket, func listClientAllSessions(sessions kvdb.RBucket,
opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) {
clientSessions := make(map[SessionID]*ClientSession) clientSessions := make(map[SessionID]*ClientSession)
@ -801,7 +801,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession(sessions, towers, k, opts...) session, err := getClientSession(sessions, k, opts...)
if err != nil { if err != nil {
return err return err
} }
@ -819,7 +819,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket,
// listTowerSessions returns the set of all client sessions known to the db // listTowerSessions returns the set of all client sessions known to the db
// that are associated with the given tower id. // that are associated with the given tower id.
func listTowerSessions(id TowerID, sessionsBkt, towersBkt, func listTowerSessions(id TowerID, sessionsBkt,
towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) (
map[SessionID]*ClientSession, error) { map[SessionID]*ClientSession, error) {
@ -834,9 +834,7 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt,
// the CommittedUpdates and AckedUpdates on startup to resume // the CommittedUpdates and AckedUpdates on startup to resume
// committed updates and compute the highest known commit height // committed updates and compute the highest known commit height
// for each channel. // for each channel.
session, err := getClientSession( session, err := getClientSession(sessionsBkt, k, opts...)
sessionsBkt, towersBkt, k, opts...,
)
if err != nil { if err != nil {
return err return err
} }
@ -1248,7 +1246,7 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption {
// getClientSession loads the full ClientSession associated with the serialized // getClientSession loads the full ClientSession associated with the serialized
// session id. This method populates the CommittedUpdates, AckUpdates and Tower // session id. This method populates the CommittedUpdates, AckUpdates and Tower
// in addition to the ClientSession's body. // in addition to the ClientSession's body.
func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, func getClientSession(sessions kvdb.RBucket, idBytes []byte,
opts ...ClientSessionListOption) (*ClientSession, error) { opts ...ClientSessionListOption) (*ClientSession, error) {
cfg := NewClientSessionCfg() cfg := NewClientSessionCfg()
@ -1261,13 +1259,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte,
return nil, err return nil, err
} }
// Fetch the tower associated with this session.
tower, err := getTower(towers, session.TowerID.Bytes())
if err != nil {
return nil, err
}
session.Tower = tower
// Can't fail because client session body has already been read. // Can't fail because client session body has already been read.
sessionBkt := sessions.NestedReadBucket(idBytes) sessionBkt := sessions.NestedReadBucket(idBytes)

View File

@ -343,8 +343,11 @@ func testCreateTower(h *clientDBHarness) {
h.loadTowerByID(20, wtdb.ErrTowerNotFound) h.loadTowerByID(20, wtdb.ErrTowerNotFound)
tower := h.newTower() tower := h.newTower()
require.Len(h.t, tower.LNAddrs(), 1) require.Len(h.t, tower.Addresses, 1)
towerAddr := tower.LNAddrs()[0] towerAddr := &lnwire.NetAddress{
IdentityKey: tower.IdentityKey,
Address: tower.Addresses[0],
}
// Load the tower from the database and assert that it matches the tower // Load the tower from the database and assert that it matches the tower
// we created. // we created.

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/blob"
"github.com/lightningnetwork/lnd/watchtower/wtpolicy" "github.com/lightningnetwork/lnd/watchtower/wtpolicy"
@ -36,19 +35,6 @@ type ClientSession struct {
ID SessionID ID SessionID
ClientSessionBody ClientSessionBody
// Tower holds the pubkey and address of the watchtower.
//
// NOTE: This value is not serialized. It is recovered by looking up the
// tower with TowerID.
Tower *Tower
// SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret
// key used to connect to the watchtower.
//
// NOTE: This value is not serialized. It is derived using the KeyIndex
// on startup to avoid storing private keys on disk.
SessionKeyECDH keychain.SingleKeyECDH
} }
// ClientSessionBody represents the primary components of a ClientSession that // ClientSessionBody represents the primary components of a ClientSession that

View File

@ -7,7 +7,6 @@ import (
"net" "net"
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/lnwire"
) )
// TowerID is a unique 64-bit identifier allocated to each unique watchtower. // TowerID is a unique 64-bit identifier allocated to each unique watchtower.
@ -77,23 +76,6 @@ func (t *Tower) RemoveAddress(addr net.Addr) {
} }
} }
// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's
// addresses. This can be used to have a client try multiple addresses for the
// same Tower.
//
// NOTE: This method is NOT safe for concurrent use.
func (t *Tower) LNAddrs() []*lnwire.NetAddress {
addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses))
for _, addr := range t.Addresses {
addrs = append(addrs, &lnwire.NetAddress{
IdentityKey: t.IdentityKey,
Address: addr,
})
}
return addrs
}
// String returns a user-friendly identifier of the tower. // String returns a user-friendly identifier of the tower.
func (t *Tower) String() string { func (t *Tower) String() string {
pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed()) pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed())

View File

@ -231,7 +231,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID,
if tower != nil && *tower != session.TowerID { if tower != nil && *tower != session.TowerID {
continue continue
} }
session.Tower = m.towers[session.TowerID]
sessions[session.ID] = &session sessions[session.ID] = &session
if cfg.PerAckedUpdate != nil { if cfg.PerAckedUpdate != nil {