routing+routerrpc: add multiple outgoing channel restriction

This commit is contained in:
Joost Jager
2020-05-07 11:48:39 +02:00
parent 53e4876a1d
commit c33d94ff27
11 changed files with 287 additions and 214 deletions

View File

@ -284,9 +284,9 @@ type RestrictParams struct {
// the source to the target.
FeeLimit lnwire.MilliSatoshi
// OutgoingChannelID is the channel that needs to be taken to the first
// hop. If nil, any channel may be used.
OutgoingChannelID *uint64
// OutgoingChannelIDs is the list of channels that are allowed for the
// first hop. If nil, any channel may be used.
OutgoingChannelIDs []uint64
// LastHop is the pubkey of the last node before the final destination
// is reached. If nil, any node may be used.
@ -329,7 +329,7 @@ type PathFindingConfig struct {
// getOutgoingBalance returns the maximum available balance in any of the
// channels of the given node. The second return parameters is the total
// available balance.
func getOutgoingBalance(node route.Vertex, outgoingChan *uint64,
func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{},
bandwidthHints map[uint64]lnwire.MilliSatoshi,
g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) {
@ -344,8 +344,10 @@ func getOutgoingBalance(node route.Vertex, outgoingChan *uint64,
chanID := outEdge.ChannelID
// Enforce outgoing channel restriction.
if outgoingChan != nil && chanID != *outgoingChan {
return nil
if outgoingChans != nil {
if _, ok := outgoingChans[chanID]; !ok {
return nil
}
}
bandwidth, ok := bandwidthHints[chanID]
@ -447,13 +449,22 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
return nil, errNoPaymentAddr
}
// Set up outgoing channel map for quicker access.
var outgoingChanMap map[uint64]struct{}
if len(r.OutgoingChannelIDs) > 0 {
outgoingChanMap = make(map[uint64]struct{})
for _, outChan := range r.OutgoingChannelIDs {
outgoingChanMap[outChan] = struct{}{}
}
}
// If we are routing from ourselves, check that we have enough local
// balance available.
self := g.graph.sourceNode()
if source == self {
max, total, err := getOutgoingBalance(
self, r.OutgoingChannelID, g.bandwidthHints, g.graph,
self, outgoingChanMap, g.bandwidthHints, g.graph,
)
if err != nil {
return nil, err
@ -763,7 +774,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig,
pivot := partialPath.node
// Create unified policies for all incoming connections.
u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID)
u := newUnifiedPolicies(self, pivot, outgoingChanMap)
err := u.addGraphPolicies(g.graph)
if err != nil {