diff --git a/aliasmgr/aliasmgr.go b/aliasmgr/aliasmgr.go index f33eb391e..cae637f33 100644 --- a/aliasmgr/aliasmgr.go +++ b/aliasmgr/aliasmgr.go @@ -5,6 +5,7 @@ import ( "fmt" "sync" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -68,6 +69,10 @@ var ( // errNoPeerAlias is returned when the peer's alias for a given // channel is not found. errNoPeerAlias = fmt.Errorf("no peer alias found") + + // ErrAliasNotFound is returned when the alias is not found and can't + // be mapped to a base SCID. + ErrAliasNotFound = fmt.Errorf("alias not found") ) // Manager is a struct that handles aliases for LND. It has an underlying @@ -340,6 +345,61 @@ func (m *Manager) DeleteSixConfs(baseScid lnwire.ShortChannelID) error { return nil } +// DeleteLocalAlias removes a mapping from the database and the Manager's maps. +func (m *Manager) DeleteLocalAlias(alias, + baseScid lnwire.ShortChannelID) error { + + m.Lock() + defer m.Unlock() + + err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { + aliasToBaseBucket, err := tx.CreateTopLevelBucket(aliasBucket) + if err != nil { + return err + } + + var aliasBytes [8]byte + byteOrder.PutUint64(aliasBytes[:], alias.ToUint64()) + + // If the user attempts to delete an alias that doesn't exist, + // we'll want to inform them about it and not just do nothing. + if aliasToBaseBucket.Get(aliasBytes[:]) == nil { + return ErrAliasNotFound + } + + return aliasToBaseBucket.Delete(aliasBytes[:]) + }, func() {}) + if err != nil { + return err + } + + // Now that the database state has been updated, we'll delete the + // mapping from the Manager's maps. + aliasSet, ok := m.baseToSet[baseScid] + if !ok { + return ErrAliasNotFound + } + + // We'll filter the alias set and remove the alias from it. + aliasSet = fn.Filter(func(a lnwire.ShortChannelID) bool { + return a.ToUint64() != alias.ToUint64() + }, aliasSet) + + // If the alias set is empty, we'll delete the base SCID from the + // baseToSet map. + if len(aliasSet) == 0 { + delete(m.baseToSet, baseScid) + } else { + m.baseToSet[baseScid] = aliasSet + } + + // Finally, we'll delete the aliasToBase mapping from the Manager's + // cache (but this is only set if we gossip the alias). + delete(m.aliasToBase, alias) + + return nil +} + // PutPeerAlias stores the peer's alias SCID once we learn of it in the // channel_ready message. func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID, diff --git a/aliasmgr/aliasmgr_test.go b/aliasmgr/aliasmgr_test.go index 17159ed87..e0e78ba70 100644 --- a/aliasmgr/aliasmgr_test.go +++ b/aliasmgr/aliasmgr_test.go @@ -68,6 +68,74 @@ func TestAliasStoreRequest(t *testing.T) { require.Equal(t, nextAlias, alias2) } +// TestAliasLifecycle tests that the aliases can be created and deleted. +func TestAliasLifecycle(t *testing.T) { + t.Parallel() + + // Create the backend database and use this to create the aliasStore. + dbPath := filepath.Join(t.TempDir(), "testdb") + db, err := kvdb.Create( + kvdb.BoltBackendName, dbPath, true, kvdb.DefaultDBTimeout, + ) + require.NoError(t, err) + defer db.Close() + + aliasStore, err := NewManager(db) + require.NoError(t, err) + + const ( + base = uint64(123123123) + alias = uint64(456456456) + ) + + // Parse the aliases and base to short channel ID format. + baseScid := lnwire.NewShortChanIDFromInt(base) + aliasScid := lnwire.NewShortChanIDFromInt(alias) + aliasScid2 := lnwire.NewShortChanIDFromInt(alias + 1) + + // Add the first alias. + err = aliasStore.AddLocalAlias(aliasScid, baseScid, false) + require.NoError(t, err) + + // Query the aliases and verify the results. + aliasList := aliasStore.GetAliases(baseScid) + require.Len(t, aliasList, 1) + require.Contains(t, aliasList, aliasScid) + + // Add the second alias. + err = aliasStore.AddLocalAlias(aliasScid2, baseScid, false) + require.NoError(t, err) + + // Query the aliases and verify the results. + aliasList = aliasStore.GetAliases(baseScid) + require.Len(t, aliasList, 2) + require.Contains(t, aliasList, aliasScid) + require.Contains(t, aliasList, aliasScid2) + + // Delete the first alias. + err = aliasStore.DeleteLocalAlias(aliasScid, baseScid) + require.NoError(t, err) + + // We expect to get an error if we attempt to delete the same alias + // again. + err = aliasStore.DeleteLocalAlias(aliasScid, baseScid) + require.ErrorIs(t, err, ErrAliasNotFound) + + // Query the aliases and verify that first one doesn't exist anymore. + aliasList = aliasStore.GetAliases(baseScid) + require.Len(t, aliasList, 1) + require.Contains(t, aliasList, aliasScid2) + require.NotContains(t, aliasList, aliasScid) + + // Delete the second alias. + err = aliasStore.DeleteLocalAlias(aliasScid2, baseScid) + require.NoError(t, err) + + // Query the aliases and verify that none exists. + aliasList = aliasStore.GetAliases(baseScid) + require.Len(t, aliasList, 0) +} + // TestGetNextScid tests that given a current lnwire.ShortChannelID, // getNextScid returns the expected alias to use next. func TestGetNextScid(t *testing.T) {