multi: refresh htlcswitch aliases on aliasmgr update

This commit is contained in:
George Tsagkarelis 2024-03-12 18:15:14 +01:00 committed by Oliver Gugger
parent 0c95dc2118
commit 4ef68512a9
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
8 changed files with 152 additions and 20 deletions

View File

@ -11,6 +11,11 @@ import (
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
) )
// UpdateLinkAliases is a function type for a function that locates the active
// link that matches the given shortID and triggers an update based on the
// latest values of the alias manager.
type UpdateLinkAliases func(shortID lnwire.ShortChannelID) error
var ( var (
// aliasBucket stores aliases as keys and their base SCIDs as values. // aliasBucket stores aliases as keys and their base SCIDs as values.
// This is used to populate the maps that the Manager uses. The keys // This is used to populate the maps that the Manager uses. The keys
@ -82,6 +87,10 @@ var (
type Manager struct { type Manager struct {
backend kvdb.Backend backend kvdb.Backend
// linkAliasUpdater is a function used by the alias manager to
// facilitate live update of aliases in other subsystems.
linkAliasUpdater UpdateLinkAliases
// baseToSet is a mapping from the "base" SCID to the set of aliases // baseToSet is a mapping from the "base" SCID to the set of aliases
// for this channel. This mapping includes all channels that // for this channel. This mapping includes all channels that
// negotiated the option-scid-alias feature bit. // negotiated the option-scid-alias feature bit.
@ -103,8 +112,14 @@ type Manager struct {
} }
// NewManager initializes an alias Manager from the passed database backend. // NewManager initializes an alias Manager from the passed database backend.
func NewManager(db kvdb.Backend) (*Manager, error) { func NewManager(db kvdb.Backend, linkAliasUpdater UpdateLinkAliases) (*Manager,
m := &Manager{backend: db} error) {
m := &Manager{
backend: db,
linkAliasUpdater: linkAliasUpdater,
}
m.baseToSet = make(map[lnwire.ShortChannelID][]lnwire.ShortChannelID) m.baseToSet = make(map[lnwire.ShortChannelID][]lnwire.ShortChannelID)
m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID)
m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID) m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID)
@ -220,12 +235,22 @@ func (m *Manager) populateMaps() error {
// AddLocalAlias adds a database mapping from the passed alias to the passed // AddLocalAlias adds a database mapping from the passed alias to the passed
// base SCID. The gossip boolean marks whether or not to create a mapping // base SCID. The gossip boolean marks whether or not to create a mapping
// that the gossiper will use. It is set to false for the upgrade path where // that the gossiper will use. It is set to false for the upgrade path where
// the feature-bit is toggled on and there are existing channels. // the feature-bit is toggled on and there are existing channels. The linkUpdate
// flag is used to signal whether this function should also trigger an update
// on the htlcswitch scid alias maps.
func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID, func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID,
gossip bool) error { gossip, linkUpdate bool) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock() m.Lock()
defer m.Unlock() unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
// If the caller does not want to allow the alias to be used // If the caller does not want to allow the alias to be used
@ -275,6 +300,18 @@ func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID,
m.aliasToBase[alias] = baseScid m.aliasToBase[alias] = baseScid
} }
// We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
// Finally, we trigger a htlcswitch update if the flag is set, in order
// for any future htlc that references the added alias to be properly
// routed.
if linkUpdate {
return m.linkAliasUpdater(baseScid)
}
return nil return nil
} }
@ -349,8 +386,16 @@ func (m *Manager) DeleteSixConfs(baseScid lnwire.ShortChannelID) error {
func (m *Manager) DeleteLocalAlias(alias, func (m *Manager) DeleteLocalAlias(alias,
baseScid lnwire.ShortChannelID) error { baseScid lnwire.ShortChannelID) error {
// We need to lock the manager for the whole duration of this method,
// except for the very last part where we call the link updater. In
// order for us to safely use a defer _and_ still be able to manually
// unlock, we use a sync.Once.
m.Lock() m.Lock()
defer m.Unlock() unlockOnce := sync.Once{}
unlock := func() {
unlockOnce.Do(m.Unlock)
}
defer unlock()
err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error {
aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket) aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket)
@ -397,7 +442,12 @@ func (m *Manager) DeleteLocalAlias(alias,
// cache (but this is only set if we gossip the alias). // cache (but this is only set if we gossip the alias).
delete(m.aliasToBase, alias) delete(m.aliasToBase, alias)
return nil // We definitely need to unlock the Manager before calling the link
// updater. If we don't, we'll deadlock. We use a sync.Once to ensure
// that we only unlock once.
unlock()
return m.linkAliasUpdater(baseScid)
} }
// PutPeerAlias stores the peer's alias SCID once we learn of it in the // PutPeerAlias stores the peer's alias SCID once we learn of it in the

View File

@ -23,7 +23,11 @@ func TestAliasStorePeerAlias(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
aliasStore, err := NewManager(db) linkUpdater := func(shortID lnwire.ShortChannelID) error {
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err) require.NoError(t, err)
var chanID1 [32]byte var chanID1 [32]byte
@ -52,7 +56,11 @@ func TestAliasStoreRequest(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
aliasStore, err := NewManager(db) linkUpdater := func(shortID lnwire.ShortChannelID) error {
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err) require.NoError(t, err)
// We'll assert that the very first alias we receive is StartingAlias. // We'll assert that the very first alias we receive is StartingAlias.
@ -80,7 +88,14 @@ func TestAliasLifecycle(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
defer db.Close() defer db.Close()
aliasStore, err := NewManager(db) updateChan := make(chan struct{}, 1)
linkUpdater := func(shortID lnwire.ShortChannelID) error {
updateChan <- struct{}{}
return nil
}
aliasStore, err := NewManager(db, linkUpdater)
require.NoError(t, err) require.NoError(t, err)
const ( const (
@ -94,18 +109,24 @@ func TestAliasLifecycle(t *testing.T) {
aliasScid2 := lnwire.NewShortChanIDFromInt(alias + 1) aliasScid2 := lnwire.NewShortChanIDFromInt(alias + 1)
// Add the first alias. // Add the first alias.
err = aliasStore.AddLocalAlias(aliasScid, baseScid, false) err = aliasStore.AddLocalAlias(aliasScid, baseScid, false, true)
require.NoError(t, err) require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify the results. // Query the aliases and verify the results.
aliasList := aliasStore.GetAliases(baseScid) aliasList := aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 1) require.Len(t, aliasList, 1)
require.Contains(t, aliasList, aliasScid) require.Contains(t, aliasList, aliasScid)
// Add the second alias. // Add the second alias.
err = aliasStore.AddLocalAlias(aliasScid2, baseScid, false) err = aliasStore.AddLocalAlias(aliasScid2, baseScid, false, true)
require.NoError(t, err) require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify the results. // Query the aliases and verify the results.
aliasList = aliasStore.GetAliases(baseScid) aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 2) require.Len(t, aliasList, 2)
@ -116,11 +137,21 @@ func TestAliasLifecycle(t *testing.T) {
err = aliasStore.DeleteLocalAlias(aliasScid, baseScid) err = aliasStore.DeleteLocalAlias(aliasScid, baseScid)
require.NoError(t, err) require.NoError(t, err)
// The link updater should be called.
<-updateChan
// We expect to get an error if we attempt to delete the same alias // We expect to get an error if we attempt to delete the same alias
// again. // again.
err = aliasStore.DeleteLocalAlias(aliasScid, baseScid) err = aliasStore.DeleteLocalAlias(aliasScid, baseScid)
require.ErrorIs(t, err, ErrAliasNotFound) require.ErrorIs(t, err, ErrAliasNotFound)
// The link updater should _not_ be called.
select {
case <-updateChan:
t.Fatal("link alias updater should not have been called")
default:
}
// Query the aliases and verify that first one doesn't exist anymore. // Query the aliases and verify that first one doesn't exist anymore.
aliasList = aliasStore.GetAliases(baseScid) aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 1) require.Len(t, aliasList, 1)
@ -131,6 +162,9 @@ func TestAliasLifecycle(t *testing.T) {
err = aliasStore.DeleteLocalAlias(aliasScid2, baseScid) err = aliasStore.DeleteLocalAlias(aliasScid2, baseScid)
require.NoError(t, err) require.NoError(t, err)
// The link updater should be called.
<-updateChan
// Query the aliases and verify that none exists. // Query the aliases and verify that none exists.
aliasList = aliasStore.GetAliases(baseScid) aliasList = aliasStore.GetAliases(baseScid)
require.Len(t, aliasList, 0) require.Len(t, aliasList, 0)

View File

@ -36,7 +36,8 @@ type aliasHandler interface {
GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID, error) GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID, error)
// AddLocalAlias persists an alias to an underlying alias store. // AddLocalAlias persists an alias to an underlying alias store.
AddLocalAlias(lnwire.ShortChannelID, lnwire.ShortChannelID, bool) error AddLocalAlias(lnwire.ShortChannelID, lnwire.ShortChannelID, bool,
bool) error
// GetAliases returns the set of aliases given the main SCID of a // GetAliases returns the set of aliases given the main SCID of a
// channel. This SCID will be an alias for zero-conf channels and will // channel. This SCID will be an alias for zero-conf channels and will

View File

@ -1263,7 +1263,7 @@ func (f *Manager) advancePendingChannelState(
// Persist the alias to the alias database. // Persist the alias to the alias database.
baseScid := channel.ShortChannelID baseScid := channel.ShortChannelID
err := f.cfg.AliasManager.AddLocalAlias( err := f.cfg.AliasManager.AddLocalAlias(
baseScid, baseScid, true, baseScid, baseScid, true, false,
) )
if err != nil { if err != nil {
return fmt.Errorf("error adding local alias to "+ return fmt.Errorf("error adding local alias to "+
@ -3149,7 +3149,7 @@ func (f *Manager) handleFundingConfirmation(
} }
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
aliasScid, confChannel.shortChanID, true, aliasScid, confChannel.shortChanID, true, false,
) )
if err != nil { if err != nil {
return fmt.Errorf("unable to request alias: %w", err) return fmt.Errorf("unable to request alias: %w", err)
@ -3315,7 +3315,7 @@ func (f *Manager) sendChannelReady(completeChan *channeldb.OpenChannel,
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
alias, completeChan.ShortChannelID, alias, completeChan.ShortChannelID,
false, false, false,
) )
if err != nil { if err != nil {
return err return err
@ -3892,7 +3892,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
} }
err = f.cfg.AliasManager.AddLocalAlias( err = f.cfg.AliasManager.AddLocalAlias(
alias, channel.ShortChannelID, false, alias, channel.ShortChannelID, false, false,
) )
if err != nil { if err != nil {
log.Errorf("unable to add local alias: %v", log.Errorf("unable to add local alias: %v",

View File

@ -162,7 +162,7 @@ func (m *mockAliasMgr) GetPeerAlias(lnwire.ChannelID) (lnwire.ShortChannelID,
} }
func (m *mockAliasMgr) AddLocalAlias(lnwire.ShortChannelID, func (m *mockAliasMgr) AddLocalAlias(lnwire.ShortChannelID,
lnwire.ShortChannelID, bool) error { lnwire.ShortChannelID, bool, bool) error {
return nil return nil
} }

View File

@ -2090,6 +2090,26 @@ func (s *Switch) addLiveLink(link ChannelLink) {
} }
s.interfaceIndex[peerPub][link.ChanID()] = link s.interfaceIndex[peerPub][link.ChanID()] = link
s.updateLinkAliases(link)
}
// UpdateLinkAliases is the externally exposed wrapper for updating link
// aliases. It acquires the indexMtx and calls the internal method.
func (s *Switch) UpdateLinkAliases(link ChannelLink) {
s.indexMtx.Lock()
defer s.indexMtx.Unlock()
s.updateLinkAliases(link)
}
// updateLinkAliases updates the aliases for a given link. This will cause the
// htlcswitch to consult the alias manager on the up to date values of its
// alias maps.
//
// NOTE: this MUST be called with the indexMtx held.
func (s *Switch) updateLinkAliases(link ChannelLink) {
linkScid := link.ShortChanID()
aliases := link.getAliases() aliases := link.getAliases()
if link.isZeroConf() { if link.isZeroConf() {
if link.zeroConfConfirmed() { if link.zeroConfConfirmed() {
@ -2114,6 +2134,21 @@ func (s *Switch) addLiveLink(link ChannelLink) {
s.baseIndex[alias] = linkScid s.baseIndex[alias] = linkScid
} }
} else if link.negotiatedAliasFeature() { } else if link.negotiatedAliasFeature() {
// First, we flush any alias mappings for this link's scid
// before we populate the map again, in order to get rid of old
// values that no longer exist.
for alias, real := range s.aliasToReal {
if real == linkScid {
delete(s.aliasToReal, alias)
}
}
for alias, real := range s.baseIndex {
if real == linkScid {
delete(s.baseIndex, alias)
}
}
// The link's SCID is the confirmed SCID for non-zero-conf // The link's SCID is the confirmed SCID for non-zero-conf
// option-scid-alias feature bit channels. // option-scid-alias feature bit channels.
for _, alias := range aliases { for _, alias := range aliases {

View File

@ -370,7 +370,7 @@ type Config struct {
// AddLocalAlias persists an alias to an underlying alias store. // AddLocalAlias persists an alias to an underlying alias store.
AddLocalAlias func(alias, base lnwire.ShortChannelID, AddLocalAlias func(alias, base lnwire.ShortChannelID,
gossip bool) error gossip, liveUpdate bool) error
// AuxLeafStore is an optional store that can be used to store auxiliary // AuxLeafStore is an optional store that can be used to store auxiliary
// leaves for certain custom channel types. // leaves for certain custom channel types.
@ -912,6 +912,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
err = p.cfg.AddLocalAlias( err = p.cfg.AddLocalAlias(
aliasScid, dbChan.ShortChanID(), false, aliasScid, dbChan.ShortChanID(), false,
false,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -644,7 +644,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
thresholdSats := btcutil.Amount(cfg.MaxFeeExposure) thresholdSats := btcutil.Amount(cfg.MaxFeeExposure)
thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats) thresholdMSats := lnwire.NewMSatFromSatoshis(thresholdSats)
s.aliasMgr, err = aliasmgr.NewManager(dbs.ChanStateDB) linkUpdater := func(shortID lnwire.ShortChannelID) error {
link, err := s.htlcSwitch.GetLinkByShortID(shortID)
if err != nil {
return err
}
s.htlcSwitch.UpdateLinkAliases(link)
return nil
}
s.aliasMgr, err = aliasmgr.NewManager(dbs.ChanStateDB, linkUpdater)
if err != nil { if err != nil {
return nil, err return nil, err
} }