mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-07-12 14:12:27 +02:00
routing+routerrpc: add multiple outgoing channel restriction
This commit is contained in:
@ -284,9 +284,9 @@ type RestrictParams struct {
|
||||
// the source to the target.
|
||||
FeeLimit lnwire.MilliSatoshi
|
||||
|
||||
// OutgoingChannelID is the channel that needs to be taken to the first
|
||||
// hop. If nil, any channel may be used.
|
||||
OutgoingChannelID *uint64
|
||||
// OutgoingChannelIDs is the list of channels that are allowed for the
|
||||
// first hop. If nil, any channel may be used.
|
||||
OutgoingChannelIDs []uint64
|
||||
|
||||
// LastHop is the pubkey of the last node before the final destination
|
||||
// is reached. If nil, any node may be used.
|
||||
@ -329,7 +329,7 @@ type PathFindingConfig struct {
|
||||
// getOutgoingBalance returns the maximum available balance in any of the
|
||||
// channels of the given node. The second return parameters is the total
|
||||
// available balance.
|
||||
func getOutgoingBalance(node route.Vertex, outgoingChan *uint64,
|
||||
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
|
||||
bandwidthHints map[uint64]lnwire.MilliSatoshi,
|
||||
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
|
||||
|
||||
@ -344,8 +344,10 @@ func getOutgoingBalance(node route.Vertex, outgoingChan *uint64,
|
||||
chanID := outEdge.ChannelID
|
||||
|
||||
// Enforce outgoing channel restriction.
|
||||
if outgoingChan != nil && chanID != *outgoingChan {
|
||||
return nil
|
||||
if outgoingChans != nil {
|
||||
if _, ok := outgoingChans[chanID]; !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
bandwidth, ok := bandwidthHints[chanID]
|
||||
@ -447,13 +449,22 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
||||
return nil, errNoPaymentAddr
|
||||
}
|
||||
|
||||
// Set up outgoing channel map for quicker access.
|
||||
var outgoingChanMap map[uint64]struct{}
|
||||
if len(r.OutgoingChannelIDs) > 0 {
|
||||
outgoingChanMap = make(map[uint64]struct{})
|
||||
for _, outChan := range r.OutgoingChannelIDs {
|
||||
outgoingChanMap[outChan] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
// If we are routing from ourselves, check that we have enough local
|
||||
// balance available.
|
||||
self := g.graph.sourceNode()
|
||||
|
||||
if source == self {
|
||||
max, total, err := getOutgoingBalance(
|
||||
self, r.OutgoingChannelID, g.bandwidthHints, g.graph,
|
||||
self, outgoingChanMap, g.bandwidthHints, g.graph,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -763,7 +774,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
|
||||
pivot := partialPath.node
|
||||
|
||||
// Create unified policies for all incoming connections.
|
||||
u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID)
|
||||
u := newUnifiedPolicies(self, pivot, outgoingChanMap)
|
||||
|
||||
err := u.addGraphPolicies(g.graph)
|
||||
if err != nil {
|
||||
|
Reference in New Issue
Block a user