mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-07 19:30:46 +02:00
watchtower: allow removal during session negotiation
In this commit, the bug demonstrated in the previous commit is fixed. The locking capabilities of the AddressIterator are used to lock addresses if they are being used for session negotiation. So now, when a request comes through to remove a tower address then a check is first done to ensure that the address is not currently in use. If it is not, then the request can go through.
This commit is contained in:
@@ -155,6 +155,10 @@ func (t *towerListIterator) RemoveCandidate(candidate wtdb.TowerID,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
if tower.Addresses.HasLocked() {
|
||||||
|
return ErrAddrInUse
|
||||||
|
}
|
||||||
|
|
||||||
delete(t.candidates, candidate)
|
delete(t.candidates, candidate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -2,7 +2,6 @@ package wtclient
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -826,13 +825,10 @@ func (c *TowerClient) backupDispatcher() {
|
|||||||
msg.errChan <- c.handleNewTower(msg)
|
msg.errChan <- c.handleNewTower(msg)
|
||||||
|
|
||||||
// A tower has been requested to be removed. We'll
|
// A tower has been requested to be removed. We'll
|
||||||
// immediately return an error as we want to avoid the
|
// only allow removal of it if the address in question
|
||||||
// possibility of a new session being negotiated with
|
// is not currently being used for session negotiation.
|
||||||
// this request's tower.
|
|
||||||
case msg := <-c.staleTowers:
|
case msg := <-c.staleTowers:
|
||||||
msg.errChan <- errors.New("removing towers " +
|
msg.errChan <- c.handleStaleTower(msg)
|
||||||
"is disallowed while a new session " +
|
|
||||||
"negotiation is in progress")
|
|
||||||
|
|
||||||
case <-c.forceQuit:
|
case <-c.forceQuit:
|
||||||
return
|
return
|
||||||
@@ -1254,18 +1250,31 @@ func (c *TowerClient) RemoveTower(pubKey *btcec.PublicKey, addr net.Addr) error
|
|||||||
func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
||||||
// We'll load the tower before potentially removing it in order to
|
// We'll load the tower before potentially removing it in order to
|
||||||
// retrieve its ID within the database.
|
// retrieve its ID within the database.
|
||||||
tower, err := c.cfg.DB.LoadTower(msg.pubKey)
|
dbTower, err := c.cfg.DB.LoadTower(msg.pubKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// We'll update our persisted state, followed by our in-memory state,
|
// We'll first update our in-memory state followed by our persisted
|
||||||
// with the stale tower.
|
// state, with the stale tower. The removal of the tower address from
|
||||||
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
|
// the in-memory state will fail if the address is currently being used
|
||||||
|
// for a session negotiation.
|
||||||
|
err = c.candidateTowers.RemoveCandidate(dbTower.ID, msg.addr)
|
||||||
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
err = c.candidateTowers.RemoveCandidate(tower.ID, msg.addr)
|
|
||||||
if err != nil {
|
if err := c.cfg.DB.RemoveTower(msg.pubKey, msg.addr); err != nil {
|
||||||
|
// If the persisted state update fails, re-add the address to
|
||||||
|
// our in-memory state.
|
||||||
|
tower, newTowerErr := NewTowerFromDBTower(dbTower)
|
||||||
|
if newTowerErr != nil {
|
||||||
|
log.Errorf("could not create new in-memory tower: %v",
|
||||||
|
newTowerErr)
|
||||||
|
} else {
|
||||||
|
c.candidateTowers.AddCandidate(tower)
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1278,7 +1287,7 @@ func (c *TowerClient) handleStaleTower(msg *staleTowerMsg) error {
|
|||||||
// Otherwise, the tower should no longer be used for future session
|
// Otherwise, the tower should no longer be used for future session
|
||||||
// negotiations and backups.
|
// negotiations and backups.
|
||||||
pubKey := msg.pubKey.SerializeCompressed()
|
pubKey := msg.pubKey.SerializeCompressed()
|
||||||
sessions, err := c.cfg.DB.ListClientSessions(&tower.ID)
|
sessions, err := c.cfg.DB.ListClientSessions(&dbTower.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
return fmt.Errorf("unable to retrieve sessions for tower %x: "+
|
||||||
"%v", pubKey, err)
|
"%v", pubKey, err)
|
||||||
|
@@ -2,6 +2,7 @@ package wtclient_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -394,6 +395,8 @@ type testHarness struct {
|
|||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
channels map[lnwire.ChannelID]*mockChannel
|
channels map[lnwire.ChannelID]*mockChannel
|
||||||
|
|
||||||
|
quit chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
type harnessCfg struct {
|
type harnessCfg struct {
|
||||||
@@ -467,7 +470,11 @@ func newHarness(t *testing.T, cfg harnessCfg) *testHarness {
|
|||||||
serverCfg: serverCfg,
|
serverCfg: serverCfg,
|
||||||
net: mockNet,
|
net: mockNet,
|
||||||
channels: make(map[lnwire.ChannelID]*mockChannel),
|
channels: make(map[lnwire.ChannelID]*mockChannel),
|
||||||
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
t.Cleanup(func() {
|
||||||
|
close(h.quit)
|
||||||
|
})
|
||||||
|
|
||||||
if !cfg.noServerStart {
|
if !cfg.noServerStart {
|
||||||
h.startServer()
|
h.startServer()
|
||||||
@@ -1542,11 +1549,10 @@ var clientTests = []clientTest{
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Assert that an error is returned if a user tries to remove
|
// Assert that a user is able to remove a tower address during
|
||||||
// a tower from the client while a session negotiation is in
|
// session negotiation as long as the address in question is not
|
||||||
// progress. This is a bug that will be fixed in a future
|
// currently being used.
|
||||||
// commit.
|
name: "removing a tower during session negotiation",
|
||||||
name: "cant remove tower while session negotiation in progress",
|
|
||||||
cfg: harnessCfg{
|
cfg: harnessCfg{
|
||||||
localBalance: localBalance,
|
localBalance: localBalance,
|
||||||
remoteBalance: remoteBalance,
|
remoteBalance: remoteBalance,
|
||||||
@@ -1560,18 +1566,93 @@ var clientTests = []clientTest{
|
|||||||
noServerStart: true,
|
noServerStart: true,
|
||||||
},
|
},
|
||||||
fn: func(h *testHarness) {
|
fn: func(h *testHarness) {
|
||||||
var err error
|
// The server has not started yet and so no session
|
||||||
waitErr := wait.Predicate(func() bool {
|
// negotiation with the server will be in progress, so
|
||||||
err = h.client.RemoveTower(
|
// the client should be able to remove the server.
|
||||||
|
err := wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
h.serverAddr.IdentityKey, nil,
|
h.serverAddr.IdentityKey, nil,
|
||||||
)
|
)
|
||||||
return err != nil
|
}, waitTime)
|
||||||
}, time.Second*5)
|
require.NoError(h.t, err)
|
||||||
require.NoError(h.t, waitErr)
|
|
||||||
|
|
||||||
require.ErrorContains(h.t, err, "removing towers is "+
|
// Set the server up so that its Dial function hangs
|
||||||
"disallowed while a new session negotiation "+
|
// when the client calls it. This will force the client
|
||||||
"is in progress")
|
// to remain in the state where it has locked the
|
||||||
|
// address of the server.
|
||||||
|
h.server, err = wtserver.New(h.serverCfg)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
cancel := make(chan struct{})
|
||||||
|
h.net.registerConnCallback(
|
||||||
|
h.serverAddr, func(peer wtserver.Peer) {
|
||||||
|
select {
|
||||||
|
case <-h.quit:
|
||||||
|
case <-cancel:
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Also add a new tower address.
|
||||||
|
towerTCPAddr, err := net.ResolveTCPAddr(
|
||||||
|
"tcp", towerAddr2Str,
|
||||||
|
)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
towerAddr := &lnwire.NetAddress{
|
||||||
|
IdentityKey: h.serverAddr.IdentityKey,
|
||||||
|
Address: towerTCPAddr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register the new address in the mock-net.
|
||||||
|
h.net.registerConnCallback(
|
||||||
|
towerAddr, h.server.InboundPeerConnected,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Now start the server.
|
||||||
|
require.NoError(h.t, h.server.Start())
|
||||||
|
|
||||||
|
// Re-add the server to the client
|
||||||
|
err = h.client.AddTower(h.serverAddr)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Also add the new tower address.
|
||||||
|
err = h.client.AddTower(towerAddr)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Assert that if the client attempts to remove the
|
||||||
|
// tower's first address, then it will error due to
|
||||||
|
// address currently being locked for session
|
||||||
|
// negotiation.
|
||||||
|
err = wait.Predicate(func() bool {
|
||||||
|
err = h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey,
|
||||||
|
h.serverAddr.Address,
|
||||||
|
)
|
||||||
|
return errors.Is(err, wtclient.ErrAddrInUse)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Assert that the second address can be removed since
|
||||||
|
// it is not being used for session negotiation.
|
||||||
|
err = wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey, towerTCPAddr,
|
||||||
|
)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
|
|
||||||
|
// Allow the dial to the first address to stop hanging.
|
||||||
|
close(cancel)
|
||||||
|
|
||||||
|
// Assert that the client can now remove the first
|
||||||
|
// address.
|
||||||
|
err = wait.NoError(func() error {
|
||||||
|
return h.client.RemoveTower(
|
||||||
|
h.serverAddr.IdentityKey, nil,
|
||||||
|
)
|
||||||
|
}, waitTime)
|
||||||
|
require.NoError(h.t, err)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@@ -350,7 +350,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
|
|||||||
sessionKeyDesc, n.cfg.SecretKeyRing,
|
sessionKeyDesc, n.cfg.SecretKeyRing,
|
||||||
)
|
)
|
||||||
|
|
||||||
addr := tower.Addresses.Peek()
|
addr := tower.Addresses.PeekAndLock()
|
||||||
for {
|
for {
|
||||||
lnAddr := &lnwire.NetAddress{
|
lnAddr := &lnwire.NetAddress{
|
||||||
IdentityKey: tower.IdentityKey,
|
IdentityKey: tower.IdentityKey,
|
||||||
@@ -358,6 +358,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
|
err = n.tryAddress(sessionKey, keyIndex, tower, lnAddr)
|
||||||
|
tower.Addresses.ReleaseLock(addr)
|
||||||
switch {
|
switch {
|
||||||
case err == ErrPermanentTowerFailure:
|
case err == ErrPermanentTowerFailure:
|
||||||
// TODO(conner): report to iterator? can then be reset
|
// TODO(conner): report to iterator? can then be reset
|
||||||
@@ -370,7 +371,7 @@ func (n *sessionNegotiator) createSession(tower *Tower, keyIndex uint32) error {
|
|||||||
"%v", lnAddr, err)
|
"%v", lnAddr, err)
|
||||||
|
|
||||||
// Get the next tower address if there is one.
|
// Get the next tower address if there is one.
|
||||||
addr, err = tower.Addresses.Next()
|
addr, err = tower.Addresses.NextAndLock()
|
||||||
if err == ErrAddressesExhausted {
|
if err == ErrAddressesExhausted {
|
||||||
tower.Addresses.Reset()
|
tower.Addresses.Reset()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user