From 8a7329b988811b354b5344e38a3bd19b2feec6dc Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 12 Oct 2022 09:47:38 +0200 Subject: [PATCH] watchtower: make use of the new AddressIterator This commit upgrades the wtclient package to make use of the new `AddressIterator`. It does so by first creating new `Tower` and `ClientSession` types. The new `Tower` type has an `AddressIterator` instead of a list of addresses. The `ClientSession` type contains a `Tower`. --- watchtower/wtclient/candidate_iterator.go | 32 ++++--- .../wtclient/candidate_iterator_test.go | 56 +++++++----- watchtower/wtclient/client.go | 89 +++++++++++++------ watchtower/wtclient/client_test.go | 17 +--- watchtower/wtclient/errors.go | 4 - watchtower/wtclient/interface.go | 47 ++++++++++ watchtower/wtclient/session_negotiator.go | 64 +++++++------ watchtower/wtclient/session_queue.go | 84 +++++++++++------ watchtower/wtdb/client_db.go | 25 ++---- watchtower/wtdb/client_db_test.go | 7 +- watchtower/wtdb/client_session.go | 14 --- watchtower/wtdb/tower.go | 18 ---- watchtower/wtmock/client_db.go | 1 - 13 files changed, 274 insertions(+), 184 deletions(-) diff --git a/watchtower/wtclient/candidate_iterator.go b/watchtower/wtclient/candidate_iterator.go index 5b48a68ef..f11ad2e35 100644 --- a/watchtower/wtclient/candidate_iterator.go +++ b/watchtower/wtclient/candidate_iterator.go @@ -13,7 +13,7 @@ import ( type TowerCandidateIterator interface { // AddCandidate adds a new candidate tower to the iterator. If the // candidate already exists, then any new addresses are added to it. - AddCandidate(*wtdb.Tower) + AddCandidate(*Tower) // RemoveCandidate removes an existing candidate tower from the // iterator. An optional address can be provided to indicate a stale @@ -32,7 +32,7 @@ type TowerCandidateIterator interface { // Next returns the next candidate tower. The iterator is not required // to return results in any particular order. If no more candidates are // available, ErrTowerCandidatesExhausted is returned. - Next() (*wtdb.Tower, error) + Next() (*Tower, error) } // towerListIterator is a linked-list backed TowerCandidateIterator. @@ -40,7 +40,7 @@ type towerListIterator struct { mu sync.Mutex queue *list.List nextCandidate *list.Element - candidates map[wtdb.TowerID]*wtdb.Tower + candidates map[wtdb.TowerID]*Tower } // Compile-time constraint to ensure *towerListIterator implements the @@ -49,10 +49,10 @@ var _ TowerCandidateIterator = (*towerListIterator)(nil) // newTowerListIterator initializes a new towerListIterator from a variadic list // of lnwire.NetAddresses. -func newTowerListIterator(candidates ...*wtdb.Tower) *towerListIterator { +func newTowerListIterator(candidates ...*Tower) *towerListIterator { iter := &towerListIterator{ queue: list.New(), - candidates: make(map[wtdb.TowerID]*wtdb.Tower), + candidates: make(map[wtdb.TowerID]*Tower), } for _, candidate := range candidates { @@ -79,7 +79,7 @@ func (t *towerListIterator) Reset() error { // Next returns the next candidate tower. This iterator will always return // candidates in the order given when the iterator was instantiated. If no more // candidates are available, ErrTowerCandidatesExhausted is returned. -func (t *towerListIterator) Next() (*wtdb.Tower, error) { +func (t *towerListIterator) Next() (*Tower, error) { t.mu.Lock() defer t.mu.Unlock() @@ -107,7 +107,7 @@ func (t *towerListIterator) Next() (*wtdb.Tower, error) { // AddCandidate adds a new candidate tower to the iterator. If the candidate // already exists, then any new addresses are added to it. -func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { +func (t *towerListIterator) AddCandidate(candidate *Tower) { t.mu.Lock() defer t.mu.Unlock() @@ -121,8 +121,16 @@ func (t *towerListIterator) AddCandidate(candidate *wtdb.Tower) { t.nextCandidate = t.queue.Back() } } else { - for _, addr := range candidate.Addresses { - tower.AddAddress(addr) + candidate.Addresses.Reset() + firstAddr := candidate.Addresses.Peek() + tower.Addresses.Add(firstAddr) + for { + next, err := candidate.Addresses.Next() + if err != nil { + break + } + + tower.Addresses.Add(next) } } } @@ -142,9 +150,9 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID, return nil } if addr != nil { - tower.RemoveAddress(addr) - if len(tower.Addresses) == 0 { - return wtdb.ErrLastTowerAddr + err := tower.Addresses.Remove(addr) + if err != nil { + return err } } else { delete(t.candidates, candidate) diff --git a/watchtower/wtclient/candidate_iterator_test.go b/watchtower/wtclient/candidate_iterator_test.go index 9a919e103..7fe6ba723 100644 --- a/watchtower/wtclient/candidate_iterator_test.go +++ b/watchtower/wtclient/candidate_iterator_test.go @@ -33,31 +33,38 @@ func randAddr(t *testing.T) net.Addr { } } -func randTower(t *testing.T) *wtdb.Tower { +func randTower(t *testing.T) *Tower { t.Helper() priv, err := btcec.NewPrivateKey() require.NoError(t, err, "unable to create private key") pubKey := priv.PubKey() - return &wtdb.Tower{ + addrs, err := newAddressIterator(randAddr(t)) + require.NoError(t, err) + + return &Tower{ ID: wtdb.TowerID(rand.Uint64()), IdentityKey: pubKey, - Addresses: []net.Addr{randAddr(t)}, + Addresses: addrs, } } -func copyTower(tower *wtdb.Tower) *wtdb.Tower { - t := &wtdb.Tower{ +func copyTower(t *testing.T, tower *Tower) *Tower { + t.Helper() + + addrs := tower.Addresses.GetAll() + addrIterator, err := newAddressIterator(addrs...) + require.NoError(t, err) + + return &Tower{ ID: tower.ID, IdentityKey: tower.IdentityKey, - Addresses: make([]net.Addr, len(tower.Addresses)), + Addresses: addrIterator, } - copy(t.Addresses, tower.Addresses) - return t } -func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, - c *wtdb.Tower, active bool) { +func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c *Tower, + active bool) { t.Helper() @@ -71,12 +78,14 @@ func assertActiveCandidate(t *testing.T, i TowerCandidateIterator, c.ID) } -func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *wtdb.Tower) { +func assertNextCandidate(t *testing.T, i TowerCandidateIterator, c *Tower) { t.Helper() tower, err := i.Next() require.NoError(t, err) - require.Equal(t, c, tower) + require.True(t, tower.IdentityKey.IsEqual(c.IdentityKey)) + require.Equal(t, tower.ID, c.ID) + require.Equal(t, tower.Addresses.GetAll(), c.Addresses.GetAll()) } // TestTowerCandidateIterator asserts the internal state of a @@ -88,13 +97,13 @@ func TestTowerCandidateIterator(t *testing.T) { // towers. We'll use copies of these towers within the iterator to // ensure the iterator properly updates the state of its candidates. const numTowers = 4 - towers := make([]*wtdb.Tower, 0, numTowers) + towers := make([]*Tower, 0, numTowers) for i := 0; i < numTowers; i++ { towers = append(towers, randTower(t)) } - towerCopies := make([]*wtdb.Tower, 0, numTowers) + towerCopies := make([]*Tower, 0, numTowers) for _, tower := range towers { - towerCopies = append(towerCopies, copyTower(tower)) + towerCopies = append(towerCopies, copyTower(t, tower)) } towerIterator := newTowerListIterator(towerCopies...) @@ -112,13 +121,13 @@ func TestTowerCandidateIterator(t *testing.T) { towerIterator.Reset() // We'll then attempt to test the RemoveCandidate behavior of the - // iterator. We'll remove the address of the first tower, which should - // result in it not having any addresses left, but still being an active - // candidate. + // iterator. We'll attempt to remove the address of the first tower, + // which should result in an error due to it being the last address of + // the tower. firstTower := towers[0] - firstTowerAddr := firstTower.Addresses[0] - firstTower.RemoveAddress(firstTowerAddr) - towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + firstTowerAddr := firstTower.Addresses.Peek() + err = towerIterator.RemoveCandidate(firstTower.ID, firstTowerAddr) + require.ErrorIs(t, err, wtdb.ErrLastTowerAddr) assertActiveCandidate(t, towerIterator, firstTower, true) assertNextCandidate(t, towerIterator, firstTower) @@ -126,7 +135,8 @@ func TestTowerCandidateIterator(t *testing.T) { // not providing the optional address. Since it's been removed, we // should expect to see the third tower next. secondTower, thirdTower := towers[1], towers[2] - towerIterator.RemoveCandidate(secondTower.ID, nil) + err = towerIterator.RemoveCandidate(secondTower.ID, nil) + require.NoError(t, err) assertActiveCandidate(t, towerIterator, secondTower, false) assertNextCandidate(t, towerIterator, thirdTower) @@ -135,7 +145,7 @@ func TestTowerCandidateIterator(t *testing.T) { // iterator, but the new address should be. fourthTower := towers[3] assertActiveCandidate(t, towerIterator, fourthTower, true) - fourthTower.AddAddress(randAddr(t)) + fourthTower.Addresses.Add(randAddr(t)) towerIterator.AddCandidate(fourthTower) assertNextCandidate(t, towerIterator, fourthTower) diff --git a/watchtower/wtclient/client.go b/watchtower/wtclient/client.go index f9514f8f1..1f6641e28 100644 --- a/watchtower/wtclient/client.go +++ b/watchtower/wtclient/client.go @@ -45,8 +45,8 @@ const ( // genActiveSessionFilter generates a filter that selects active sessions that // also match the desired channel type, either legacy or anchor. -func genActiveSessionFilter(anchor bool) func(*wtdb.ClientSession) bool { - return func(s *wtdb.ClientSession) bool { +func genActiveSessionFilter(anchor bool) func(*ClientSession) bool { + return func(s *ClientSession) bool { return s.Status == wtdb.CSessionActive && anchor == s.Policy.IsAnchorChannel() } @@ -241,7 +241,7 @@ type TowerClient struct { negotiator SessionNegotiator candidateTowers TowerCandidateIterator - candidateSessions map[wtdb.SessionID]*wtdb.ClientSession + candidateSessions map[wtdb.SessionID]*ClientSession activeSessions sessionQueueSet sessionQueue *sessionQueue @@ -351,7 +351,7 @@ func New(config *Config) (*TowerClient, error) { activeSessionFilter := genActiveSessionFilter(isAnchorClient) candidateTowers := newTowerListIterator() - perActiveTower := func(tower *wtdb.Tower) { + perActiveTower := func(tower *Tower) { // If the tower has already been marked as active, then there is // no need to add it to the iterator again. if candidateTowers.IsActive(tower.ID) { @@ -400,18 +400,23 @@ func New(config *Config) (*TowerClient, error) { // sessionFilter check then the perActiveTower call-back will be called on that // tower. func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, - sessionFilter func(*wtdb.ClientSession) bool, - perActiveTower func(tower *wtdb.Tower), + sessionFilter func(*ClientSession) bool, + perActiveTower func(tower *Tower), opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { towers, err := db.ListTowers() if err != nil { return nil, err } - candidateSessions := make(map[wtdb.SessionID]*wtdb.ClientSession) - for _, tower := range towers { + candidateSessions := make(map[wtdb.SessionID]*ClientSession) + for _, dbTower := range towers { + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + sessions, err := db.ListClientSessions(&tower.ID, opts...) if err != nil { return nil, err @@ -427,16 +432,24 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH( + + sessionKeyECDH := keychain.NewPubKeyECDH( towerKeyDesc, keyRing, ) - if !sessionFilter(s) { + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } + + if !sessionFilter(cs) { continue } // Add the session to the set of candidate sessions. - candidateSessions[s.ID] = s + candidateSessions[s.ID] = cs perActiveTower(tower) } } @@ -452,11 +465,11 @@ func getTowerAndSessionCandidates(db DB, keyRing ECDHKeyRing, // ClientSession's SessionPrivKey field is desired, otherwise, the existing // ListClientSessions method should be used. func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, - passesFilter func(*wtdb.ClientSession) bool, + passesFilter func(*ClientSession) bool, opts ...wtdb.ClientSessionListOption) ( - map[wtdb.SessionID]*wtdb.ClientSession, error) { + map[wtdb.SessionID]*ClientSession, error) { - sessions, err := db.ListClientSessions(forTower, opts...) + dbSessions, err := db.ListClientSessions(forTower, opts...) if err != nil { return nil, err } @@ -466,7 +479,13 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, // be able to communicate with the towers and authenticate session // requests. This prevents us from having to store the private keys on // disk. - for _, s := range sessions { + sessions := make(map[wtdb.SessionID]*ClientSession) + for _, s := range dbSessions { + dbTower, err := db.LoadTowerByID(s.TowerID) + if err != nil { + return nil, err + } + towerKeyDesc, err := keyRing.DeriveKey(keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, Index: s.KeyIndex, @@ -474,13 +493,27 @@ func getClientSessions(db DB, keyRing ECDHKeyRing, forTower *wtdb.TowerID, if err != nil { return nil, err } - s.SessionKeyECDH = keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + sessionKeyECDH := keychain.NewPubKeyECDH(towerKeyDesc, keyRing) + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return nil, err + } + + cs := &ClientSession{ + ID: s.ID, + ClientSessionBody: s.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKeyECDH, + } // If an optional filter was provided, use it to filter out any // undesired sessions. - if passesFilter != nil && !passesFilter(s) { - delete(sessions, s.ID) + if passesFilter != nil && !passesFilter(cs) { + continue } + + sessions[s.ID] = cs } return sessions, nil @@ -710,7 +743,7 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID, func (c *TowerClient) nextSessionQueue() (*sessionQueue, error) { // Select any candidate session at random, and remove it from the set of // candidate sessions. - var candidateSession *wtdb.ClientSession + var candidateSession *ClientSession for id, sessionInfo := range c.candidateSessions { delete(c.candidateSessions, id) @@ -1069,7 +1102,7 @@ func (c *TowerClient) sendMessage(peer wtserver.Peer, msg wtwire.Message) error // newSessionQueue creates a sessionQueue from a ClientSession loaded from the // database and supplying it with the resources needed by the client. -func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, +func (c *TowerClient) newSessionQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { return newSessionQueue(&sessionQueueConfig{ @@ -1089,7 +1122,7 @@ func (c *TowerClient) newSessionQueue(s *wtdb.ClientSession, // getOrInitActiveQueue checks the activeSessions set for a sessionQueue for the // passed ClientSession. If it exists, the active sessionQueue is returned. // Otherwise, a new sessionQueue is initialized and added to the set. -func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) getOrInitActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { if sq, ok := c.activeSessions[s.ID]; ok { @@ -1103,7 +1136,7 @@ func (c *TowerClient) getOrInitActiveQueue(s *wtdb.ClientSession, // adds the sessionQueue to the activeSessions set, and starts the sessionQueue // so that it can deliver any committed updates or begin accepting newly // assigned tasks. -func (c *TowerClient) initActiveQueue(s *wtdb.ClientSession, +func (c *TowerClient) initActiveQueue(s *ClientSession, updates []wtdb.CommittedUpdate) *sessionQueue { // Initialize the session queue, providing it with all the resources it @@ -1156,10 +1189,16 @@ func (c *TowerClient) handleNewTower(msg *newTowerMsg) error { // We'll start by updating our persisted state, followed by our // in-memory state, with the new tower. This might not actually be a new // tower, but it might include a new address at which it can be reached. - tower, err := c.cfg.DB.CreateTower(msg.addr) + dbTower, err := c.cfg.DB.CreateTower(msg.addr) if err != nil { return err } + + tower, err := NewTowerFromDBTower(dbTower) + if err != nil { + return err + } + c.candidateTowers.AddCandidate(tower) // Include all of its corresponding sessions to our set of candidates. @@ -1251,7 +1290,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error { // If our active session queue corresponds to the stale tower, we'll // proceed to negotiate a new one. if c.sessionQueue != nil { - activeTower := c.sessionQueue.towerAddr.IdentityKey.SerializeCompressed() + activeTower := c.sessionQueue.tower.IdentityKey.SerializeCompressed() if bytes.Equal(pubKey, activeTower) { c.sessionQueue = nil } diff --git a/watchtower/wtclient/client_test.go b/watchtower/wtclient/client_test.go index 6312bbbab..738c8cf02 100644 --- a/watchtower/wtclient/client_test.go +++ b/watchtower/wtclient/client_test.go @@ -1471,10 +1471,8 @@ var clientTests = []clientTest{ }, { // Assert that if a client changes the address for a server and - // then tries to back up updates then the client will not switch - // to the new address. The client will only use the server's new - // address after a restart. This is a bug that will be fixed in - // a future commit. + // then tries to back up updates then the client will switch to + // the new address. name: "change address of existing session", cfg: harnessCfg{ localBalance: localBalance, @@ -1535,16 +1533,7 @@ var clientTests = []clientTest{ // Now attempt to back up the rest of the updates. h.backupStates(chanID, numUpdates/2, maxUpdates, nil) - // Assert that the server does not receive the updates. - h.waitServerUpdates(nil, waitTime) - - // Restart the client and attempt to back up the updates - // again. - h.client.Stop() - h.startClient() - h.backupStates(chanID, numUpdates/2, maxUpdates, nil) - - // The server should now receive the updates. + // Assert that the server does receive the updates. h.waitServerUpdates(hints[:maxUpdates], waitTime) }, }, diff --git a/watchtower/wtclient/errors.go b/watchtower/wtclient/errors.go index 857af3087..f496074bf 100644 --- a/watchtower/wtclient/errors.go +++ b/watchtower/wtclient/errors.go @@ -20,10 +20,6 @@ var ( // down. ErrNegotiatorExiting = errors.New("negotiator exiting") - // ErrNoTowerAddrs signals that the client could not be created because - // we have no addresses with which we can reach a tower. - ErrNoTowerAddrs = errors.New("no tower addresses") - // ErrFailedNegotiation signals that the session negotiator could not // acquire a new session as requested. ErrFailedNegotiation = errors.New("session negotiation unsuccessful") diff --git a/watchtower/wtclient/interface.go b/watchtower/wtclient/interface.go index 5f2357950..ba6546328 100644 --- a/watchtower/wtclient/interface.go +++ b/watchtower/wtclient/interface.go @@ -118,3 +118,50 @@ type ECDHKeyRing interface { // key. DeriveKey(keyLoc keychain.KeyLocator) (keychain.KeyDescriptor, error) } + +// Tower represents the info about a watchtower server that a watchtower client +// needs in order to connect to it. +type Tower struct { + // ID is the unique, db-assigned, identifier for this tower. + ID wtdb.TowerID + + // IdentityKey is the public key of the remote node, used to + // authenticate the brontide transport. + IdentityKey *btcec.PublicKey + + // Addresses is an AddressIterator that can be used to manage the + // addresses for this tower. + Addresses AddressIterator +} + +// NewTowerFromDBTower converts a wtdb.Tower, which uses a static address list, +// into a Tower which uses an address iterator. +func NewTowerFromDBTower(t *wtdb.Tower) (*Tower, error) { + addrs, err := newAddressIterator(t.Addresses...) + if err != nil { + return nil, err + } + + return &Tower{ + ID: t.ID, + IdentityKey: t.IdentityKey, + Addresses: addrs, + }, nil +} + +// ClientSession represents the session that a tower client has with a server. +type ClientSession struct { + // ID is the client's public key used when authenticating with the + // tower. + ID wtdb.SessionID + + wtdb.ClientSessionBody + + // Tower represents the tower that the client session has been made + // with. + Tower *Tower + + // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret + // key used to connect to the watchtower. + SessionKeyECDH keychain.SingleKeyECDH +} diff --git a/watchtower/wtclient/session_negotiator.go b/watchtower/wtclient/session_negotiator.go index 9ccaf5b79..91b568158 100644 --- a/watchtower/wtclient/session_negotiator.go +++ b/watchtower/wtclient/session_negotiator.go @@ -25,7 +25,7 @@ type SessionNegotiator interface { // NewSessions is a read-only channel where newly negotiated sessions // will be delivered. - NewSessions() <-chan *wtdb.ClientSession + NewSessions() <-chan *ClientSession // Start safely initializes the session negotiator. Start() error @@ -105,8 +105,8 @@ type sessionNegotiator struct { log btclog.Logger dispatcher chan struct{} - newSessions chan *wtdb.ClientSession - successfulNegotiations chan *wtdb.ClientSession + newSessions chan *ClientSession + successfulNegotiations chan *ClientSession wg sync.WaitGroup quit chan struct{} @@ -139,8 +139,8 @@ func newSessionNegotiator(cfg *NegotiatorConfig) *sessionNegotiator { log: cfg.Log, localInit: localInit, dispatcher: make(chan struct{}, 1), - newSessions: make(chan *wtdb.ClientSession), - successfulNegotiations: make(chan *wtdb.ClientSession), + newSessions: make(chan *ClientSession), + successfulNegotiations: make(chan *ClientSession), quit: make(chan struct{}), } } @@ -171,7 +171,7 @@ func (n *sessionNegotiator) Stop() error { // NewSessions returns a receive-only channel from which newly negotiated // sessions will be returned. -func (n *sessionNegotiator) NewSessions() <-chan *wtdb.ClientSession { +func (n *sessionNegotiator) NewSessions() <-chan *ClientSession { return n.newSessions } @@ -333,18 +333,10 @@ retryWithBackoff: } } -// createSession takes a tower an attempts to negotiate a session using any of +// createSession takes a tower and attempts to negotiate a session using any of // its stored addresses. This method returns after the first successful -// negotiation, or after all addresses have failed with ErrFailedNegotiation. If -// the tower has no addresses, ErrNoTowerAddrs is returned. -func (n *sessionNegotiator) createSession(tower *wtdb.Tower, - keyIndex uint32) error { - - // If the tower has no addresses, there's nothing we can do. - if len(tower.Addresses) == 0 { - return ErrNoTowerAddrs - } - +// negotiation, or after all addresses have failed with ErrFailedNegotiation. +func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error { sessionKeyDesc, err := n.cfg.SecretKeyRing.DeriveKey( keychain.KeyLocator{ Family: keychain.KeyFamilyTowerSession, @@ -358,8 +350,14 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, sessionKeyDesc, n.cfg.SecretKeyRing, ) - for _, lnAddr := range tower.LNAddrs() { - err := n.tryAddress(sessionKey, keyIndex, tower, lnAddr) + addr := tower.Addresses.Peek() + for { + lnAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: addr, + } + + err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr) switch { case err == ErrPermanentTowerFailure: // TODO(conner): report to iterator? can then be reset @@ -370,6 +368,15 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, n.log.Debugf("Request for session negotiation with "+ "tower=%s failed, trying again -- reason: "+ "%v", lnAddr, err) + + // Get the next tower address if there is one. + addr, err = tower.Addresses.Next() + if err == ErrAddressesExhausted { + tower.Addresses.Reset() + + return ErrFailedNegotiation + } + continue default: @@ -385,7 +392,7 @@ func (n *sessionNegotiator) createSession(tower *wtdb.Tower, // returns true if all steps succeed and the new session has been persisted, and // fails otherwise. func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, - keyIndex uint32, tower *wtdb.Tower, lnAddr *lnwire.NetAddress) error { + keyIndex uint32, tower *Tower, lnAddr *lnwire.NetAddress) error { // Connect to the tower address using our generated session key. conn, err := n.cfg.Dial(sessionKey, lnAddr) @@ -456,26 +463,31 @@ func (n *sessionNegotiator) tryAddress(sessionKey keychain.SingleKeyECDH, rewardPkScript := createSessionReply.Data sessionID := wtdb.NewSessionIDFromPubKey(sessionKey.PubKey()) - clientSession := &wtdb.ClientSession{ + dbClientSession := &wtdb.ClientSession{ ClientSessionBody: wtdb.ClientSessionBody{ TowerID: tower.ID, KeyIndex: keyIndex, Policy: n.cfg.Policy, RewardPkScript: rewardPkScript, }, - Tower: tower, - SessionKeyECDH: sessionKey, - ID: sessionID, + ID: sessionID, } - err = n.cfg.DB.CreateClientSession(clientSession) + err = n.cfg.DB.CreateClientSession(dbClientSession) if err != nil { return fmt.Errorf("unable to persist ClientSession: %v", err) } n.log.Debugf("New session negotiated with %s, policy: %s", - lnAddr, clientSession.Policy) + lnAddr, dbClientSession.Policy) + + clientSession := &ClientSession{ + ID: sessionID, + ClientSessionBody: dbClientSession.ClientSessionBody, + Tower: tower, + SessionKeyECDH: sessionKey, + } // We have a newly negotiated session, return it to the // dispatcher so that it can update how many outstanding diff --git a/watchtower/wtclient/session_queue.go b/watchtower/wtclient/session_queue.go index d149d09b6..7d98ec86f 100644 --- a/watchtower/wtclient/session_queue.go +++ b/watchtower/wtclient/session_queue.go @@ -34,7 +34,7 @@ const ( type sessionQueueConfig struct { // ClientSession provides access to the negotiated session parameters // and updating its persistent storage. - ClientSession *wtdb.ClientSession + ClientSession *ClientSession // ChainHash identifies the chain for which the session's justice // transactions are targeted. @@ -97,7 +97,7 @@ type sessionQueue struct { queueCond *sync.Cond localInit *wtwire.Init - towerAddr *lnwire.NetAddress + tower *Tower seqNum uint16 @@ -117,18 +117,13 @@ func newSessionQueue(cfg *sessionQueueConfig, cfg.ChainHash, ) - towerAddr := &lnwire.NetAddress{ - IdentityKey: cfg.ClientSession.Tower.IdentityKey, - Address: cfg.ClientSession.Tower.Addresses[0], - } - sq := &sessionQueue{ cfg: cfg, log: cfg.Log, commitQueue: list.New(), pendingQueue: list.New(), localInit: localInit, - towerAddr: towerAddr, + tower: cfg.ClientSession.Tower, seqNum: cfg.ClientSession.SeqNum, retryBackoff: cfg.MinBackoff, quit: make(chan struct{}), @@ -293,18 +288,48 @@ func (q *sessionQueue) sessionManager() { // drainBackups attempts to send all pending updates in the queue to the tower. func (q *sessionQueue) drainBackups() { - // First, check that we are able to dial this session's tower. - conn, err := q.cfg.Dial(q.cfg.ClientSession.SessionKeyECDH, q.towerAddr) - if err != nil { - q.log.Errorf("SessionQueue(%s) unable to dial tower at %v: %v", - q.ID(), q.towerAddr, err) + var ( + conn wtserver.Peer + err error + towerAddr = q.tower.Addresses.Peek() + ) - q.increaseBackoff() - select { - case <-time.After(q.retryBackoff): - case <-q.forceQuit: + for { + q.log.Infof("SessionQueue(%s) attempting to dial tower at %v", + q.ID(), towerAddr) + + // First, check that we are able to dial this session's tower. + conn, err = q.cfg.Dial( + q.cfg.ClientSession.SessionKeyECDH, &lnwire.NetAddress{ + IdentityKey: q.tower.IdentityKey, + Address: towerAddr, + }, + ) + if err != nil { + // If there are more addrs available, immediately try + // those. + nextAddr, iteratorErr := q.tower.Addresses.Next() + if iteratorErr == nil { + towerAddr = nextAddr + continue + } + + // Otherwise, if we have exhausted the address list, + // back off and try again later. + q.tower.Addresses.Reset() + + q.log.Errorf("SessionQueue(%s) unable to dial tower "+ + "at any available Addresses: %v", q.ID(), err) + + q.increaseBackoff() + select { + case <-time.After(q.retryBackoff): + case <-q.forceQuit: + } + return } - return + + break } defer conn.Close() @@ -324,9 +349,7 @@ func (q *sessionQueue) drainBackups() { } // Now, send the state update to the tower and wait for a reply. - err = q.sendStateUpdate( - conn, stateUpdate, q.localInit, sendInit, isPending, - ) + err = q.sendStateUpdate(conn, stateUpdate, sendInit, isPending) if err != nil { q.log.Errorf("SessionQueue(%s) unable to send state "+ "update: %v", q.ID(), err) @@ -483,8 +506,12 @@ func (q *sessionQueue) nextStateUpdate() (*wtwire.StateUpdate, bool, // variable indicates whether we should back off before attempting to send the // next state update. func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, - stateUpdate *wtwire.StateUpdate, localInit *wtwire.Init, - sendInit, isPending bool) error { + stateUpdate *wtwire.StateUpdate, sendInit, isPending bool) error { + + towerAddr := &lnwire.NetAddress{ + IdentityKey: conn.RemotePub(), + Address: conn.RemoteAddr(), + } // If this is the first message being sent to the tower, we must send an // Init message to establish that server supports the features we @@ -505,7 +532,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, remoteInit, ok := remoteMsg.(*wtwire.Init) if !ok { return fmt.Errorf("watchtower %s responded with %T "+ - "to Init", q.towerAddr, remoteMsg) + "to Init", towerAddr, remoteMsg) } // Validate Init. @@ -532,7 +559,7 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, stateUpdateReply, ok := remoteMsg.(*wtwire.StateUpdateReply) if !ok { return fmt.Errorf("watchtower %s responded with %T to "+ - "StateUpdate", q.towerAddr, remoteMsg) + "StateUpdate", towerAddr, remoteMsg) } // Process the reply from the tower. @@ -547,8 +574,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, err := fmt.Errorf("received error code %v in "+ "StateUpdateReply for seqnum=%d", stateUpdateReply.Code, stateUpdate.SeqNum) - q.log.Warnf("SessionQueue(%s) unable to upload state update to "+ - "tower=%s: %v", q.ID(), q.towerAddr, err) + q.log.Warnf("SessionQueue(%s) unable to upload state update "+ + "to tower=%s: %v", q.ID(), towerAddr, err) return err } @@ -559,7 +586,8 @@ func (q *sessionQueue) sendStateUpdate(conn wtserver.Peer, // TODO(conner): borked watchtower err = fmt.Errorf("unable to ack seqnum=%d: %v", stateUpdate.SeqNum, err) - q.log.Errorf("SessionQueue(%v) failed to ack update: %v", q.ID(), err) + q.log.Errorf("SessionQueue(%v) failed to ack update: %v", + q.ID(), err) return err case err == wtdb.ErrLastAppliedReversion: diff --git a/watchtower/wtdb/client_db.go b/watchtower/wtdb/client_db.go index 94a9c2c74..26d4704d4 100644 --- a/watchtower/wtdb/client_db.go +++ b/watchtower/wtdb/client_db.go @@ -429,7 +429,7 @@ func (c *ClientDB) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error { } towerSessions, err := listTowerSessions( - towerID, sessions, towers, towersToSessionsIndex, + towerID, sessions, towersToSessionsIndex, WithPerCommittedUpdate(perCommittedUpdate), ) if err != nil { @@ -766,7 +766,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, // known to the db. if id == nil { clientSessions, err = listClientAllSessions( - sessions, towers, opts..., + sessions, opts..., ) return err } @@ -778,7 +778,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } clientSessions, err = listTowerSessions( - *id, sessions, towers, towerToSessionIndex, opts..., + *id, sessions, towerToSessionIndex, opts..., ) return err }, func() { @@ -792,7 +792,7 @@ func (c *ClientDB) ListClientSessions(id *TowerID, } // listClientAllSessions returns the set of all client sessions known to the db. -func listClientAllSessions(sessions, towers kvdb.RBucket, +func listClientAllSessions(sessions kvdb.RBucket, opts ...ClientSessionListOption) (map[SessionID]*ClientSession, error) { clientSessions := make(map[SessionID]*ClientSession) @@ -801,7 +801,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession(sessions, towers, k, opts...) + session, err := getClientSession(sessions, k, opts...) if err != nil { return err } @@ -819,7 +819,7 @@ func listClientAllSessions(sessions, towers kvdb.RBucket, // listTowerSessions returns the set of all client sessions known to the db // that are associated with the given tower id. -func listTowerSessions(id TowerID, sessionsBkt, towersBkt, +func listTowerSessions(id TowerID, sessionsBkt, towerToSessionIndex kvdb.RBucket, opts ...ClientSessionListOption) ( map[SessionID]*ClientSession, error) { @@ -834,9 +834,7 @@ func listTowerSessions(id TowerID, sessionsBkt, towersBkt, // the CommittedUpdates and AckedUpdates on startup to resume // committed updates and compute the highest known commit height // for each channel. - session, err := getClientSession( - sessionsBkt, towersBkt, k, opts..., - ) + session, err := getClientSession(sessionsBkt, k, opts...) if err != nil { return err } @@ -1248,7 +1246,7 @@ func WithPerCommittedUpdate(cb PerCommittedUpdateCB) ClientSessionListOption { // getClientSession loads the full ClientSession associated with the serialized // session id. This method populates the CommittedUpdates, AckUpdates and Tower // in addition to the ClientSession's body. -func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, +func getClientSession(sessions kvdb.RBucket, idBytes []byte, opts ...ClientSessionListOption) (*ClientSession, error) { cfg := NewClientSessionCfg() @@ -1261,13 +1259,6 @@ func getClientSession(sessions, towers kvdb.RBucket, idBytes []byte, return nil, err } - // Fetch the tower associated with this session. - tower, err := getTower(towers, session.TowerID.Bytes()) - if err != nil { - return nil, err - } - session.Tower = tower - // Can't fail because client session body has already been read. sessionBkt := sessions.NestedReadBucket(idBytes) diff --git a/watchtower/wtdb/client_db_test.go b/watchtower/wtdb/client_db_test.go index aa30cc713..f75a0c2bc 100644 --- a/watchtower/wtdb/client_db_test.go +++ b/watchtower/wtdb/client_db_test.go @@ -343,8 +343,11 @@ func testCreateTower(h *clientDBHarness) { h.loadTowerByID(20, wtdb.ErrTowerNotFound) tower := h.newTower() - require.Len(h.t, tower.LNAddrs(), 1) - towerAddr := tower.LNAddrs()[0] + require.Len(h.t, tower.Addresses, 1) + towerAddr := &lnwire.NetAddress{ + IdentityKey: tower.IdentityKey, + Address: tower.Addresses[0], + } // Load the tower from the database and assert that it matches the tower // we created. diff --git a/watchtower/wtdb/client_session.go b/watchtower/wtdb/client_session.go index a4d5c5ecc..e44331094 100644 --- a/watchtower/wtdb/client_session.go +++ b/watchtower/wtdb/client_session.go @@ -4,7 +4,6 @@ import ( "fmt" "io" - "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/watchtower/blob" "github.com/lightningnetwork/lnd/watchtower/wtpolicy" @@ -36,19 +35,6 @@ type ClientSession struct { ID SessionID ClientSessionBody - - // Tower holds the pubkey and address of the watchtower. - // - // NOTE: This value is not serialized. It is recovered by looking up the - // tower with TowerID. - Tower *Tower - - // SessionKeyECDH is the ECDH capable wrapper of the ephemeral secret - // key used to connect to the watchtower. - // - // NOTE: This value is not serialized. It is derived using the KeyIndex - // on startup to avoid storing private keys on disk. - SessionKeyECDH keychain.SingleKeyECDH } // ClientSessionBody represents the primary components of a ClientSession that diff --git a/watchtower/wtdb/tower.go b/watchtower/wtdb/tower.go index 77f452fb5..ca9dbeb28 100644 --- a/watchtower/wtdb/tower.go +++ b/watchtower/wtdb/tower.go @@ -7,7 +7,6 @@ import ( "net" "github.com/btcsuite/btcd/btcec/v2" - "github.com/lightningnetwork/lnd/lnwire" ) // TowerID is a unique 64-bit identifier allocated to each unique watchtower. @@ -77,23 +76,6 @@ func (t *Tower) RemoveAddress(addr net.Addr) { } } -// LNAddrs generates a list of lnwire.NetAddress from a Tower instance's -// addresses. This can be used to have a client try multiple addresses for the -// same Tower. -// -// NOTE: This method is NOT safe for concurrent use. -func (t *Tower) LNAddrs() []*lnwire.NetAddress { - addrs := make([]*lnwire.NetAddress, 0, len(t.Addresses)) - for _, addr := range t.Addresses { - addrs = append(addrs, &lnwire.NetAddress{ - IdentityKey: t.IdentityKey, - Address: addr, - }) - } - - return addrs -} - // String returns a user-friendly identifier of the tower. func (t *Tower) String() string { pubKey := hex.EncodeToString(t.IdentityKey.SerializeCompressed()) diff --git a/watchtower/wtmock/client_db.go b/watchtower/wtmock/client_db.go index 8a47bdf7f..b12fe2780 100644 --- a/watchtower/wtmock/client_db.go +++ b/watchtower/wtmock/client_db.go @@ -231,7 +231,6 @@ func (m *ClientDB) listClientSessions(tower *wtdb.TowerID, if tower != nil && *tower != session.TowerID { continue } - session.Tower = m.towers[session.TowerID] sessions[session.ID] = &session if cfg.PerAckedUpdate != nil {