package routing import ( "io/ioutil" "os" "testing" "time" "github.com/lightningnetwork/lnd/channeldb/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/routing/route" ) const ( sourceNodeID = 1 targetNodeID = 2 ) // integratedRoutingContext defines the context in which integrated routing // tests run. type integratedRoutingContext struct { graph *mockGraph t *testing.T source *mockNode target *mockNode amt lnwire.MilliSatoshi finalExpiry int32 mcCfg MissionControlConfig pathFindingCfg PathFindingConfig } // newIntegratedRoutingContext instantiates a new integrated routing test // context with a source and a target node. func newIntegratedRoutingContext(t *testing.T) *integratedRoutingContext { // Instantiate a mock graph. source := newMockNode(sourceNodeID) target := newMockNode(targetNodeID) graph := newMockGraph(t) graph.addNode(source) graph.addNode(target) graph.source = source // Initiate the test context with a set of default configuration values. // We don't use the lnd defaults here, because otherwise changing the // defaults would break the unit tests. The actual values picked aren't // critical to excite certain behavior, but do need to be aligned with // the test case assertions. ctx := integratedRoutingContext{ t: t, graph: graph, amt: 100000, finalExpiry: 40, mcCfg: MissionControlConfig{ PenaltyHalfLife: 30 * time.Minute, AprioriHopProbability: 0.6, AprioriWeight: 0.5, SelfNode: source.pubkey, }, pathFindingCfg: PathFindingConfig{ PaymentAttemptPenalty: 1000, }, source: source, target: target, } return &ctx } // htlcAttempt records the route and outcome of an attempted htlc. type htlcAttempt struct { route *route.Route success bool } // testPayment launches a test payment and asserts that it is completed after // the expected number of attempts. func (c *integratedRoutingContext) testPayment() []htlcAttempt { var ( nextPid uint64 attempts []htlcAttempt ) // Create temporary database for mission control. file, err := ioutil.TempFile("", "*.db") if err != nil { c.t.Fatal(err) } dbPath := file.Name() defer os.Remove(dbPath) db, err := kvdb.Open(kvdb.BoltBackendName, dbPath, true) if err != nil { c.t.Fatal(err) } defer db.Close() // Instantiate a new mission control with the current configuration // values. mc, err := NewMissionControl(db, &c.mcCfg) if err != nil { c.t.Fatal(err) } getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { bandwidthHints[ch.id] = ch.balance } return bandwidthHints, nil } payment := LightningPayment{ FinalCLTVDelta: uint16(c.finalExpiry), FeeLimit: lnwire.MaxMilliSatoshi, Target: c.target.pubkey, } session := &paymentSession{ getBandwidthHints: getBandwidthHints, payment: &payment, pathFinder: findPath, getRoutingGraph: func() (routingGraph, func(), error) { return c.graph, func() {}, nil }, pathFindingConfig: c.pathFindingCfg, missionControl: mc, } // Now the payment control loop starts. It will keep trying routes until // the payment succeeds. for { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { bandwidthHints[ch.id] = ch.balance } // Find a route. route, err := session.RequestRoute( c.amt, lnwire.MaxMilliSatoshi, 0, 0, ) if err != nil { c.t.Fatal(err) } // Send out the htlc on the mock graph. pid := nextPid nextPid++ htlcResult, err := c.graph.sendHtlc(route) if err != nil { c.t.Fatal(err) } success := htlcResult.failure == nil attempts = append(attempts, htlcAttempt{ route: route, success: success, }) // Process the result. if success { err := mc.ReportPaymentSuccess(pid, route) if err != nil { c.t.Fatal(err) } // If the payment is successful, the control loop can be // broken out of. break } // Failure, update mission control and retry. c.t.Logf("fail: %v @ %v\n", htlcResult.failure, htlcResult.failureSource) finalResult, err := mc.ReportPaymentFail( pid, route, getNodeIndex(route, htlcResult.failureSource), htlcResult.failure, ) if err != nil { c.t.Fatal(err) } if finalResult != nil { c.t.Logf("final result: %v\n", finalResult) break } } c.t.Logf("Payment attempts: %v\n", len(attempts)) return attempts } // getNodeIndex returns the zero-based index of the given node in the route. func getNodeIndex(route *route.Route, failureSource route.Vertex) *int { if failureSource == route.SourcePubKey { idx := 0 return &idx } for i, h := range route.Hops { if h.PubKeyBytes == failureSource { idx := i + 1 return &idx } } return nil }