diff --git a/aliasmgr/aliasmgr.go b/aliasmgr/aliasmgr.go index 3048e0e26..67d484099 100644 --- a/aliasmgr/aliasmgr.go +++ b/aliasmgr/aliasmgr.go @@ -87,18 +87,22 @@ type Manager struct { // negotiated option-scid-alias feature bit. aliasToBase map[lnwire.ShortChannelID]lnwire.ShortChannelID + // peerAlias is a cache for the alias SCIDs that our peers send us in + // the funding_locked TLV. The keys are the ChannelID generated from + // the FundingOutpoint and the values are the remote peer's alias SCID. + // The values should match the ones stored in the "invoice-alias-bucket" + // bucket. + peerAlias map[lnwire.ChannelID]lnwire.ShortChannelID + sync.RWMutex } // NewManager initializes an alias Manager from the passed database backend. func NewManager(db kvdb.Backend) (*Manager, error) { m := &Manager{backend: db} - m.baseToSet = make( - map[lnwire.ShortChannelID][]lnwire.ShortChannelID, - ) - m.aliasToBase = make( - map[lnwire.ShortChannelID]lnwire.ShortChannelID, - ) + m.baseToSet = make(map[lnwire.ShortChannelID][]lnwire.ShortChannelID) + m.aliasToBase = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) + m.peerAlias = make(map[lnwire.ChannelID]lnwire.ShortChannelID) err := m.populateMaps() return m, err @@ -115,6 +119,10 @@ func (m *Manager) populateMaps() error { // populate the Manager's actual maps. aliasMap := make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) + // This map caches the ChannelID/alias SCIDs stored in the database and + // is used to populate the Manager's cache. + peerAliasMap := make(map[lnwire.ChannelID]lnwire.ShortChannelID) + err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { baseConfBucket, err := tx.CreateTopLevelBucket(confirmedBucket) if err != nil { @@ -152,12 +160,34 @@ func (m *Manager) populateMaps() error { aliasMap[aliasScid] = baseScid return nil }) + if err != nil { + return err + } + + invAliasBucket, err := tx.CreateTopLevelBucket( + invoiceAliasBucket, + ) + if err != nil { + return err + } + + err = invAliasBucket.ForEach(func(k, v []byte) error { + var chanID lnwire.ChannelID + copy(chanID[:], k) + alias := lnwire.NewShortChanIDFromInt( + byteOrder.Uint64(v), + ) + + peerAliasMap[chanID] = alias + + return nil + }) + return err }, func() { baseConfMap = make(map[lnwire.ShortChannelID]struct{}) - aliasMap = make( - map[lnwire.ShortChannelID]lnwire.ShortChannelID, - ) + aliasMap = make(map[lnwire.ShortChannelID]lnwire.ShortChannelID) + peerAliasMap = make(map[lnwire.ChannelID]lnwire.ShortChannelID) }) if err != nil { return err @@ -176,6 +206,9 @@ func (m *Manager) populateMaps() error { m.aliasToBase[aliasSCID] = baseSCID } + // Populate the peer alias cache. + m.peerAlias = peerAliasMap + return nil } @@ -242,7 +275,9 @@ func (m *Manager) AddLocalAlias(alias, baseScid lnwire.ShortChannelID, // GetAliases fetches the set of aliases stored under a given base SCID from // write-through caches. -func (m *Manager) GetAliases(base lnwire.ShortChannelID) []lnwire.ShortChannelID { +func (m *Manager) GetAliases( + base lnwire.ShortChannelID) []lnwire.ShortChannelID { + m.RLock() defer m.RUnlock() @@ -310,7 +345,10 @@ func (m *Manager) DeleteSixConfs(baseScid lnwire.ShortChannelID) error { func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID, alias lnwire.ShortChannelID) error { - return kvdb.Update(m.backend, func(tx kvdb.RwTx) error { + m.Lock() + defer m.Unlock() + + err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { bucket, err := tx.CreateTopLevelBucket(invoiceAliasBucket) if err != nil { return err @@ -320,36 +358,30 @@ func (m *Manager) PutPeerAlias(chanID lnwire.ChannelID, byteOrder.PutUint64(scratch[:], alias.ToUint64()) return bucket.Put(chanID[:], scratch[:]) }, func() {}) + if err != nil { + return err + } + + // Now that the database state has been updated, we can update it in + // our cache. + m.peerAlias[chanID] = alias + + return nil } // GetPeerAlias retrieves a peer's alias SCID by the channel's ChanID. -func (m *Manager) GetPeerAlias(chanID lnwire.ChannelID) ( - lnwire.ShortChannelID, error) { +func (m *Manager) GetPeerAlias(chanID lnwire.ChannelID) (lnwire.ShortChannelID, + error) { - var alias lnwire.ShortChannelID + m.RLock() + defer m.RUnlock() - err := kvdb.Update(m.backend, func(tx kvdb.RwTx) error { - bucket, err := tx.CreateTopLevelBucket(invoiceAliasBucket) - if err != nil { - return err - } - - aliasBytes := bucket.Get(chanID[:]) - if aliasBytes == nil { - return nil - } - - alias = lnwire.NewShortChanIDFromInt( - byteOrder.Uint64(aliasBytes), - ) - return nil - }, func() {}) - - if alias == hop.Source { - return alias, errNoPeerAlias + alias, ok := m.peerAlias[chanID] + if !ok || alias == hop.Source { + return lnwire.ShortChannelID{}, errNoPeerAlias } - return alias, err + return alias, nil } // RequestAlias returns a new ALIAS ShortChannelID to the caller by allocating diff --git a/docs/release-notes/release-notes-0.15.2.md b/docs/release-notes/release-notes-0.15.2.md index 841d30020..dc295ba4c 100644 --- a/docs/release-notes/release-notes-0.15.2.md +++ b/docs/release-notes/release-notes-0.15.2.md @@ -16,5 +16,16 @@ # Contributors (Alphabetical Order) + +## Performance improvements + +* [Refactor hop hint selection + algorithm](https://github.com/lightningnetwork/lnd/pull/6914) + + +# Contributors (Alphabetical Order) + * Eugene Siegel +* Jordi Montes * Oliver Gugger + diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index b47e3bb8f..c37ed001d 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -7,6 +7,8 @@ import ( "errors" "fmt" "math" + mathRand "math/rand" + "sort" "time" "github.com/btcsuite/btcd/btcec/v2" @@ -35,6 +37,10 @@ const ( // inbound capacity we want our hop hints to represent, allowing us to // have some leeway if peers go offline. hopHintFactor = 2 + + // maxHopHints is the maximum number of hint paths that will be included + // in an invoice. + maxHopHints = 20 ) // AddInvoiceConfig contains dependencies for invoice creation. @@ -126,8 +132,8 @@ type AddInvoiceData struct { // NOTE: Preimage should always be set to nil when this value is true. Amp bool - // RouteHints are optional route hints that can each be individually used - // to assist in reaching the invoice's destination. + // RouteHints are optional route hints that can each be individually + // used to assist in reaching the invoice's destination. RouteHints [][]zpay32.HopHint } @@ -159,7 +165,9 @@ func (d *AddInvoiceData) paymentHashAndPreimage() ( // ampPaymentHashAndPreimage returns the payment hash to use for an AMP invoice. // The preimage will always be nil. -func (d *AddInvoiceData) ampPaymentHashAndPreimage() (*lntypes.Preimage, lntypes.Hash, error) { +func (d *AddInvoiceData) ampPaymentHashAndPreimage() (*lntypes.Preimage, + lntypes.Hash, error) { + switch { // Preimages cannot be set on AMP invoice. case d.Preimage != nil: @@ -184,7 +192,9 @@ func (d *AddInvoiceData) ampPaymentHashAndPreimage() (*lntypes.Preimage, lntypes // mppPaymentHashAndPreimage returns the payment hash and preimage to use for an // MPP invoice. -func (d *AddInvoiceData) mppPaymentHashAndPreimage() (*lntypes.Preimage, lntypes.Hash, error) { +func (d *AddInvoiceData) mppPaymentHashAndPreimage() (*lntypes.Preimage, + lntypes.Hash, error) { + var ( paymentPreimage *lntypes.Preimage paymentHash lntypes.Hash @@ -235,11 +245,14 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // exceed the maximum values for either of the fields. if len(invoice.Memo) > channeldb.MaxMemoSize { return nil, nil, fmt.Errorf("memo too large: %v bytes "+ - "(maxsize=%v)", len(invoice.Memo), channeldb.MaxMemoSize) + "(maxsize=%v)", len(invoice.Memo), + channeldb.MaxMemoSize) } - if len(invoice.DescriptionHash) > 0 && len(invoice.DescriptionHash) != 32 { - return nil, nil, fmt.Errorf("description hash is %v bytes, must be 32", - len(invoice.DescriptionHash)) + if len(invoice.DescriptionHash) > 0 && + len(invoice.DescriptionHash) != 32 { + + return nil, nil, fmt.Errorf("description hash is %v bytes, "+ + "must be 32", len(invoice.DescriptionHash)) } // We set the max invoice amount to 100k BTC, which itself is several @@ -281,8 +294,8 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, addr, err := btcutil.DecodeAddress(invoice.FallbackAddr, cfg.ChainParams) if err != nil { - return nil, nil, fmt.Errorf("invalid fallback address: %v", - err) + return nil, nil, fmt.Errorf("invalid fallback "+ + "address: %v", err) } options = append(options, zpay32.FallbackAddr(addr)) } @@ -314,11 +327,13 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // Otherwise, use the default AMP expiry. default: - options = append(options, zpay32.Expiry(DefaultAMPInvoiceExpiry)) + defaultExpiry := zpay32.Expiry(DefaultAMPInvoiceExpiry) + options = append(options, defaultExpiry) } - // If the description hash is set, then we add it do the list of options. - // If not, use the memo field as the payment request description. + // If the description hash is set, then we add it do the list of + // options. If not, use the memo field as the payment request + // description. if len(invoice.DescriptionHash) > 0 { var descHash [32]byte copy(descHash[:], invoice.DescriptionHash[:]) @@ -333,8 +348,10 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, // an option on the command line when creating an invoice. switch { case invoice.CltvExpiry > math.MaxUint16: - return nil, nil, fmt.Errorf("CLTV delta of %v is too large, max "+ - "accepted is: %v", invoice.CltvExpiry, math.MaxUint16) + return nil, nil, fmt.Errorf("CLTV delta of %v is too large, "+ + "max accepted is: %v", invoice.CltvExpiry, + math.MaxUint16) + case invoice.CltvExpiry != 0: // Disallow user-chosen final CLTV deltas below the required // minimum. @@ -346,99 +363,52 @@ func AddInvoice(ctx context.Context, cfg *AddInvoiceConfig, options = append(options, zpay32.CLTVExpiry(invoice.CltvExpiry)) + default: // TODO(roasbeef): assumes set delta between versions - defaultDelta := cfg.DefaultCLTVExpiry - options = append(options, zpay32.CLTVExpiry(uint64(defaultDelta))) + defaultCLTVExpiry := uint64(cfg.DefaultCLTVExpiry) + options = append(options, zpay32.CLTVExpiry(defaultCLTVExpiry)) } - // We make sure that the given invoice routing hints number is within the - // valid range - if len(invoice.RouteHints) > 20 { - return nil, nil, fmt.Errorf("number of routing hints must not exceed " + - "maximum of 20") + // We make sure that the given invoice routing hints number is within + // the valid range + if len(invoice.RouteHints) > maxHopHints { + return nil, nil, fmt.Errorf("number of routing hints must "+ + "not exceed maximum of %v", maxHopHints) } - // We continue by populating the requested routing hints indexing their - // corresponding channels so we won't duplicate them. - forcedHints := make(map[uint64]struct{}) - for _, h := range invoice.RouteHints { - if len(h) == 0 { - return nil, nil, fmt.Errorf("number of hop hint within a route must " + - "be positive") + // Include route hints if needed. + if len(invoice.RouteHints) > 0 || invoice.Private { + // Validate provided hop hints. + for _, hint := range invoice.RouteHints { + if len(hint) == 0 { + return nil, nil, fmt.Errorf("number of hop " + + "hint within a route must be positive") + } } - options = append(options, zpay32.RouteHint(h)) - // Only this first hop is our direct channel. - forcedHints[h[0].ChannelID] = struct{}{} - } + totalHopHints := len(invoice.RouteHints) + if invoice.Private { + totalHopHints = maxHopHints + } - // If we were requested to include routing hints in the invoice, then - // we'll fetch all of our available private channels and create routing - // hints for them. - if invoice.Private { - openChannels, err := cfg.ChanDB.FetchAllChannels() + hopHintsCfg := newSelectHopHintsCfg(cfg, totalHopHints) + hopHints, err := PopulateHopHints( + hopHintsCfg, amtMSat, invoice.RouteHints, + ) if err != nil { - return nil, nil, fmt.Errorf("could not fetch all channels") + return nil, nil, fmt.Errorf("unable to populate hop "+ + "hints: %v", err) } - if len(openChannels) > 0 { - // We filter the channels by excluding the ones that were specified by - // the caller and were already added. - var filteredChannels []*HopHintInfo - for _, c := range openChannels { - if _, ok := forcedHints[c.ShortChanID().ToUint64()]; ok { - continue - } + // Convert our set of selected hop hints into route + // hints and add to our invoice options. + for _, hopHint := range hopHints { + routeHint := zpay32.RouteHint(hopHint) - // If this is a zero-conf channel, check if the - // confirmed SCID was used in forcedHints. - realScid := c.ZeroConfRealScid().ToUint64() - if c.IsZeroConf() { - if _, ok := forcedHints[realScid]; ok { - continue - } - } - - chanID := lnwire.NewChanIDFromOutPoint( - &c.FundingOutpoint, - ) - - // Check whether the the peer's alias was - // provided in forcedHints. - peerAlias, _ := cfg.GetAlias(chanID) - peerScid := peerAlias.ToUint64() - if _, ok := forcedHints[peerScid]; ok { - continue - } - - isActive := cfg.IsChannelActive(chanID) - - hopHintInfo := newHopHintInfo(c, isActive) - filteredChannels = append( - filteredChannels, hopHintInfo, - ) - } - - // We'll restrict the number of individual route hints - // to 20 to avoid creating overly large invoices. - numMaxHophints := 20 - len(forcedHints) - - hopHintsCfg := newSelectHopHintsCfg(cfg) - hopHints := SelectHopHints( - amtMSat, hopHintsCfg, filteredChannels, - numMaxHophints, + options = append( + options, routeHint, ) - - // Convert our set of selected hop hints into route - // hints and add to our invoice options. - for _, hopHint := range hopHints { - routeHint := zpay32.RouteHint(hopHint) - - options = append( - options, routeHint, - ) - } } } @@ -576,30 +546,6 @@ func chanCanBeHopHint(channel *HopHintInfo, cfg *SelectHopHintsCfg) ( return remotePolicy, true } -// addHopHint creates a hop hint out of the passed channel and channel policy. -// The new hop hint is appended to the passed slice. -func addHopHint(hopHints *[][]zpay32.HopHint, - channel *HopHintInfo, chanPolicy *channeldb.ChannelEdgePolicy, - aliasScid lnwire.ShortChannelID) { - - hopHint := zpay32.HopHint{ - NodeID: channel.RemotePubkey, - ChannelID: channel.ShortChannelID, - FeeBaseMSat: uint32(chanPolicy.FeeBaseMSat), - FeeProportionalMillionths: uint32( - chanPolicy.FeeProportionalMillionths, - ), - CLTVExpiryDelta: chanPolicy.TimeLockDelta, - } - - var defaultScid lnwire.ShortChannelID - if aliasScid != defaultScid { - hopHint.ChannelID = aliasScid.ToUint64() - } - - *hopHints = append(*hopHints, []zpay32.HopHint{hopHint}) -} - // HopHintInfo contains the channel information required to create a hop hint. type HopHintInfo struct { // IsPublic indicates whether a channel is advertised to the network. @@ -647,6 +593,22 @@ func newHopHintInfo(c *channeldb.OpenChannel, isActive bool) *HopHintInfo { } } +// newHopHint returns a new hop hint using the relevant data from a hopHintInfo +// and a ChannelEdgePolicy. +func newHopHint(hopHintInfo *HopHintInfo, + chanPolicy *channeldb.ChannelEdgePolicy) zpay32.HopHint { + + return zpay32.HopHint{ + NodeID: hopHintInfo.RemotePubkey, + ChannelID: hopHintInfo.ShortChannelID, + FeeBaseMSat: uint32(chanPolicy.FeeBaseMSat), + FeeProportionalMillionths: uint32( + chanPolicy.FeeProportionalMillionths, + ), + CLTVExpiryDelta: chanPolicy.TimeLockDelta, + } +} + // SelectHopHintsCfg contains the dependencies required to obtain hop hints // for an invoice. type SelectHopHintsCfg struct { @@ -664,169 +626,208 @@ type SelectHopHintsCfg struct { // GetAlias allows the peer's alias SCID to be retrieved for private // option_scid_alias channels. GetAlias func(lnwire.ChannelID) (lnwire.ShortChannelID, error) + + // FetchAllChannels retrieves all open channels currently stored + // within the database. + FetchAllChannels func() ([]*channeldb.OpenChannel, error) + + // IsChannelActive checks whether the channel identified by the provided + // ChannelID is considered active. + IsChannelActive func(chanID lnwire.ChannelID) bool + + // MaxHopHints is the maximum number of hop hints we are interested in. + MaxHopHints int } -func newSelectHopHintsCfg(invoicesCfg *AddInvoiceConfig) *SelectHopHintsCfg { +func newSelectHopHintsCfg(invoicesCfg *AddInvoiceConfig, + maxHopHints int) *SelectHopHintsCfg { + return &SelectHopHintsCfg{ + FetchAllChannels: invoicesCfg.ChanDB.FetchAllChannels, + IsChannelActive: invoicesCfg.IsChannelActive, IsPublicNode: invoicesCfg.Graph.IsPublicNode, FetchChannelEdgesByID: invoicesCfg.Graph.FetchChannelEdgesByID, GetAlias: invoicesCfg.GetAlias, + MaxHopHints: maxHopHints, } } // sufficientHints checks whether we have sufficient hop hints, based on the -// following criteria: -// - Hop hint count: limit to a set number of hop hints, regardless of whether -// we've reached our invoice amount or not. -// - Total incoming capacity: limit to our invoice amount * scaling factor to -// allow for some of our links going offline. +// any of the following criteria: +// - Hop hint count: the number of hints have reach our max target. +// - Total incoming capacity: the sum of the remote balance amount in the +// hints is bigger of equal than our target (currently twice the invoice +// amount) // // We limit our number of hop hints like this to keep our invoice size down, // and to avoid leaking all our private channels when we don't need to. -func sufficientHints(numHints, maxHints, scalingFactor int, amount, - totalHintAmount lnwire.MilliSatoshi) bool { +func sufficientHints(nHintsLeft int, currentAmount, + targetAmount lnwire.MilliSatoshi) bool { - if numHints >= maxHints { - log.Debug("Reached maximum number of hop hints") + if nHintsLeft <= 0 { + log.Debugf("Reached targeted number of hop hints") return true } - requiredAmount := amount * lnwire.MilliSatoshi(scalingFactor) - if totalHintAmount >= requiredAmount { + if currentAmount >= targetAmount { log.Debugf("Total hint amount: %v has reached target hint "+ - "bandwidth: %v (invoice amount: %v * factor: %v)", - totalHintAmount, requiredAmount, amount, - scalingFactor) - + "bandwidth: %v", currentAmount, targetAmount) return true } return false } -// SelectHopHints will select up to numMaxHophints from the set of passed open +// getPotentialHints returns a slice of open channels that should be considered +// for the hopHint list in an invoice. The slice is sorted in descending order +// based on the remote balance. +func getPotentialHints(cfg *SelectHopHintsCfg) ([]*channeldb.OpenChannel, + error) { + + // TODO(positiveblue): get the channels slice already filtered by + // private == true and sorted by RemoteBalance? + openChannels, err := cfg.FetchAllChannels() + if err != nil { + return nil, err + } + + privateChannels := make([]*channeldb.OpenChannel, 0, len(openChannels)) + for _, oc := range openChannels { + isPublic := oc.ChannelFlags&lnwire.FFAnnounceChannel != 0 + if !isPublic { + privateChannels = append(privateChannels, oc) + } + } + + // Sort the channels in descending remote balance. + compareRemoteBalance := func(i, j int) bool { + iBalance := privateChannels[i].LocalCommitment.RemoteBalance + jBalance := privateChannels[j].LocalCommitment.RemoteBalance + return iBalance > jBalance + } + sort.Slice(privateChannels, compareRemoteBalance) + + return privateChannels, nil +} + +// shouldIncludeChannel returns true if the channel passes all the checks to +// be a hopHint in a given invoice. +func shouldIncludeChannel(cfg *SelectHopHintsCfg, + channel *channeldb.OpenChannel, + alreadyIncluded map[uint64]bool) (zpay32.HopHint, lnwire.MilliSatoshi, + bool) { + + if _, ok := alreadyIncluded[channel.ShortChannelID.ToUint64()]; ok { + return zpay32.HopHint{}, 0, false + } + + chanID := lnwire.NewChanIDFromOutPoint( + &channel.FundingOutpoint, + ) + + hopHintInfo := newHopHintInfo(channel, cfg.IsChannelActive(chanID)) + + // If this channel can't be a hop hint, then skip it. + edgePolicy, canBeHopHint := chanCanBeHopHint(hopHintInfo, cfg) + if edgePolicy == nil || !canBeHopHint { + return zpay32.HopHint{}, 0, false + } + + if hopHintInfo.ScidAliasFeature { + alias, err := cfg.GetAlias(chanID) + if err != nil { + return zpay32.HopHint{}, 0, false + } + + if alias.IsDefault() || alreadyIncluded[alias.ToUint64()] { + return zpay32.HopHint{}, 0, false + } + + hopHintInfo.ShortChannelID = alias.ToUint64() + } + + // Now that we know this channel use usable, add it as a hop hint and + // the indexes we'll use later. + hopHint := newHopHint(hopHintInfo, edgePolicy) + return hopHint, hopHintInfo.RemoteBalance, true +} + +// selectHopHints iterates a list of potential hints selecting the valid hop +// hints until we have enough hints or run out of channels. +// +// NOTE: selectHopHints expects potentialHints to be already sorted in +// descending priority. +func selectHopHints(cfg *SelectHopHintsCfg, nHintsLeft int, + targetBandwidth lnwire.MilliSatoshi, + potentialHints []*channeldb.OpenChannel, + alreadyIncluded map[uint64]bool) [][]zpay32.HopHint { + + currentBandwidth := lnwire.MilliSatoshi(0) + hopHints := make([][]zpay32.HopHint, 0, nHintsLeft) + for _, channel := range potentialHints { + enoughHopHints := sufficientHints( + nHintsLeft, currentBandwidth, targetBandwidth, + ) + if enoughHopHints { + return hopHints + } + + hopHint, remoteBalance, include := shouldIncludeChannel( + cfg, channel, alreadyIncluded, + ) + + if include { + // Now that we now this channel use usable, add it as a hop + // hint and the indexes we'll use later. + hopHints = append(hopHints, []zpay32.HopHint{hopHint}) + currentBandwidth += remoteBalance + nHintsLeft-- + } + } + + // We do not want to leak information about how our remote balance is + // distributed in our private channels. We shuffle the selected ones + // here so they do not appear in order in the invoice. + mathRand.Shuffle( + len(hopHints), func(i, j int) { + hopHints[i], hopHints[j] = hopHints[j], hopHints[i] + }, + ) + return hopHints +} + +// PopulateHopHints will select up to cfg.MaxHophints from the current open // channels. The set of hop hints will be returned as a slice of functional // options that'll append the route hint to the set of all route hints. // // TODO(roasbeef): do proper sub-set sum max hints usually << numChans. -func SelectHopHints(amtMSat lnwire.MilliSatoshi, cfg *SelectHopHintsCfg, - openChannels []*HopHintInfo, - numMaxHophints int) [][]zpay32.HopHint { +func PopulateHopHints(cfg *SelectHopHintsCfg, amtMSat lnwire.MilliSatoshi, + forcedHints [][]zpay32.HopHint) ([][]zpay32.HopHint, error) { - // We'll add our hop hints in two passes, first we'll add all channels - // that are eligible to be hop hints, and also have a local balance - // above the payment amount. - var totalHintBandwidth lnwire.MilliSatoshi - hopHintChans := make(map[wire.OutPoint]struct{}) - hopHints := make([][]zpay32.HopHint, 0, numMaxHophints) - for _, channel := range openChannels { - enoughHopHints := sufficientHints( - len(hopHints), numMaxHophints, hopHintFactor, amtMSat, - totalHintBandwidth, - ) - if enoughHopHints { - log.Debugf("First pass of hop selection has " + - "sufficient hints") + hopHints := forcedHints - return hopHints - } - - // If this channel can't be a hop hint, then skip it. - edgePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg) - if edgePolicy == nil || !canBeHopHint { - continue - } - - // Similarly, in this first pass, we'll ignore all channels in - // isolation can't satisfy this payment. - if channel.RemoteBalance < amtMSat { - continue - } - - // Lookup and see if there is an alias SCID that exists. - chanID := lnwire.NewChanIDFromOutPoint( - &channel.FundingOutpoint, - ) - alias, _ := cfg.GetAlias(chanID) - - // If this is a channel where the option-scid-alias feature bit - // was negotiated and the alias is not yet assigned, we cannot - // issue an invoice. Doing so might expose the confirmed SCID - // of a private channel. - if channel.ScidAliasFeature { - var defaultScid lnwire.ShortChannelID - if alias == defaultScid { - continue - } - } - - // Now that we now this channel use usable, add it as a hop - // hint and the indexes we'll use later. - addHopHint(&hopHints, channel, edgePolicy, alias) - - hopHintChans[channel.FundingOutpoint] = struct{}{} - totalHintBandwidth += channel.RemoteBalance + // If we already have enough hints we don't need to add any more. + nHintsLeft := cfg.MaxHopHints - len(hopHints) + if nHintsLeft <= 0 { + return hopHints, nil } - // In this second pass we'll add channels, and we'll either stop when - // we have 20 hop hints, we've run through all the available channels, - // or if the sum of available bandwidth in the routing hints exceeds 2x - // the payment amount. We do 2x here to account for a margin of error - // if some of the selected channels no longer become operable. - for i := 0; i < len(openChannels); i++ { - enoughHopHints := sufficientHints( - len(hopHints), numMaxHophints, hopHintFactor, amtMSat, - totalHintBandwidth, - ) - if enoughHopHints { - log.Debugf("Second pass of hop selection has " + - "sufficient hints") - - return hopHints - } - - channel := openChannels[i] - - // Skip the channel if we already selected it. - if _, ok := hopHintChans[channel.FundingOutpoint]; ok { - continue - } - - // If the channel can't be a hop hint, then we'll skip it. - // Otherwise, we'll use the policy information to populate the - // hop hint. - remotePolicy, canBeHopHint := chanCanBeHopHint(channel, cfg) - if !canBeHopHint || remotePolicy == nil { - continue - } - - // Lookup and see if there's an alias SCID that exists. - chanID := lnwire.NewChanIDFromOutPoint( - &channel.FundingOutpoint, - ) - alias, _ := cfg.GetAlias(chanID) - - // If this is a channel where the option-scid-alias feature bit - // was negotiated and the alias is not yet assigned, we cannot - // issue an invoice. Doing so might expose the confirmed SCID - // of a private channel. - if channel.ScidAliasFeature { - var defaultScid lnwire.ShortChannelID - if alias == defaultScid { - continue - } - } - - // Include the route hint in our set of options that will be - // used when creating the invoice. - addHopHint(&hopHints, channel, remotePolicy, alias) - - // As we've just added a new hop hint, we'll accumulate it's - // available balance now to update our tally. - // - // TODO(roasbeef): have a cut off based on min bandwidth? - totalHintBandwidth += channel.RemoteBalance + alreadyIncluded := make(map[uint64]bool) + for _, hopHint := range hopHints { + alreadyIncluded[hopHint[0].ChannelID] = true } - return hopHints + potentialHints, err := getPotentialHints(cfg) + if err != nil { + return nil, err + } + + targetBandwidth := amtMSat * hopHintFactor + selectedHints := selectHopHints( + cfg, nHintsLeft, targetBandwidth, potentialHints, + alreadyIncluded, + ) + + hopHints = append(hopHints, selectedHints...) + return hopHints, nil } diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index 4720f7ad0..d7ee620eb 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -2,7 +2,7 @@ package invoicesrpc import ( "encoding/hex" - "errors" + "fmt" "testing" "github.com/btcsuite/btcd/btcec/v2" @@ -24,7 +24,32 @@ func (h *hopHintsConfigMock) IsPublicNode(pubKey [33]byte) (bool, error) { return args.Bool(0), args.Error(1) } -// FetchChannelEdgesByID mocks channel edge lookup. +// IsChannelActive is used to generate valid hop hints. +func (h *hopHintsConfigMock) IsChannelActive(chanID lnwire.ChannelID) bool { + args := h.Mock.Called(chanID) + return args.Bool(0) +} + +// GetAlias allows the peer's alias SCID to be retrieved for private +// option_scid_alias channels. +func (h *hopHintsConfigMock) GetAlias( + chanID lnwire.ChannelID) (lnwire.ShortChannelID, error) { + + args := h.Mock.Called(chanID) + return args.Get(0).(lnwire.ShortChannelID), args.Error(1) +} + +// FetchAllChannels retrieves all open channels currently stored +// within the database. +func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, + error) { + + args := h.Mock.Called() + return args.Get(0).([]*channeldb.OpenChannel), args.Error(1) +} + +// FetchChannelEdgesByID attempts to lookup the two directed edges for +// the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( *channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, *channeldb.ChannelEdgePolicy, error) { @@ -46,584 +71,716 @@ func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( return edgeInfo, policy1, policy2, err } -// TestSelectHopHints tests selection of hop hints for a node with private -// channels. -func TestSelectHopHints(t *testing.T) { - var ( - // We need to serialize our pubkey in SelectHopHints so it - // needs to be valid. - pubkeyBytes, _ = hex.DecodeString( - "598ec453728e0ffe0ae2f5e174243cf58f2" + - "a3f2c83d2457b43036db568b11093", - ) - pubKeyY = new(btcec.FieldVal) - _ = pubKeyY.SetByteSlice(pubkeyBytes) - pubkey = btcec.NewPublicKey( - new(btcec.FieldVal).SetInt(4), - pubKeyY, - ) - compressed = pubkey.SerializeCompressed() - - publicChannel = &HopHintInfo{ - IsPublic: true, - IsActive: true, - FundingOutpoint: wire.OutPoint{ - Index: 0, - }, - RemoteBalance: 10, - ShortChannelID: 0, - } - - inactiveChannel = &HopHintInfo{ - IsPublic: false, - IsActive: false, - } - - // Create a private channel that we'll generate hints from. - private1ShortID uint64 = 1 - privateChannel1 = &HopHintInfo{ - IsPublic: false, - IsActive: true, - FundingOutpoint: wire.OutPoint{ - Index: 1, - }, - RemotePubkey: pubkey, - RemoteBalance: 100, - ShortChannelID: private1ShortID, - } - - // Create a edge policy for private channel 1. - privateChan1Policy = &channeldb.ChannelEdgePolicy{ - FeeBaseMSat: 10, - FeeProportionalMillionths: 100, - TimeLockDelta: 1000, - } - - // Create an edge policy different to ours which we'll use for - // the other direction - otherChanPolicy = &channeldb.ChannelEdgePolicy{ - FeeBaseMSat: 90, - FeeProportionalMillionths: 900, - TimeLockDelta: 9000, - } - - // Create a hop hint based on privateChan1Policy. - privateChannel1Hint = zpay32.HopHint{ - NodeID: privateChannel1.RemotePubkey, - ChannelID: private1ShortID, - FeeBaseMSat: uint32(privateChan1Policy.FeeBaseMSat), - FeeProportionalMillionths: uint32( - privateChan1Policy.FeeProportionalMillionths, - ), - CLTVExpiryDelta: privateChan1Policy.TimeLockDelta, - } - - // Create a second private channel that we'll use for hints. - private2ShortID uint64 = 2 - privateChannel2 = &HopHintInfo{ - IsPublic: false, - IsActive: true, - FundingOutpoint: wire.OutPoint{ - Index: 2, - }, - RemotePubkey: pubkey, - RemoteBalance: 100, - ShortChannelID: private2ShortID, - } - - // Create a edge policy for private channel 1. - privateChan2Policy = &channeldb.ChannelEdgePolicy{ - FeeBaseMSat: 20, - FeeProportionalMillionths: 200, - TimeLockDelta: 2000, - } - - // Create a hop hint based on privateChan2Policy. - privateChannel2Hint = zpay32.HopHint{ - NodeID: privateChannel2.RemotePubkey, - ChannelID: private2ShortID, - FeeBaseMSat: uint32(privateChan2Policy.FeeBaseMSat), - FeeProportionalMillionths: uint32( - privateChan2Policy.FeeProportionalMillionths, - ), - CLTVExpiryDelta: privateChan2Policy.TimeLockDelta, - } - - // Create a third private channel that we'll use for hints. - private3ShortID uint64 = 3 - privateChannel3 = &HopHintInfo{ - IsPublic: false, - IsActive: true, - FundingOutpoint: wire.OutPoint{ - Index: 3, - }, - RemotePubkey: pubkey, - RemoteBalance: 100, - ShortChannelID: private3ShortID, - } - - // Create a edge policy for private channel 1. - privateChan3Policy = &channeldb.ChannelEdgePolicy{ - FeeBaseMSat: 30, - FeeProportionalMillionths: 300, - TimeLockDelta: 3000, - } - - // Create a hop hint based on privateChan2Policy. - privateChannel3Hint = zpay32.HopHint{ - NodeID: privateChannel3.RemotePubkey, - ChannelID: private3ShortID, - FeeBaseMSat: uint32(privateChan3Policy.FeeBaseMSat), - FeeProportionalMillionths: uint32( - privateChan3Policy.FeeProportionalMillionths, - ), - CLTVExpiryDelta: privateChan3Policy.TimeLockDelta, - } +// getTestPubKey returns a valid parsed pub key to be used in our tests. +func getTestPubKey() *btcec.PublicKey { + pubkeyBytes, _ := hex.DecodeString( + "598ec453728e0ffe0ae2f5e174243cf58f2" + + "a3f2c83d2457b43036db568b11093", ) - - // We can't copy in the above var decls, so we copy in our pubkey here. - var peer [33]byte - copy(peer[:], compressed) - - var ( - // We pick our policy based on which node (1 or 2) the remote - // peer is. Here we create two different sets of edge - // information. One where our peer is node 1, the other where - // our peer is edge 2. This ensures that we always pick the - // right edge policy for our hint. - infoNode1 = &channeldb.ChannelEdgeInfo{ - NodeKey1Bytes: peer, - } - - infoNode2 = &channeldb.ChannelEdgeInfo{ - NodeKey1Bytes: [33]byte{9, 9, 9}, - NodeKey2Bytes: peer, - } - - // setMockChannelUsed preps our mock for the case where we - // want our private channel to be used for a hop hint. - setMockChannelUsed = func(h *hopHintsConfigMock, - shortID uint64, - policy *channeldb.ChannelEdgePolicy) { - - // Return public node = true so that we'll consider - // this node for our hop hints. - h.Mock.On( - "IsPublicNode", peer, - ).Once().Return(true, nil) - - // When it gets time to find an edge policy for this - // node, fail it. We won't use it as a hop hint. - h.Mock.On( - "FetchChannelEdgesByID", - shortID, - ).Once().Return( - infoNode1, policy, otherChanPolicy, nil, - ) - } + pubKeyY := new(btcec.FieldVal) + _ = pubKeyY.SetByteSlice(pubkeyBytes) + pubkey := btcec.NewPublicKey( + new(btcec.FieldVal).SetInt(4), + pubKeyY, ) + return pubkey +} - tests := []struct { - name string - setupMock func(*hopHintsConfigMock) - amount lnwire.MilliSatoshi - channels []*HopHintInfo - numHints int - - // expectedHints is the set of hop hints that we expect. We - // initialize this slice with our max hop hints length, so this - // value won't be nil even if its empty. - expectedHints [][]zpay32.HopHint - }{ - { - // We don't need hop hints for public channels. - name: "channel is public", - // When a channel is public, we exit before we make any - // calls. - setupMock: func(h *hopHintsConfigMock) { - }, - amount: 100, - channels: []*HopHintInfo{ - publicChannel, - }, - numHints: 2, - expectedHints: nil, +var shouldIncludeChannelTestCases = []struct { + name string + setupMock func(*hopHintsConfigMock) + channel *channeldb.OpenChannel + alreadyIncluded map[uint64]bool + cfg *SelectHopHintsCfg + hopHint zpay32.HopHint + remoteBalance lnwire.MilliSatoshi + include bool +}{{ + name: "already included channels should not be included " + + "again", + alreadyIncluded: map[uint64]bool{1: true}, + channel: &channeldb.OpenChannel{ + ShortChannelID: lnwire.NewShortChanIDFromInt(1), + }, + include: false, +}, { + name: "public channels should not be included", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - name: "channel is inactive", - setupMock: func(h *hopHintsConfigMock) {}, - amount: 100, - channels: []*HopHintInfo{ - inactiveChannel, - }, - numHints: 2, - expectedHints: nil, + ChannelFlags: lnwire.FFAnnounceChannel, + }, +}, { + name: "not active channels should not be included", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(false) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - // If we can't lookup an edge policy, we skip channels. - name: "no edge policy", - setupMock: func(h *hopHintsConfigMock) { - // Return public node = true so that we'll - // consider this node for our hop hints. - h.Mock.On( - "IsPublicNode", peer, - ).Return(true, nil) + }, + include: false, +}, { + name: "a channel with a not public peer should not be included", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) - // When it gets time to find an edge policy for - // this node, fail it. We won't use it as a - // hop hint. - h.Mock.On( - "FetchChannelEdgesByID", - mock.Anything, - ).Return( - nil, nil, nil, - errors.New("no edge"), - ).Times(4) - }, - amount: 100, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 3, - expectedHints: nil, + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(false, nil) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - // If one of our private channels belongs to a node - // that is otherwise not announced to the network, we're - // polite and don't include them (they can't be routed - // through anyway). - name: "node is private", - setupMock: func(h *hopHintsConfigMock) { - // Return public node = false so that we'll - // give up on this node. - h.Mock.On( - "IsPublicNode", peer, - ).Return(false, nil) - }, - amount: 100, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 1, - expectedHints: nil, + IdentityPub: getTestPubKey(), + }, + include: false, +}, { + name: "if we are unable to fetch the edge policy for the channel it " + + "should not be included", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return(nil, nil, nil, fmt.Errorf("no edge")) + + // TODO(positiveblue): check that the func is called with the + // right scid when we have access to the `confirmedscid` form + // here. + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return(nil, nil, nil, fmt.Errorf("no edge")) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - // This test case asserts that we limit our hop hints - // when we've reached our maximum number of hints. - name: "too many hints", - setupMock: func(h *hopHintsConfigMock) { - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - }, - // Set our amount to less than our channel balance of - // 100. - amount: 30, - channels: []*HopHintInfo{ - privateChannel1, privateChannel2, - }, - numHints: 1, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - }, + IdentityPub: getTestPubKey(), + }, + include: false, +}, { + name: "channels with the option-scid-alias but not assigned alias " + + "yet should not be included", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + + h.Mock.On( + "GetAlias", mock.Anything, + ).Once().Return(lnwire.ShortChannelID{}, nil) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - // If a channel has more balance than the amount we're - // looking for, it'll be added in our first pass. We - // can be sure we're adding it in our first pass because - // we assert that there are no additional calls to our - // mock (which would happen if we ran a second pass). - // - // We set our peer to be node 1 in our policy ordering. - name: "balance > total amount, node 1", - setupMock: func(h *hopHintsConfigMock) { - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - }, - // Our channel has balance of 100 (> 50). - amount: 50, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 2, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - }, + IdentityPub: getTestPubKey(), + ChanType: channeldb.ScidAliasFeatureBit, + }, + include: false, +}, { + name: "channels with the option-scid-alias and an alias that has " + + "already been included should not be included again", + alreadyIncluded: map[uint64]bool{5: true}, + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 0, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + alias := lnwire.ShortChannelID{TxPosition: 5} + h.Mock.On( + "GetAlias", mock.Anything, + ).Once().Return(alias, nil) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 0, }, - { - // As above, but we set our peer to be node 2 in our - // policy ordering. - name: "balance > total amount, node 2", - setupMock: func(h *hopHintsConfigMock) { - // Return public node = true so that we'll - // consider this node for our hop hints. - h.Mock.On( - "IsPublicNode", peer, - ).Return(true, nil) + IdentityPub: getTestPubKey(), + ChanType: channeldb.ScidAliasFeatureBit, + }, + include: false, +}, { + name: "channels that pass all the checks should be " + + "included, using policy 1", + alreadyIncluded: map[uint64]bool{5: true}, + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 1, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) - // When it gets time to find an edge policy for - // this node, fail it. We won't use it as a - // hop hint. - h.Mock.On( - "FetchChannelEdgesByID", - private1ShortID, - ).Return( - infoNode2, otherChanPolicy, - privateChan1Policy, nil, - ) + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + var selectedPolicy [33]byte + copy(selectedPolicy[:], getTestPubKey().SerializeCompressed()) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{ + NodeKey1Bytes: selectedPolicy, }, - // Our channel has balance of 100 (> 50). - amount: 50, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 2, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, + &channeldb.ChannelEdgePolicy{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, }, + &channeldb.ChannelEdgePolicy{}, + nil, + ) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 1, }, - { - // Since our balance is less than the amount we're - // looking to route, we expect this hint to be picked - // up in our second pass on the channel set. - name: "balance < total amount", - setupMock: func(h *hopHintsConfigMock) { - // We expect to call all our checks twice - // because we pick up this channel in the - // second round. - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - }, - // Our channel has balance of 100 (< 150). - amount: 150, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 2, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - }, + IdentityPub: getTestPubKey(), + ShortChannelID: lnwire.NewShortChanIDFromInt(12), + }, + hopHint: zpay32.HopHint{ + NodeID: getTestPubKey(), + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + ChannelID: 12, + CLTVExpiryDelta: 13, + }, + include: true, +}, { + name: "channels that pass all the checks should be " + + "included, using policy 2", + alreadyIncluded: map[uint64]bool{5: true}, + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 1, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, + }, nil, + ) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 1, }, - { - // Test the case where we hit our total amount of - // required liquidity in our first pass. - name: "first pass sufficient balance", - setupMock: func(h *hopHintsConfigMock) { - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - }, - // Divide our balance by hop hint factor so that the - // channel balance will always reach our factored up - // amount, even if we change this value. - amount: privateChannel1.RemoteBalance / hopHintFactor, - channels: []*HopHintInfo{ - privateChannel1, - }, - numHints: 2, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - }, + IdentityPub: getTestPubKey(), + ShortChannelID: lnwire.NewShortChanIDFromInt(12), + }, + hopHint: zpay32.HopHint{ + NodeID: getTestPubKey(), + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + ChannelID: 12, + CLTVExpiryDelta: 13, + }, + include: true, +}, { + name: "channels that pass all the checks and have an alias " + + "should be included with the alias", + alreadyIncluded: map[uint64]bool{5: true}, + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{ + Index: 1, + } + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{ + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + TimeLockDelta: 13, + }, nil, + ) + + aliasSCID := lnwire.NewShortChanIDFromInt(15) + + h.Mock.On( + "GetAlias", mock.Anything, + ).Once().Return(aliasSCID, nil) + }, + channel: &channeldb.OpenChannel{ + FundingOutpoint: wire.OutPoint{ + Index: 1, }, - { - // Setup our amount so that we don't have enough - // inbound total for our amount, but we hit our - // desired hint limit. - name: "second pass sufficient hint count", - setupMock: func(h *hopHintsConfigMock) { - // We expect all of our channels to be passed - // on in the first pass. - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) + IdentityPub: getTestPubKey(), + ShortChannelID: lnwire.NewShortChanIDFromInt(12), + ChanType: channeldb.ScidAliasFeatureBit, + }, + hopHint: zpay32.HopHint{ + NodeID: getTestPubKey(), + FeeBaseMSat: 1000, + FeeProportionalMillionths: 20, + ChannelID: 15, + CLTVExpiryDelta: 13, + }, + include: true, +}} - setMockChannelUsed( - h, private2ShortID, privateChan2Policy, - ) +func TestShouldIncludeChannel(t *testing.T) { + for _, tc := range shouldIncludeChannelTestCases { + tc := tc - // In the second pass, our first two channels - // should be added before we hit our hint count. - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - }, - // Add two channels that we'd want to use, but the - // second one will be cut off due to our hop hint count - // limit. - channels: []*HopHintInfo{ - privateChannel1, privateChannel2, - }, - // Set the amount we need to more than our two channels - // can provide us. - amount: privateChannel1.RemoteBalance + - privateChannel2.RemoteBalance, - numHints: 1, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - }, - }, - { - // Add three channels that are all less than the amount - // we wish to receive, but collectively will reach the - // total amount that we need. - name: "second pass reaches bandwidth requirement", - setupMock: func(h *hopHintsConfigMock) { - // In the first round, all channels should be - // passed on. - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - setMockChannelUsed( - h, private2ShortID, privateChan2Policy, - ) - - setMockChannelUsed( - h, private3ShortID, privateChan3Policy, - ) - - // In the second round, we'll pick up all of - // our hop hints. - setMockChannelUsed( - h, private1ShortID, privateChan1Policy, - ) - - setMockChannelUsed( - h, private2ShortID, privateChan2Policy, - ) - - setMockChannelUsed( - h, private3ShortID, privateChan3Policy, - ) - }, - channels: []*HopHintInfo{ - privateChannel1, privateChannel2, - privateChannel3, - }, - - // All of our channels have 100 inbound, so none will - // be picked up in the first round. - amount: 110, - numHints: 5, - expectedHints: [][]zpay32.HopHint{ - { - privateChannel1Hint, - }, - { - privateChannel2Hint, - }, - { - privateChannel3Hint, - }, - }, - }, - } - - getAlias := func(lnwire.ChannelID) (lnwire.ShortChannelID, error) { - return lnwire.ShortChannelID{}, nil - } - - for _, test := range tests { - test := test - - t.Run(test.name, func(t *testing.T) { // Create mock and prime it for the test case. mock := &hopHintsConfigMock{} - test.setupMock(mock) + if tc.setupMock != nil { + tc.setupMock(mock) + } defer mock.AssertExpectations(t) cfg := &SelectHopHintsCfg{ IsPublicNode: mock.IsPublicNode, + IsChannelActive: mock.IsChannelActive, FetchChannelEdgesByID: mock.FetchChannelEdgesByID, - GetAlias: getAlias, + GetAlias: mock.GetAlias, } - hints := SelectHopHints( - test.amount, cfg, test.channels, test.numHints, + hopHint, remoteBalance, include := shouldIncludeChannel( + cfg, tc.channel, tc.alreadyIncluded, ) - // SelectHopHints preallocates its hop hint slice, so - // we check that it is empty if we don't expect any - // hints, and otherwise assert that the two slices are - // equal. This allows tests to set their expected value - // to nil, rather than providing a preallocated empty - // slice. - if len(test.expectedHints) == 0 { - require.Zero(t, len(hints)) - } else { - require.Equal(t, test.expectedHints, hints) + require.Equal(t, tc.include, include) + if include { + require.Equal(t, tc.hopHint, hopHint) + require.Equal( + t, tc.remoteBalance, remoteBalance, + ) } }) } } -// TestSufficientHopHints tests limiting our hops to a set number of hints or -// scaled amount of capacity. -func TestSufficientHopHints(t *testing.T) { - t.Parallel() +var sufficientHintsTestCases = []struct { + name string + nHintsLeft int + currentAmount lnwire.MilliSatoshi + targetAmount lnwire.MilliSatoshi + done bool +}{{ + name: "not enoguh hints neither bandwidth", + nHintsLeft: 3, + currentAmount: 100, + targetAmount: 200, + done: false, +}, { + name: "enough hints", + nHintsLeft: 0, + done: true, +}, { + name: "enoguh bandwidth", + nHintsLeft: 1, + currentAmount: 200, + targetAmount: 200, + done: true, +}} - tests := []struct { - name string - numHints int - maxHints int - scalingFactor int - amount lnwire.MilliSatoshi - totalHintAmount lnwire.MilliSatoshi - sufficient bool - }{ - { - name: "not enough hints or amount", - numHints: 3, - maxHints: 10, - // We want to have at least 200, and we currently have - // 10. - scalingFactor: 2, - amount: 100, - totalHintAmount: 10, - sufficient: false, - }, - { - name: "enough hints", - numHints: 3, - maxHints: 3, - sufficient: true, - }, - { - name: "not enough hints, insufficient bandwidth", - numHints: 1, - maxHints: 3, - // We want at least 200, and we have enough. - scalingFactor: 2, - amount: 100, - totalHintAmount: 700, - sufficient: true, - }, - } +func TestSufficientHints(t *testing.T) { + for _, tc := range sufficientHintsTestCases { + tc := tc - for _, testCase := range tests { - sufficient := sufficientHints( - testCase.numHints, testCase.maxHints, - testCase.scalingFactor, testCase.amount, - testCase.totalHintAmount, - ) + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - require.Equal(t, testCase.sufficient, sufficient) + enoughHints := sufficientHints( + tc.nHintsLeft, tc.currentAmount, + tc.targetAmount, + ) + require.Equal(t, tc.done, enoughHints) + }) + } +} + +var populateHopHintsTestCases = []struct { + name string + setupMock func(*hopHintsConfigMock) + amount lnwire.MilliSatoshi + maxHopHints int + forcedHints [][]zpay32.HopHint + expectedHopHints [][]zpay32.HopHint +}{{ + name: "populate hop hints with forced hints", + maxHopHints: 1, + forcedHints: [][]zpay32.HopHint{ + { + {ChannelID: 12}, + }, + }, + expectedHopHints: [][]zpay32.HopHint{ + { + {ChannelID: 12}, + }, + }, +}, { + name: "populate hop hints stops when we reached the max number of " + + "hop hints allowed", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{Index: 9} + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + allChannels := []*channeldb.OpenChannel{ + { + FundingOutpoint: fundingOutpoint, + ShortChannelID: lnwire.NewShortChanIDFromInt(9), + IdentityPub: getTestPubKey(), + }, + // Have one empty channel that we should not process + // because we have already finished. + {}, + } + + h.Mock.On( + "FetchAllChannels", + ).Once().Return(allChannels, nil) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + }, + maxHopHints: 1, + amount: 1_000_000, + expectedHopHints: [][]zpay32.HopHint{ + { + { + NodeID: getTestPubKey(), + ChannelID: 9, + }, + }, + }, +}, { + name: "populate hop hints stops when we reached the targeted bandwidth", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{Index: 9} + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + remoteBalance := lnwire.MilliSatoshi(10_000_000) + allChannels := []*channeldb.OpenChannel{ + { + LocalCommitment: channeldb.ChannelCommitment{ + RemoteBalance: remoteBalance, + }, + FundingOutpoint: fundingOutpoint, + ShortChannelID: lnwire.NewShortChanIDFromInt(9), + IdentityPub: getTestPubKey(), + }, + // Have one empty channel that we should not process + // because we have already finished. + {}, + } + + h.Mock.On( + "FetchAllChannels", + ).Once().Return(allChannels, nil) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + }, + maxHopHints: 10, + amount: 1_000_000, + expectedHopHints: [][]zpay32.HopHint{ + { + { + NodeID: getTestPubKey(), + ChannelID: 9, + }, + }, + }, +}, { + name: "populate hop hints tries to use the channels with higher " + + "remote balance frist", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint := wire.OutPoint{Index: 9} + chanID := lnwire.NewChanIDFromOutPoint(&fundingOutpoint) + remoteBalance := lnwire.MilliSatoshi(10_000_000) + allChannels := []*channeldb.OpenChannel{ + // Because the channels with higher remote balance have + // enough bandwidth we should never use this one. + {}, + { + LocalCommitment: channeldb.ChannelCommitment{ + RemoteBalance: remoteBalance, + }, + FundingOutpoint: fundingOutpoint, + ShortChannelID: lnwire.NewShortChanIDFromInt(9), + IdentityPub: getTestPubKey(), + }, + } + + h.Mock.On( + "FetchAllChannels", + ).Once().Return(allChannels, nil) + + h.Mock.On( + "IsChannelActive", chanID, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + }, + maxHopHints: 1, + amount: 1_000_000, + expectedHopHints: [][]zpay32.HopHint{ + { + { + NodeID: getTestPubKey(), + ChannelID: 9, + }, + }, + }, +}, { + name: "populate hop hints stops after having considered all the open " + + "channels", + setupMock: func(h *hopHintsConfigMock) { + fundingOutpoint1 := wire.OutPoint{Index: 9} + chanID1 := lnwire.NewChanIDFromOutPoint(&fundingOutpoint1) + remoteBalance1 := lnwire.MilliSatoshi(10_000_000) + + fundingOutpoint2 := wire.OutPoint{Index: 2} + chanID2 := lnwire.NewChanIDFromOutPoint(&fundingOutpoint2) + remoteBalance2 := lnwire.MilliSatoshi(1_000_000) + + allChannels := []*channeldb.OpenChannel{ + // After sorting we will first process chanID1 and then + // chanID2. + { + LocalCommitment: channeldb.ChannelCommitment{ + RemoteBalance: remoteBalance2, + }, + FundingOutpoint: fundingOutpoint2, + ShortChannelID: lnwire.NewShortChanIDFromInt(2), + IdentityPub: getTestPubKey(), + }, + { + LocalCommitment: channeldb.ChannelCommitment{ + RemoteBalance: remoteBalance1, + }, + FundingOutpoint: fundingOutpoint1, + ShortChannelID: lnwire.NewShortChanIDFromInt(9), + IdentityPub: getTestPubKey(), + }, + } + + h.Mock.On( + "FetchAllChannels", + ).Once().Return(allChannels, nil) + + // Prepare the mock for the first channel. + h.Mock.On( + "IsChannelActive", chanID1, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + + // Prepare the mock for the second channel. + h.Mock.On( + "IsChannelActive", chanID2, + ).Once().Return(true) + + h.Mock.On( + "IsPublicNode", mock.Anything, + ).Once().Return(true, nil) + + h.Mock.On( + "FetchChannelEdgesByID", mock.Anything, + ).Once().Return( + &channeldb.ChannelEdgeInfo{}, + &channeldb.ChannelEdgePolicy{}, + &channeldb.ChannelEdgePolicy{}, nil, + ) + }, + maxHopHints: 10, + amount: 100_000_000, + expectedHopHints: [][]zpay32.HopHint{ + { + { + NodeID: getTestPubKey(), + ChannelID: 9, + }, + }, { + { + NodeID: getTestPubKey(), + ChannelID: 2, + }, + }, + }, +}} + +func TestPopulateHopHints(t *testing.T) { + for _, tc := range populateHopHintsTestCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + // Create mock and prime it for the test case. + mock := &hopHintsConfigMock{} + if tc.setupMock != nil { + tc.setupMock(mock) + } + defer mock.AssertExpectations(t) + + cfg := &SelectHopHintsCfg{ + IsPublicNode: mock.IsPublicNode, + IsChannelActive: mock.IsChannelActive, + FetchChannelEdgesByID: mock.FetchChannelEdgesByID, + GetAlias: mock.GetAlias, + FetchAllChannels: mock.FetchAllChannels, + MaxHopHints: tc.maxHopHints, + } + hopHints, err := PopulateHopHints( + cfg, tc.amount, tc.forcedHints, + ) + require.NoError(t, err) + // We shuffle the elements in the hop hint list so we + // need to compare the elements here. + require.ElementsMatch(t, tc.expectedHopHints, hopHints) + }) } } diff --git a/lnwire/short_channel_id.go b/lnwire/short_channel_id.go index f07de709f..d4da518b7 100644 --- a/lnwire/short_channel_id.go +++ b/lnwire/short_channel_id.go @@ -64,6 +64,12 @@ func (c *ShortChannelID) Record() tlv.Record { ) } +// IsDefault returns true if the ShortChannelID represents the zero value for +// its type. +func (c ShortChannelID) IsDefault() bool { + return c == ShortChannelID{} +} + // EShortChannelID is an encoder for ShortChannelID. It is exported so other // packages can use the encoding scheme. func EShortChannelID(w io.Writer, val interface{}, buf *[8]byte) error {