From a8ed1b342aac6f2393277e93af1634b5bd79d2d9 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 27 Jan 2020 13:19:08 +0100 Subject: [PATCH 1/3] routing: remove pathfinding db tx Pathfinding is never used with an externally supplied bbolt transaction. --- routing/pathfind.go | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/routing/pathfind.go b/routing/pathfind.go index 23bc1f9e9..77686337f 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -271,10 +271,6 @@ func edgeWeight(lockedAmt lnwire.MilliSatoshi, fee lnwire.MilliSatoshi, // graphParams wraps the set of graph parameters passed to findPath. type graphParams struct { - // tx can be set to an existing db transaction. If not set, a new - // transaction will be started. - tx *bbolt.Tx - // graph is the ChannelGraph to be used during path finding. graph *channeldb.ChannelGraph @@ -425,14 +421,12 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } self := selfNode.PubKeyBytes - tx := g.tx - if tx == nil { - tx, err = g.graph.Database().Begin(false) - if err != nil { - return nil, err - } - defer tx.Rollback() + // Get a db transaction to execute the graph queries in. + tx, err := g.graph.Database().Begin(false) + if err != nil { + return nil, err } + defer tx.Rollback() // If no destination features are provided, we will load what features // we have for the target node from our graph. From 06bdeb56e23b6c838efcf7a3eb862d16586f554f Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 27 Jan 2020 12:33:53 +0100 Subject: [PATCH 2/3] routing: add graph interface --- routing/graph.go | 104 +++++++++++++++++++++- routing/pathfind.go | 167 +++++++++++++++++------------------- routing/router.go | 26 ++++-- routing/unified_policies.go | 9 +- 4 files changed, 206 insertions(+), 100 deletions(-) diff --git a/routing/graph.go b/routing/graph.go index f3dfa121d..14eca1786 100644 --- a/routing/graph.go +++ b/routing/graph.go @@ -1,4 +1,104 @@ package routing -// TODO(roasbeef): abstract out graph to interface -// * add in-memory version of graph for tests +import ( + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// routingGraph is an abstract interface that provides information about nodes +// and edges to pathfinding. +type routingGraph interface { + // forEachNodeChannel calls the callback for every channel of the given node. + forEachNodeChannel(nodePub route.Vertex, + cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) error) error + + // sourceNode returns the source node of the graph. + sourceNode() route.Vertex + + // fetchNodeFeatures returns the features of the given node. + fetchNodeFeatures(nodePub route.Vertex) (*lnwire.FeatureVector, error) +} + +// dbRoutingTx is a routingGraph implementation that retrieves from the +// database. +type dbRoutingTx struct { + graph *channeldb.ChannelGraph + tx *bbolt.Tx + source route.Vertex +} + +// newDbRoutingTx instantiates a new db-connected routing graph. It implictly +// instantiates a new read transaction. +func newDbRoutingTx(graph *channeldb.ChannelGraph) (*dbRoutingTx, error) { + sourceNode, err := graph.SourceNode() + if err != nil { + return nil, err + } + + tx, err := graph.Database().Begin(false) + if err != nil { + return nil, err + } + + return &dbRoutingTx{ + graph: graph, + tx: tx, + source: sourceNode.PubKeyBytes, + }, nil +} + +// close closes the underlying db transaction. +func (g *dbRoutingTx) close() error { + return g.tx.Rollback() +} + +// forEachNodeChannel calls the callback for every channel of the given node. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) forEachNodeChannel(nodePub route.Vertex, + cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) error) error { + + txCb := func(_ *bbolt.Tx, info *channeldb.ChannelEdgeInfo, + p1, p2 *channeldb.ChannelEdgePolicy) error { + + return cb(info, p1, p2) + } + + return g.graph.ForEachNodeChannel(g.tx, nodePub[:], txCb) +} + +// sourceNode returns the source node of the graph. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) sourceNode() route.Vertex { + return g.source +} + +// fetchNodeFeatures returns the features of the given node. If the node is +// unknown, assume no additional features are supported. +// +// NOTE: Part of the routingGraph interface. +func (g *dbRoutingTx) fetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + targetNode, err := g.graph.FetchLightningNode(g.tx, nodePub) + switch err { + + // If the node exists and has features, return them directly. + case nil: + return targetNode.Features, nil + + // If we couldn't find a node announcement, populate a blank feature + // vector. + case channeldb.ErrGraphNodeNotFound: + return lnwire.EmptyFeatureVector(), nil + + // Otherwise bubble the error up. + default: + return nil, err + } +} diff --git a/routing/pathfind.go b/routing/pathfind.go index 77686337f..f2ea782d9 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -7,7 +7,6 @@ import ( "math" "time" - "github.com/coreos/bbolt" sphinx "github.com/lightningnetwork/lightning-onion" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/feature" @@ -346,10 +345,11 @@ type PathFindingConfig struct { // getMaxOutgoingAmt returns the maximum available balance in any of the // channels of the given node. func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, - g *graphParams, tx *bbolt.Tx) (lnwire.MilliSatoshi, error) { + bandwidthHints map[uint64]lnwire.MilliSatoshi, + g routingGraph) (lnwire.MilliSatoshi, error) { var max lnwire.MilliSatoshi - cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, outEdge, + cb := func(edgeInfo *channeldb.ChannelEdgeInfo, outEdge, _ *channeldb.ChannelEdgePolicy) error { if outEdge == nil { @@ -363,7 +363,7 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, return nil } - bandwidth, ok := g.bandwidthHints[chanID] + bandwidth, ok := bandwidthHints[chanID] // If the bandwidth is not available for whatever reason, don't // fail the pathfinding early. @@ -380,28 +380,54 @@ func getMaxOutgoingAmt(node route.Vertex, outgoingChan *uint64, } // Iterate over all channels of the to node. - err := g.graph.ForEachNodeChannel(tx, node[:], cb) + err := g.forEachNodeChannel(node, cb) if err != nil { return 0, err } return max, err } -// findPath attempts to find a path from the source node within the -// ChannelGraph to the target node that's capable of supporting a payment of -// `amt` value. The current approach implemented is modified version of -// Dijkstra's algorithm to find a single shortest path between the source node -// and the destination. The distance metric used for edges is related to the -// time-lock+fee costs along a particular edge. If a path is found, this -// function returns a slice of ChannelHop structs which encoded the chosen path -// from the target to the source. The search is performed backwards from -// destination node back to source. This is to properly accumulate fees -// that need to be paid along the path and accurately check the amount -// to forward at every node against the available bandwidth. +// findPath attempts to find a path from the source node within the ChannelGraph +// to the target node that's capable of supporting a payment of `amt` value. The +// current approach implemented is modified version of Dijkstra's algorithm to +// find a single shortest path between the source node and the destination. The +// distance metric used for edges is related to the time-lock+fee costs along a +// particular edge. If a path is found, this function returns a slice of +// ChannelHop structs which encoded the chosen path from the target to the +// source. The search is performed backwards from destination node back to +// source. This is to properly accumulate fees that need to be paid along the +// path and accurately check the amount to forward at every node against the +// available bandwidth. func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( []*channeldb.ChannelEdgePolicy, error) { + routingTx, err := newDbRoutingTx(g.graph) + if err != nil { + return nil, err + } + defer func() { + err := routingTx.close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }() + + return findPathInternal( + g.additionalEdges, g.bandwidthHints, routingTx, r, cfg, source, + target, amt, finalHtlcExpiry, + ) +} + +// findPathInternal is the internal implementation of findPath. +func findPathInternal( + additionalEdges map[route.Vertex][]*channeldb.ChannelEdgePolicy, + bandwidthHints map[uint64]lnwire.MilliSatoshi, + graph routingGraph, + r *RestrictParams, cfg *PathFindingConfig, + source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ( + []*channeldb.ChannelEdgePolicy, error) { + // Pathfinding can be a significant portion of the total payment // latency, especially on low-powered devices. Log several metrics to // aid in the analysis performance problems in this area. @@ -414,45 +440,20 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, "time=%v", nodesVisited, edgesExpanded, timeElapsed) }() - // Get source node outside of the pathfinding tx, to prevent a deadlock. - selfNode, err := g.graph.SourceNode() - if err != nil { - return nil, err - } - self := selfNode.PubKeyBytes - - // Get a db transaction to execute the graph queries in. - tx, err := g.graph.Database().Begin(false) - if err != nil { - return nil, err - } - defer tx.Rollback() - // If no destination features are provided, we will load what features // we have for the target node from our graph. features := r.DestFeatures if features == nil { - targetNode, err := g.graph.FetchLightningNode(tx, target) - switch { - - // If the node exists and has features, use them directly. - case err == nil: - features = targetNode.Features - - // If an error other than the node not existing is hit, abort. - case err != channeldb.ErrGraphNodeNotFound: + var err error + features, err = graph.fetchNodeFeatures(target) + if err != nil { return nil, err - - // Otherwise, we couldn't find a node announcement, populate a - // blank feature vector. - default: - features = lnwire.EmptyFeatureVector() } } // Ensure that the destination's features don't include unknown // required features. - err = feature.ValidateRequired(features) + err := feature.ValidateRequired(features) if err != nil { return nil, err } @@ -485,8 +486,12 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // If we are routing from ourselves, check that we have enough local // balance available. + self := graph.sourceNode() + if source == self { - max, err := getMaxOutgoingAmt(self, r.OutgoingChannelID, g, tx) + max, err := getMaxOutgoingAmt( + self, r.OutgoingChannelID, bandwidthHints, graph, + ) if err != nil { return nil, err } @@ -504,7 +509,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, distance := make(map[route.Vertex]*nodeWithDist, estimatedNodeCount) additionalEdgesWithSrc := make(map[route.Vertex][]*edgePolicyWithSource) - for vertex, outgoingEdgePolicies := range g.additionalEdges { + for vertex, outgoingEdgePolicies := range additionalEdges { // Build reverse lookup to find incoming edges. Needed because // search is taken place from target to source. for _, outgoingEdgePolicy := range outgoingEdgePolicies { @@ -746,45 +751,35 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Check cache for features of the fromNode. fromFeatures, ok := featureCache[node] - if !ok { - targetNode, err := g.graph.FetchLightningNode(tx, node) - switch { - - // If the node exists and has valid features, use them. - case err == nil: - nodeFeatures := targetNode.Features - - // Don't route through nodes that contain - // unknown required features. - err = feature.ValidateRequired(nodeFeatures) - if err != nil { - break - } - - // Don't route through nodes that don't properly - // set all transitive feature dependencies. - err = feature.ValidateDeps(nodeFeatures) - if err != nil { - break - } - - fromFeatures = nodeFeatures - - // If an error other than the node not existing is hit, - // abort. - case err != channeldb.ErrGraphNodeNotFound: - return nil, err - - // Otherwise, we couldn't find a node announcement, - // populate a blank feature vector. - default: - fromFeatures = lnwire.EmptyFeatureVector() - } - - // Update cache. - featureCache[node] = fromFeatures + if ok { + return fromFeatures, nil } + // Fetch node features fresh from the graph. + fromFeatures, err := graph.fetchNodeFeatures(node) + if err != nil { + return nil, err + } + + // Don't route through nodes that contain unknown required + // features and mark as nil in the cache. + err = feature.ValidateRequired(fromFeatures) + if err != nil { + featureCache[node] = nil + return nil, nil + } + + // Don't route through nodes that don't properly set all + // transitive feature dependencies and mark as nil in the cache. + err = feature.ValidateDeps(fromFeatures) + if err != nil { + featureCache[node] = nil + return nil, nil + } + + // Update cache. + featureCache[node] = fromFeatures + return fromFeatures, nil } @@ -797,7 +792,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, // Create unified policies for all incoming connections. u := newUnifiedPolicies(self, pivot, r.OutgoingChannelID) - err := u.addGraphPolicies(g.graph, tx) + err := u.addGraphPolicies(graph) if err != nil { return nil, err } @@ -828,7 +823,7 @@ func findPath(g *graphParams, r *RestrictParams, cfg *PathFindingConfig, } policy := unifiedPolicy.getPolicy( - amtToSend, g.bandwidthHints, + amtToSend, bandwidthHints, ) if policy == nil { diff --git a/routing/router.go b/routing/router.go index a7c24ba4e..e42ce7a0e 100644 --- a/routing/router.go +++ b/routing/router.go @@ -2311,6 +2311,13 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, return nil, err } + // Fetch the current block height outside the routing transaction, to + // prevent the rpc call blocking the database. + _, height, err := r.cfg.Chain.GetBestBlock() + if err != nil { + return nil, err + } + // Allocate a list that will contain the unified policies for this // route. edges := make([]*unifiedPolicy, len(hops)) @@ -2328,6 +2335,18 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, runningAmt = *amt } + // Open a transaction to execute the graph queries in. + routingTx, err := newDbRoutingTx(r.cfg.Graph) + if err != nil { + return nil, err + } + defer func() { + err := routingTx.close() + if err != nil { + log.Errorf("Error closing db tx: %v", err) + } + }() + // Traverse hops backwards to accumulate fees in the running amounts. source := r.selfNode.PubKeyBytes for i := len(hops) - 1; i >= 0; i-- { @@ -2346,7 +2365,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // known in the graph. u := newUnifiedPolicies(source, toNode, outgoingChan) - err := u.addGraphPolicies(r.cfg.Graph, nil) + err := u.addGraphPolicies(routingTx) if err != nil { return nil, err } @@ -2414,11 +2433,6 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, } // Build and return the final route. - _, height, err := r.cfg.Chain.GetBestBlock() - if err != nil { - return nil, err - } - return newRoute( source, pathEdges, uint32(height), finalHopParams{ diff --git a/routing/unified_policies.go b/routing/unified_policies.go index 81e646c29..3759175a6 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -2,7 +2,6 @@ package routing import ( "github.com/btcsuite/btcutil" - "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" @@ -69,10 +68,8 @@ func (u *unifiedPolicies) addPolicy(fromNode route.Vertex, // addGraphPolicies adds all policies that are known for the toNode in the // graph. -func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, - tx *bbolt.Tx) error { - - cb := func(_ *bbolt.Tx, edgeInfo *channeldb.ChannelEdgeInfo, _, +func (u *unifiedPolicies) addGraphPolicies(g routingGraph) error { + cb := func(edgeInfo *channeldb.ChannelEdgeInfo, _, inEdge *channeldb.ChannelEdgePolicy) error { // If there is no edge policy for this candidate node, skip. @@ -95,7 +92,7 @@ func (u *unifiedPolicies) addGraphPolicies(g *channeldb.ChannelGraph, } // Iterate over all channels of the to node. - return g.ForEachNodeChannel(tx, u.toNode[:], cb) + return g.forEachNodeChannel(u.toNode, cb) } // unifiedPolicyEdge is the individual channel data that is kept inside an From 29476ec6a3bdf3512c6c6f0c50eda37ac2825de8 Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Mon, 27 Jan 2020 15:40:33 +0100 Subject: [PATCH 3/3] routing/test: test probability extrapolation Adds an integrated routing test of probability extrapolation for untried channels. The larger part of this commit is mock code to simulate the Lightning Network. The difference between this test and the existing pathfinding tests, is that this test focuses on the feedback loop from result interpretation via mission control updates and probability estimation back to pathfinding. Improvements like probability extrapolation were previously only validated by reasoning, while this setup makes it possible to assert the improvement in a test and guard it for the future. --- routing/integrated_routing_context_test.go | 194 +++++++++++++++ routing/integrated_routing_test.go | 57 +++++ routing/mock_graph_test.go | 265 +++++++++++++++++++++ 3 files changed, 516 insertions(+) create mode 100644 routing/integrated_routing_context_test.go create mode 100644 routing/integrated_routing_test.go create mode 100644 routing/mock_graph_test.go diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go new file mode 100644 index 000000000..eb59473e6 --- /dev/null +++ b/routing/integrated_routing_context_test.go @@ -0,0 +1,194 @@ +package routing + +import ( + "io/ioutil" + "os" + "testing" + "time" + + "github.com/coreos/bbolt" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// integratedRoutingContext defines the context in which integrated routing +// tests run. +type integratedRoutingContext struct { + graph *mockGraph + t *testing.T + + source *mockNode + target *mockNode + + amt lnwire.MilliSatoshi + finalExpiry int32 + + mcCfg MissionControlConfig + pathFindingCfg PathFindingConfig +} + +// newIntegratedRoutingContext instantiates a new integrated routing test +// context with a source and a target node. +func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext { + // Instantiate a mock graph. + source := newMockNode() + target := newMockNode() + + graph := newMockGraph(t) + graph.addNode(source) + graph.addNode(target) + graph.source = source + + // Initiate the test context with a set of default configuration values. + // We don't use the lnd defaults here, because otherwise changing the + // defaults would break the unit tests. The actual values picked aren't + // critical to excite certain behavior, but do need to be aligned with + // the test case assertions. + ctx := integratedRoutingContext{ + t: t, + graph: graph, + amt: 100000, + finalExpiry: 40, + + mcCfg: MissionControlConfig{ + PenaltyHalfLife: 30 * time.Minute, + AprioriHopProbability: 0.6, + AprioriWeight: 0.5, + SelfNode: source.pubkey, + }, + + pathFindingCfg: PathFindingConfig{ + PaymentAttemptPenalty: 1000, + }, + + source: source, + target: target, + } + + return &ctx +} + +// testPayment launches a test payment and asserts that it is completed after +// the expected number of attempts. +func (c *integratedRoutingContext) testPayment(expectedNofAttempts int) { + var nextPid uint64 + + // Create temporary database for mission control. + file, err := ioutil.TempFile("", "*.db") + if err != nil { + c.t.Fatal(err) + } + + dbPath := file.Name() + defer os.Remove(dbPath) + + db, err := bbolt.Open(dbPath, 0600, nil) + if err != nil { + c.t.Fatal(err) + } + defer db.Close() + + // Instantiate a new mission control with the current configuration + // values. + mc, err := NewMissionControl(db, &c.mcCfg) + if err != nil { + c.t.Fatal(err) + } + + // Instruct pathfinding to use mission control as a probabiltiy source. + restrictParams := RestrictParams{ + ProbabilitySource: mc.GetProbability, + FeeLimit: lnwire.MaxMilliSatoshi, + } + + // Now the payment control loop starts. It will keep trying routes until + // the payment succeeds. + for { + // Create bandwidth hints based on local channel balances. + bandwidthHints := map[uint64]lnwire.MilliSatoshi{} + for _, ch := range c.graph.nodes[c.source.pubkey].channels { + bandwidthHints[ch.id] = ch.balance + } + + // Find a route. + path, err := findPathInternal( + nil, bandwidthHints, c.graph, + &restrictParams, + &c.pathFindingCfg, + c.source.pubkey, c.target.pubkey, + c.amt, c.finalExpiry, + ) + if err != nil { + c.t.Fatal(err) + } + + finalHop := finalHopParams{ + amt: c.amt, + cltvDelta: uint16(c.finalExpiry), + } + + route, err := newRoute(c.source.pubkey, path, 0, finalHop) + if err != nil { + c.t.Fatal(err) + } + + // Send out the htlc on the mock graph. + pid := nextPid + nextPid++ + htlcResult, err := c.graph.sendHtlc(route) + if err != nil { + c.t.Fatal(err) + } + + // Process the result. + if htlcResult.failure == nil { + err := mc.ReportPaymentSuccess(pid, route) + if err != nil { + c.t.Fatal(err) + } + + // If the payment is successful, the control loop can be + // broken out of. + break + } + + // Failure, update mission control and retry. + c.t.Logf("fail: %v @ %v\n", htlcResult.failure, htlcResult.failureSource) + + finalResult, err := mc.ReportPaymentFail( + pid, route, + getNodeIndex(route, htlcResult.failureSource), + htlcResult.failure, + ) + if err != nil { + c.t.Fatal(err) + } + + if finalResult != nil { + c.t.Logf("final result: %v\n", finalResult) + break + } + } + + c.t.Logf("Payment attempts: %v\n", nextPid) + if expectedNofAttempts != int(nextPid) { + c.t.Fatalf("expected %v attempts, but needed %v", + expectedNofAttempts, nextPid) + } +} + +// getNodeIndex returns the zero-based index of the given node in the route. +func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { + if failureSource == route.SourcePubKey { + idx := 0 + return &idx + } + + for i, h := range route.Hops { + if h.PubKeyBytes == failureSource { + idx := i + 1 + return &idx + } + } + return nil +} diff --git a/routing/integrated_routing_test.go b/routing/integrated_routing_test.go new file mode 100644 index 000000000..19df17a35 --- /dev/null +++ b/routing/integrated_routing_test.go @@ -0,0 +1,57 @@ +package routing + +import ( + "testing" +) + +// TestProbabilityExtrapolation tests that probabilities for tried channels are +// extrapolated to untried channels. This is a way to improve pathfinding +// success by steering away from bad nodes. +func TestProbabilityExtrapolation(t *testing.T) { + ctx := newIntegratedRoutingContext(t) + + // Create the following network of nodes: + // source -> expensiveNode (charges routing fee) -> target + // source -> intermediate1 (free routing) -> intermediate(1-10) (free routing) -> target + g := ctx.graph + + expensiveNode := newMockNode() + expensiveNode.baseFee = 10000 + g.addNode(expensiveNode) + + g.addChannel(ctx.source, expensiveNode, 100000) + g.addChannel(ctx.target, expensiveNode, 100000) + + intermediate1 := newMockNode() + g.addNode(intermediate1) + g.addChannel(ctx.source, intermediate1, 100000) + + for i := 0; i < 10; i++ { + imNode := newMockNode() + g.addNode(imNode) + g.addChannel(imNode, ctx.target, 100000) + g.addChannel(imNode, intermediate1, 100000) + + // The channels from intermediate1 all have insufficient balance. + g.nodes[intermediate1.pubkey].channels[imNode.pubkey].balance = 0 + } + + // It is expected that pathfinding will try to explore the routes via + // intermediate1 first, because those are free. But as failures happen, + // the node probability of intermediate1 will go down in favor of the + // paid route via expensiveNode. + // + // The exact number of attempts required is dependent on mission control + // config. For this test, it would have been enough to only assert that + // we are not trying all routes via intermediate1. However, we do assert + // a specific number of attempts to safe-guard against accidental + // modifications anywhere in the chain of components that is involved in + // this test. + ctx.testPayment(5) + + // If we use a static value for the node probability (no extrapolation + // of data from other channels), all ten bad channels will be tried + // first before switching to the paid channel. + ctx.mcCfg.AprioriWeight = 1 + ctx.testPayment(11) +} diff --git a/routing/mock_graph_test.go b/routing/mock_graph_test.go new file mode 100644 index 000000000..075a416a1 --- /dev/null +++ b/routing/mock_graph_test.go @@ -0,0 +1,265 @@ +package routing + +import ( + "bytes" + "fmt" + "testing" + + "github.com/btcsuite/btcutil" + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// nextTestPubkey is global variable that is used to deterministically generate +// test keys. +var nextTestPubkey byte + +// createPubkey return a new test pubkey. +func createPubkey() route.Vertex { + pubkey := route.Vertex{nextTestPubkey} + nextTestPubkey++ + return pubkey +} + +// mockChannel holds the channel state of a channel in the mock graph. +type mockChannel struct { + id uint64 + capacity btcutil.Amount + balance lnwire.MilliSatoshi +} + +// mockNode holds a set of mock channels and routing policies for a node in the +// mock graph. +type mockNode struct { + channels map[route.Vertex]*mockChannel + baseFee lnwire.MilliSatoshi + pubkey route.Vertex +} + +// newMockNode instantiates a new mock node with a newly generated pubkey. +func newMockNode() *mockNode { + pubkey := createPubkey() + return &mockNode{ + channels: make(map[route.Vertex]*mockChannel), + pubkey: pubkey, + } +} + +// fwd simulates an htlc forward through this node. If the from parameter is +// nil, this node is considered to be the sender of the payment. The route +// parameter describes the remaining route from this node onwards. If route.next +// is nil, this node is the final hop. +func (m *mockNode) fwd(from *mockNode, route *hop) (htlcResult, error) { + next := route.next + + // Get the incoming channel, if any. + var inChan *mockChannel + if from != nil { + inChan = m.channels[from.pubkey] + } + + // If there is no next node, this is the final node and we can settle the htlc. + if next == nil { + // Update the incoming balance. + inChan.balance += route.amtToFwd + + return htlcResult{}, nil + } + + // Check if the outgoing channel has enough balance. + outChan, ok := m.channels[next.node.pubkey] + if !ok { + return htlcResult{}, + fmt.Errorf("%v: unknown next %v", + m.pubkey, next.node.pubkey) + } + if outChan.balance < route.amtToFwd { + return htlcResult{ + failureSource: m.pubkey, + failure: lnwire.NewTemporaryChannelFailure(nil), + }, nil + } + + // Htlc can be forwarded, update channel balances. + outChan.balance -= route.amtToFwd + if inChan != nil { + inChan.balance += route.amtToFwd + } + + // Recursively forward down the given route. + result, err := next.node.fwd(m, route.next) + if err != nil { + return htlcResult{}, err + } + + // Revert balances when a failure occurs. + if result.failure != nil { + outChan.balance += route.amtToFwd + if inChan != nil { + inChan.balance -= route.amtToFwd + } + } + + return result, nil +} + +// mockGraph contains a set of nodes that together for a mocked graph. +type mockGraph struct { + t *testing.T + nodes map[route.Vertex]*mockNode + nextChanID uint64 + source *mockNode +} + +// newMockGraph instantiates a new mock graph. +func newMockGraph(t *testing.T) *mockGraph { + return &mockGraph{ + nodes: make(map[route.Vertex]*mockNode), + t: t, + } +} + +// addNode adds the given mock node to the network. +func (m *mockGraph) addNode(node *mockNode) { + m.nodes[node.pubkey] = node +} + +// addChannel adds a new channel between two existing nodes on the network. It +// sets the channel balance to 50/50%. +// +// Ignore linter error because addChannel isn't yet called with different +// capacities. +// nolint:unparam +func (m *mockGraph) addChannel(node1, node2 *mockNode, capacity btcutil.Amount) { + id := m.nextChanID + m.nextChanID++ + + m.nodes[node1.pubkey].channels[node2.pubkey] = &mockChannel{ + capacity: capacity, + id: id, + balance: lnwire.NewMSatFromSatoshis(capacity / 2), + } + m.nodes[node2.pubkey].channels[node1.pubkey] = &mockChannel{ + capacity: capacity, + id: id, + balance: lnwire.NewMSatFromSatoshis(capacity / 2), + } +} + +// forEachNodeChannel calls the callback for every channel of the given node. +// +// NOTE: Part of the routingGraph interface. +func (m *mockGraph) forEachNodeChannel(nodePub route.Vertex, + cb func(*channeldb.ChannelEdgeInfo, *channeldb.ChannelEdgePolicy, + *channeldb.ChannelEdgePolicy) error) error { + + // Look up the mock node. + node, ok := m.nodes[nodePub] + if !ok { + return channeldb.ErrGraphNodeNotFound + } + + // Iterate over all of its channels. + for peer, channel := range node.channels { + // Lexicographically sort the pubkeys. + var node1, node2 route.Vertex + if bytes.Compare(nodePub[:], peer[:]) == -1 { + node1, node2 = peer, nodePub + } else { + node1, node2 = nodePub, peer + } + + peerNode := m.nodes[peer] + + // Call the per channel callback. + err := cb( + &channeldb.ChannelEdgeInfo{ + NodeKey1Bytes: node1, + NodeKey2Bytes: node2, + }, + &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: peer, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: node.baseFee, + }, + &channeldb.ChannelEdgePolicy{ + ChannelID: channel.id, + Node: &channeldb.LightningNode{ + PubKeyBytes: nodePub, + Features: lnwire.EmptyFeatureVector(), + }, + FeeBaseMSat: peerNode.baseFee, + }, + ) + if err != nil { + return err + } + } + return nil +} + +// sourceNode returns the source node of the graph. +// +// NOTE: Part of the routingGraph interface. +func (m *mockGraph) sourceNode() route.Vertex { + return m.source.pubkey +} + +// fetchNodeFeatures returns the features of the given node. +// +// NOTE: Part of the routingGraph interface. +func (m *mockGraph) fetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + return lnwire.EmptyFeatureVector(), nil +} + +// htlcResult describes the resolution of an htlc. If failure is nil, the htlc +// was settled. +type htlcResult struct { + failureSource route.Vertex + failure lnwire.FailureMessage +} + +// hop describes one hop of a route. +type hop struct { + node *mockNode + amtToFwd lnwire.MilliSatoshi + next *hop +} + +// sendHtlc sends out an htlc on the mock network and synchronously returns the +// final resolution of the htlc. +func (m *mockGraph) sendHtlc(route *route.Route) (htlcResult, error) { + var next *hop + + // Convert the route into a structure that is suitable for recursive + // processing. + for i := len(route.Hops) - 1; i >= 0; i-- { + routeHop := route.Hops[i] + node := m.nodes[routeHop.PubKeyBytes] + next = &hop{ + node: node, + next: next, + amtToFwd: routeHop.AmtToForward, + } + } + + // Create the starting hop instance. + source := m.nodes[route.SourcePubKey] + next = &hop{ + node: source, + next: next, + amtToFwd: route.TotalAmount, + } + + // Recursively walk the path and obtain the htlc resolution. + return source.fwd(nil, next) +} + +// Compile-time check for the routingGraph interface. +var _ routingGraph = &mockGraph{}