From 45de686d35a231a63839d53ca06446797f6400b8 Mon Sep 17 00:00:00 2001 From: carla Date: Tue, 19 Oct 2021 09:37:44 +0200 Subject: [PATCH] multi: move bandwidth hints behind interface --- routing/bandwidth.go | 107 +++++++++++++++++ routing/bandwidth_test.go | 128 +++++++++++++++++++++ routing/integrated_routing_context_test.go | 21 +++- routing/mock_test.go | 22 ++++ routing/pathfind.go | 20 ++-- routing/pathfind_test.go | 72 +++++++++--- routing/payment_lifecycle_test.go | 3 - routing/payment_session.go | 4 +- routing/payment_session_source.go | 12 +- routing/payment_session_test.go | 12 +- routing/router.go | 45 +------- routing/router_test.go | 16 +-- routing/unified_policies.go | 8 +- routing/unified_policies_test.go | 2 +- server.go | 31 +---- 15 files changed, 369 insertions(+), 134 deletions(-) create mode 100644 routing/bandwidth.go create mode 100644 routing/bandwidth_test.go diff --git a/routing/bandwidth.go b/routing/bandwidth.go new file mode 100644 index 000000000..4aad0f080 --- /dev/null +++ b/routing/bandwidth.go @@ -0,0 +1,107 @@ +package routing + +import ( + "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" +) + +// bandwidthHints provides hints about the currently available balance in our +// channels. +type bandwidthHints interface { + // availableChanBandwidth returns the total available bandwidth for a + // channel and a bool indicating whether the channel hint was found. + // If the channel is unavailable, a zero amount is returned. + availableChanBandwidth(channelID uint64) (lnwire.MilliSatoshi, bool) +} + +// getLinkQuery is the function signature used to lookup a link. +type getLinkQuery func(lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) + +// bandwidthManager is an implementation of the bandwidthHints interface which +// uses the link lookup provided to query the link for our latest local channel +// balances. +type bandwidthManager struct { + getLink getLinkQuery + localChans map[lnwire.ShortChannelID]struct{} +} + +// newBandwidthManager creates a bandwidth manager for the source node provided +// which is used to obtain hints from the lower layer w.r.t the available +// bandwidth of edges on the network. Currently, we'll only obtain bandwidth +// hints for the edges we directly have open ourselves. Obtaining these hints +// allows us to reduce the number of extraneous attempts as we can skip channels +// that are inactive, or just don't have enough bandwidth to carry the payment. +func newBandwidthManager(graph routingGraph, sourceNode route.Vertex, + linkQuery getLinkQuery) (*bandwidthManager, error) { + + manager := &bandwidthManager{ + getLink: linkQuery, + localChans: make(map[lnwire.ShortChannelID]struct{}), + } + + // First, we'll collect the set of outbound edges from the target + // source node and add them to our bandwidth manager's map of channels. + err := graph.forEachNodeChannel(sourceNode, + func(channel *channeldb.DirectedChannel) error { + shortID := lnwire.NewShortChanIDFromInt( + channel.ChannelID, + ) + manager.localChans[shortID] = struct{}{} + + return nil + }) + + if err != nil { + return nil, err + } + + return manager, nil +} + +// getBandwidth queries the current state of a link and gets its currently +// available bandwidth. Note that this function assumes that the channel being +// queried is one of our local channels, so any failure to retrieve the link +// is interpreted as the link being offline. +func (b *bandwidthManager) getBandwidth(cid lnwire.ShortChannelID) lnwire.MilliSatoshi { + link, err := b.getLink(cid) + if err != nil { + // If the link isn't online, then we'll report that it has + // zero bandwidth. + return 0 + } + + // If the link is found within the switch, but it isn't yet eligible + // to forward any HTLCs, then we'll treat it as if it isn't online in + // the first place. + if !link.EligibleToForward() { + return 0 + } + + // If our link isn't currently in a state where it can add another + // outgoing htlc, treat the link as unusable. + if err := link.MayAddOutgoingHtlc(); err != nil { + return 0 + } + + // Otherwise, we'll return the current best estimate for the available + // bandwidth for the link. + return link.Bandwidth() +} + +// availableChanBandwidth returns the total available bandwidth for a channel +// and a bool indicating whether the channel hint was found. If the channel is +// unavailable, a zero amount is returned. +func (b *bandwidthManager) availableChanBandwidth(channelID uint64) ( + lnwire.MilliSatoshi, bool) { + + shortID := lnwire.NewShortChanIDFromInt(channelID) + _, ok := b.localChans[shortID] + if !ok { + return 0, false + } + + return b.getBandwidth(shortID), true +} diff --git a/routing/bandwidth_test.go b/routing/bandwidth_test.go new file mode 100644 index 000000000..b362f2985 --- /dev/null +++ b/routing/bandwidth_test.go @@ -0,0 +1,128 @@ +package routing + +import ( + "testing" + + "github.com/btcsuite/btcutil" + "github.com/go-errors/errors" + "github.com/lightningnetwork/lnd/htlcswitch" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/stretchr/testify/require" +) + +// TestBandwidthManager tests getting of bandwidth hints from a bandwidth +// manager. +func TestBandwidthManager(t *testing.T) { + var ( + chan1ID uint64 = 101 + chan2ID uint64 = 102 + chanCapacity btcutil.Amount = 100000 + ) + + testCases := []struct { + name string + channelID uint64 + linkQuery getLinkQuery + expectedBandwidth lnwire.MilliSatoshi + expectFound bool + }{ + { + name: "channel not ours", + channelID: chan2ID, + // Set a link query that will fail our test since we + // don't expect to query the switch for a channel that + // is not ours. + linkQuery: func(id lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + require.Fail(t, "link query unexpected for: "+ + "%v", id) + + return nil, nil + }, + expectedBandwidth: 0, + expectFound: false, + }, + { + name: "channel ours, link not online", + channelID: chan1ID, + linkQuery: func(lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + return nil, htlcswitch.ErrChannelLinkNotFound + }, + expectedBandwidth: 0, + expectFound: true, + }, + { + name: "channel ours, link not eligible", + channelID: chan1ID, + linkQuery: func(lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + return &mockLink{ + ineligible: true, + }, nil + }, + expectedBandwidth: 0, + expectFound: true, + }, + { + name: "channel ours, link can't add htlc", + channelID: chan1ID, + linkQuery: func(lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + return &mockLink{ + mayAddOutgoingErr: errors.New( + "can't add htlc", + ), + }, nil + }, + expectedBandwidth: 0, + expectFound: true, + }, + { + name: "channel ours, bandwidth available", + channelID: chan1ID, + linkQuery: func(lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + return &mockLink{ + bandwidth: 321, + }, nil + }, + expectedBandwidth: 321, + expectFound: true, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + g := newMockGraph(t) + + sourceNode := newMockNode(sourceNodeID) + targetNode := newMockNode(targetNodeID) + + g.addNode(sourceNode) + g.addNode(targetNode) + g.addChannel( + chan1ID, sourceNodeID, targetNodeID, + chanCapacity, + ) + + m, err := newBandwidthManager( + g, sourceNode.pubkey, testCase.linkQuery, + ) + require.NoError(t, err) + + bandwidth, found := m.availableChanBandwidth( + testCase.channelID, + ) + require.Equal(t, testCase.expectedBandwidth, bandwidth) + require.Equal(t, testCase.expectFound, found) + }) + } +} diff --git a/routing/integrated_routing_context_test.go b/routing/integrated_routing_context_test.go index d13b1c432..008470b7b 100644 --- a/routing/integrated_routing_context_test.go +++ b/routing/integrated_routing_context_test.go @@ -18,6 +18,21 @@ const ( targetNodeID = 2 ) +type mockBandwidthHints struct { + hints map[uint64]lnwire.MilliSatoshi +} + +func (m *mockBandwidthHints) availableChanBandwidth(channelID uint64) ( + lnwire.MilliSatoshi, bool) { + + if m.hints == nil { + return 0, false + } + + balance, ok := m.hints[channelID] + return balance, ok +} + // integratedRoutingContext defines the context in which integrated routing // tests run. type integratedRoutingContext struct { @@ -130,14 +145,16 @@ func (c *integratedRoutingContext) testPayment(maxParts uint32, c.t.Fatal(err) } - getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { + getBandwidthHints := func() (bandwidthHints, error) { // Create bandwidth hints based on local channel balances. bandwidthHints := map[uint64]lnwire.MilliSatoshi{} for _, ch := range c.graph.nodes[c.source.pubkey].channels { bandwidthHints[ch.id] = ch.balance } - return bandwidthHints, nil + return &mockBandwidthHints{ + hints: bandwidthHints, + }, nil } var paymentAddr [32]byte diff --git a/routing/mock_test.go b/routing/mock_test.go index a59ae2aa4..d5928f19f 100644 --- a/routing/mock_test.go +++ b/routing/mock_test.go @@ -741,3 +741,25 @@ func (m *mockControlTower) SubscribePayment(paymentHash lntypes.Hash) ( args := m.Called(paymentHash) return args.Get(0).(*ControlTowerSubscriber), args.Error(1) } + +type mockLink struct { + htlcswitch.ChannelLink + bandwidth lnwire.MilliSatoshi + mayAddOutgoingErr error + ineligible bool +} + +// Bandwidth returns the bandwidth the mock was configured with. +func (m *mockLink) Bandwidth() lnwire.MilliSatoshi { + return m.bandwidth +} + +// EligibleToForward returns the mock's configured eligibility. +func (m *mockLink) EligibleToForward() bool { + return !m.ineligible +} + +// MayAddOutgoingHtlc returns the error configured in our mock. +func (m *mockLink) MayAddOutgoingHtlc() error { + return m.mayAddOutgoingErr +} diff --git a/routing/pathfind.go b/routing/pathfind.go index 27a67ea7a..01aea32b6 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -282,14 +282,14 @@ type graphParams struct { // channel graph. additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy - // bandwidthHints is an optional map from channels to bandwidths that - // can be populated if the caller has a better estimate of the current - // channel bandwidth than what is found in the graph. If set, it will - // override the capacities and disabled flags found in the graph for - // local channels when doing path finding. In particular, it should be - // set to the current available sending bandwidth for active local - // channels, and 0 for inactive channels. - bandwidthHints map[uint64]lnwire.MilliSatoshi + // bandwidthHints is an interface that provides bandwidth hints that + // can provide a better estimate of the current channel bandwidth than + // what is found in the graph. It will override the capacities and + // disabled flags found in the graph for local channels when doing + // path finding if it has updated values for that channel. In + // particular, it should be set to the current available sending + // bandwidth for active local channels, and 0 for inactive channels. + bandwidthHints bandwidthHints } // RestrictParams wraps the set of restrictions passed to findPath that the @@ -355,7 +355,7 @@ type PathFindingConfig struct { // channels of the given node. The second return parameters is the total // available balance. func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, - bandwidthHints map[uint64]lnwire.MilliSatoshi, + bandwidthHints bandwidthHints, g routingGraph) (lnwire.MilliSatoshi, lnwire.MilliSatoshi, error) { var max, total lnwire.MilliSatoshi @@ -373,7 +373,7 @@ func getOutgoingBalance(node route.Vertex, outgoingChans map[uint64]struct{}, } } - bandwidth, ok := bandwidthHints[chanID] + bandwidth, ok := bandwidthHints.availableChanBandwidth(chanID) // If the bandwidth is not available, use the channel capacity. // This can happen when a channel is added to the graph after diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 426faa099..dcc3b0a19 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -216,6 +217,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) { aliasMap := make(map[string]route.Vertex) privKeyMap := make(map[string]*btcec.PrivateKey) channelIDs := make(map[route.Vertex]map[route.Vertex]uint64) + links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) var source *channeldb.LightningNode // First we insert all the nodes within the graph as vertexes. @@ -358,6 +360,13 @@ func parseTestGraph(path string) (*testGraphInstance, error) { copy(edgeInfo.BitcoinKey1Bytes[:], node1Bytes) copy(edgeInfo.BitcoinKey2Bytes[:], node2Bytes) + shortID := lnwire.NewShortChanIDFromInt(edge.ChannelID) + links[shortID] = &mockLink{ + bandwidth: lnwire.MilliSatoshi( + edgeInfo.Capacity * 1000, + ), + } + err = graph.AddChannelEdge(&edgeInfo) if err != nil && err != channeldb.ErrEdgeAlreadyExist { return nil, err @@ -419,6 +428,7 @@ func parseTestGraph(path string) (*testGraphInstance, error) { aliasMap: aliasMap, privKeyMap: privKeyMap, channelIDs: channelIDs, + links: links, }, nil } @@ -495,6 +505,22 @@ type testGraphInstance struct { // channelIDs stores the channel ID for each node. channelIDs map[route.Vertex]map[route.Vertex]uint64 + + // links maps channel ids to a mock channel update handler. + links map[lnwire.ShortChannelID]htlcswitch.ChannelLink +} + +// getLink is a mocked link lookup function which looks up links in our test +// graph. +func (g *testGraphInstance) getLink(chanID lnwire.ShortChannelID) ( + htlcswitch.ChannelLink, error) { + + link, ok := g.links[chanID] + if !ok { + return nil, fmt.Errorf("link not found in mock: %v", chanID) + } + + return link, nil } // createTestGraphFromChannels returns a fully populated ChannelGraph based on a set of @@ -581,6 +607,8 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( // if none is specified. nextUnassignedChannelID := uint64(100000) + links := make(map[lnwire.ShortChannelID]htlcswitch.ChannelLink) + for _, testChannel := range testChannels { for _, node := range []*testChannelEnd{ testChannel.Node1, testChannel.Node2} { @@ -617,6 +645,12 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( Index: 0, } + capacity := lnwire.MilliSatoshi(testChannel.Capacity * 1000) + shortID := lnwire.NewShortChanIDFromInt(channelID) + links[shortID] = &mockLink{ + bandwidth: capacity, + } + // Sort nodes node1 := testChannel.Node1 node2 := testChannel.Node2 @@ -730,6 +764,7 @@ func createTestGraphFromChannels(testChannels []*testChannel, source string) ( cleanUp: cleanUp, aliasMap: aliasMap, privKeyMap: privKeyMap, + links: links, }, nil } @@ -923,7 +958,7 @@ func testBasicGraphPathFindingCase(t *testing.T, graphInstance *testGraphInstanc paymentAmt := lnwire.NewMSatFromSatoshis(test.paymentAmt) target := graphInstance.aliasMap[test.target] path, err := dbFindPath( - graphInstance.graph, nil, nil, + graphInstance.graph, nil, &mockBandwidthHints{}, &RestrictParams{ FeeLimit: test.feeLimit, ProbabilitySource: noProbabilitySource, @@ -1118,7 +1153,7 @@ func TestPathFindingWithAdditionalEdges(t *testing.T) { []*channeldb.CachedEdgePolicy, error) { return dbFindPath( - graph.graph, additionalEdges, nil, + graph.graph, additionalEdges, &mockBandwidthHints{}, r, testPathFindingConfig, sourceNode.PubKeyBytes, doge.PubKeyBytes, paymentAmt, 0, @@ -1543,7 +1578,7 @@ func TestPathNotAvailable(t *testing.T) { copy(unknownNode[:], unknownNodeBytes) _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, unknownNode, 100, 0, ) @@ -1599,7 +1634,7 @@ func TestDestTLVGraphFallback(t *testing.T) { target route.Vertex) ([]*channeldb.CachedEdgePolicy, error) { return dbFindPath( - ctx.graph, nil, nil, + ctx.graph, nil, &mockBandwidthHints{}, r, testPathFindingConfig, sourceNode.PubKeyBytes, target, 100, 0, ) @@ -1868,7 +1903,7 @@ func TestPathInsufficientCapacity(t *testing.T) { payAmt := lnwire.NewMSatFromSatoshis(btcutil.SatoshiPerBitcoin) _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -1899,7 +1934,7 @@ func TestRouteFailMinHTLC(t *testing.T) { target := graph.aliasMap["songoku"] payAmt := lnwire.MilliSatoshi(10) _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -1996,7 +2031,7 @@ func TestRouteFailDisabledEdge(t *testing.T) { target := graph.aliasMap["sophon"] payAmt := lnwire.NewMSatFromSatoshis(105000) _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -2022,7 +2057,7 @@ func TestRouteFailDisabledEdge(t *testing.T) { } _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -2045,7 +2080,7 @@ func TestRouteFailDisabledEdge(t *testing.T) { // If we attempt to route through that edge, we should get a failure as // it is no longer eligible. _, err = dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -2077,7 +2112,7 @@ func TestPathSourceEdgesBandwidth(t *testing.T) { target := graph.aliasMap["sophon"] payAmt := lnwire.NewMSatFromSatoshis(50000) path, err := dbFindPath( - graph.graph, nil, nil, + graph.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, payAmt, 0, ) @@ -2090,9 +2125,11 @@ func TestPathSourceEdgesBandwidth(t *testing.T) { // roasbeef->phamnuwen to 0. roasToSongoku := uint64(12345) roasToPham := uint64(999991) - bandwidths := map[uint64]lnwire.MilliSatoshi{ - roasToSongoku: 0, - roasToPham: 0, + bandwidths := &mockBandwidthHints{ + hints: map[uint64]lnwire.MilliSatoshi{ + roasToSongoku: 0, + roasToPham: 0, + }, } // Since both these edges has a bandwidth of zero, no path should be @@ -2108,7 +2145,7 @@ func TestPathSourceEdgesBandwidth(t *testing.T) { // Set the bandwidth of roasbeef->phamnuwen high enough to carry the // payment. - bandwidths[roasToPham] = 2 * payAmt + bandwidths.hints[roasToPham] = 2 * payAmt // Now, if we attempt to route again, we should find the path via // phamnuven, as the other source edge won't be considered. @@ -2124,7 +2161,7 @@ func TestPathSourceEdgesBandwidth(t *testing.T) { // Finally, set the roasbeef->songoku bandwidth, but also set its // disable flag. - bandwidths[roasToSongoku] = 2 * payAmt + bandwidths.hints[roasToSongoku] = 2 * payAmt _, e1, e2, err := graph.graph.FetchChannelEdgesByID(roasToSongoku) if err != nil { t.Fatalf("unable to fetch edge: %v", err) @@ -2936,7 +2973,7 @@ type pathFindingTestContext struct { t *testing.T graph *channeldb.ChannelGraph restrictParams RestrictParams - bandwidthHints map[uint64]lnwire.MilliSatoshi + bandwidthHints bandwidthHints pathFindingConfig PathFindingConfig testGraphInstance *testGraphInstance source route.Vertex @@ -2964,6 +3001,7 @@ func newPathFindingTestContext(t *testing.T, testChannels []*testChannel, pathFindingConfig: *testPathFindingConfig, graph: testGraphInstance.graph, restrictParams: *noRestrictions, + bandwidthHints: &mockBandwidthHints{}, } return ctx @@ -3016,7 +3054,7 @@ func (c *pathFindingTestContext) assertPath(path []*channeldb.CachedEdgePolicy, // graph. func dbFindPath(graph *channeldb.ChannelGraph, additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy, - bandwidthHints map[uint64]lnwire.MilliSatoshi, + bandwidthHints bandwidthHints, r *RestrictParams, cfg *PathFindingConfig, source, target route.Vertex, amt lnwire.MilliSatoshi, finalHtlcExpiry int32) ([]*channeldb.CachedEdgePolicy, error) { diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index d233d8bde..64add4591 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -472,9 +472,6 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, Payer: payer, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, - QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(c.Capacity) - }, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil diff --git a/routing/payment_session.go b/routing/payment_session.go index bbf9b6f96..8895d28fe 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -164,7 +164,7 @@ type PaymentSession interface { type paymentSession struct { additionalEdges map[route.Vertex][]*channeldb.CachedEdgePolicy - getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error) + getBandwidthHints func() (bandwidthHints, error) payment *LightningPayment @@ -192,7 +192,7 @@ type paymentSession struct { // newPaymentSession instantiates a new payment session. func newPaymentSession(p *LightningPayment, - getBandwidthHints func() (map[uint64]lnwire.MilliSatoshi, error), + getBandwidthHints func() (bandwidthHints, error), routingGraph routingGraph, missionControl MissionController, pathFindingConfig PathFindingConfig) ( *paymentSession, error) { diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index d688f9814..6889d0a17 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -19,12 +19,12 @@ type SessionSource struct { // and also to carry out path finding queries. Graph routingGraph - // QueryBandwidth is a method that allows querying the lower link layer + // GetLink is a method that allows querying the lower link layer // to determine the up to date available bandwidth at a prospective link // to be traversed. If the link isn't available, then a value of zero // should be returned. Otherwise, the current up to date knowledge of // the available bandwidth of the link should be returned. - QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi + GetLink getLinkQuery // MissionControl is a shared memory of sorts that executions of payment // path finding use in order to remember which vertexes/edges were @@ -47,12 +47,10 @@ type SessionSource struct { func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( PaymentSession, error) { - getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, - error) { + sourceNode := m.Graph.sourceNode() - return generateBandwidthHints( - m.Graph.sourceNode(), m.Graph, m.QueryBandwidth, - ) + getBandwidthHints := func() (bandwidthHints, error) { + return newBandwidthManager(m.Graph, sourceNode, m.GetLink) } session, err := newPaymentSession( diff --git a/routing/payment_session_test.go b/routing/payment_session_test.go index dae331f84..f177da730 100644 --- a/routing/payment_session_test.go +++ b/routing/payment_session_test.go @@ -116,10 +116,8 @@ func TestUpdateAdditionalEdge(t *testing.T) { // Create the paymentsession. session, err := newPaymentSession( payment, - func() (map[uint64]lnwire.MilliSatoshi, - error) { - - return nil, nil + func() (bandwidthHints, error) { + return &mockBandwidthHints{}, nil }, &sessionGraph{}, &MissionControl{}, @@ -196,10 +194,8 @@ func TestRequestRoute(t *testing.T) { session, err := newPaymentSession( payment, - func() (map[uint64]lnwire.MilliSatoshi, - error) { - - return nil, nil + func() (bandwidthHints, error) { + return &mockBandwidthHints{}, nil }, &sessionGraph{}, &MissionControl{}, diff --git a/routing/router.go b/routing/router.go index dd8a375a2..5baf6eccc 100644 --- a/routing/router.go +++ b/routing/router.go @@ -339,7 +339,7 @@ type Config struct { // a value of zero should be returned. Otherwise, the current up to // date knowledge of the available bandwidth of the link should be // returned. - QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi + GetLink getLinkQuery // NextPaymentID is a method that guarantees to return a new, unique ID // each time it is called. This is used by the router to generate a @@ -1741,8 +1741,8 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. - bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, + bandwidthHints, err := newBandwidthManager( + r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, ) if err != nil { return nil, err @@ -2652,41 +2652,6 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { return r.cfg.Graph.MarkEdgeLive(chanID.ToUint64()) } -// generateBandwidthHints is a helper function that's utilized the main -// findPath function in order to obtain hints from the lower layer w.r.t to the -// available bandwidth of edges on the network. Currently, we'll only obtain -// bandwidth hints for the edges we directly have open ourselves. Obtaining -// these hints allows us to reduce the number of extraneous attempts as we can -// skip channels that are inactive, or just don't have enough bandwidth to -// carry the payment. -func generateBandwidthHints(sourceNode route.Vertex, graph routingGraph, - queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( - map[uint64]lnwire.MilliSatoshi, error) { - - // First, we'll collect the set of outbound edges from the target - // source node. - var localChans []*channeldb.DirectedChannel - err := graph.forEachNodeChannel( - sourceNode, func(channel *channeldb.DirectedChannel) error { - localChans = append(localChans, channel) - return nil - }, - ) - if err != nil { - return nil, err - } - - // Now that we have all of our outbound edges, we'll populate the set - // of bandwidth hints, querying the lower switch layer for the most up - // to date values. - bandwidthHints := make(map[uint64]lnwire.MilliSatoshi) - for _, localChan := range localChans { - bandwidthHints[localChan.ChannelID] = queryBandwidth(localChan) - } - - return bandwidthHints, nil -} - // ErrNoChannel is returned when a route cannot be built because there are no // channels that satisfy all requirements. type ErrNoChannel struct { @@ -2723,8 +2688,8 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. - bandwidthHints, err := generateBandwidthHints( - r.selfNode.PubKeyBytes, r.cachedGraph, r.cfg.QueryBandwidth, + bandwidthHints, err := newBandwidthManager( + r.cachedGraph, r.selfNode.PubKeyBytes, r.cfg.GetLink, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 4b5dd505f..77681b2a6 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -133,12 +133,8 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, require.NoError(t, err) sessionSource := &SessionSource{ - Graph: cachedGraph, - QueryBandwidth: func( - c *channeldb.DirectedChannel) lnwire.MilliSatoshi { - - return lnwire.NewMSatFromSatoshis(c.Capacity) - }, + Graph: cachedGraph, + GetLink: graphInstance.getLink, PathFindingConfig: pathFindingConfig, MissionControl: mc, } @@ -160,11 +156,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, SessionSource: sessionSource, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, - QueryBandwidth: func( - e *channeldb.DirectedChannel) lnwire.MilliSatoshi { - - return lnwire.NewMSatFromSatoshis(e.Capacity) - }, + GetLink: graphInstance.getLink, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) return next, nil @@ -2467,7 +2459,7 @@ func TestFindPathFeeWeighting(t *testing.T) { // the edge weighting, we should select the direct path over the 2 hop // path even though the direct path has a higher potential time lock. path, err := dbFindPath( - ctx.graph, nil, nil, + ctx.graph, nil, &mockBandwidthHints{}, noRestrictions, testPathFindingConfig, sourceNode.PubKeyBytes, target, amt, 0, diff --git a/routing/unified_policies.go b/routing/unified_policies.go index fe7cc1ec4..1df166b9b 100644 --- a/routing/unified_policies.go +++ b/routing/unified_policies.go @@ -133,7 +133,7 @@ type unifiedPolicy struct { // specific amount to send. It differentiates between local and network // channels. func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { + bandwidthHints bandwidthHints) *channeldb.CachedEdgePolicy { if u.localChan { return u.getPolicyLocal(amt, bandwidthHints) @@ -145,7 +145,7 @@ func (u *unifiedPolicy) getPolicy(amt lnwire.MilliSatoshi, // getPolicyLocal returns the optimal policy to use for this local connection // given a specific amount to send. func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, - bandwidthHints map[uint64]lnwire.MilliSatoshi) *channeldb.CachedEdgePolicy { + bandwidthHints bandwidthHints) *channeldb.CachedEdgePolicy { var ( bestPolicy *channeldb.CachedEdgePolicy @@ -169,7 +169,9 @@ func (u *unifiedPolicy) getPolicyLocal(amt lnwire.MilliSatoshi, // TODO(joostjager): Possibly change to skipping this // channel. The bandwidth hint is expected to be // available. - bandwidth, ok := bandwidthHints[edge.policy.ChannelID] + bandwidth, ok := bandwidthHints.availableChanBandwidth( + edge.policy.ChannelID, + ) if !ok { bandwidth = lnwire.MaxMilliSatoshi } diff --git a/routing/unified_policies_test.go b/routing/unified_policies_test.go index ac915f99a..abdc56b62 100644 --- a/routing/unified_policies_test.go +++ b/routing/unified_policies_test.go @@ -15,7 +15,7 @@ func TestUnifiedPolicies(t *testing.T) { toNode := route.Vertex{2} fromNode := route.Vertex{3} - bandwidthHints := map[uint64]lnwire.MilliSatoshi{} + bandwidthHints := &mockBandwidthHints{} u := newUnifiedPolicies(source, toNode, nil) diff --git a/server.go b/server.go index 0a2c410e5..356e056da 100644 --- a/server.go +++ b/server.go @@ -820,33 +820,6 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } - queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { - cid := lnwire.NewShortChanIDFromInt(c.ChannelID) - link, err := s.htlcSwitch.GetLinkByShortID(cid) - if err != nil { - // If the link isn't online, then we'll report - // that it has zero bandwidth to the router. - return 0 - } - - // If the link is found within the switch, but it isn't - // yet eligible to forward any HTLCs, then we'll treat - // it as if it isn't online in the first place. - if !link.EligibleToForward() { - return 0 - } - - // If our link isn't currently in a state where it can - // add another outgoing htlc, treat the link as unusable. - if err := link.MayAddOutgoingHtlc(); err != nil { - return 0 - } - - // Otherwise, we'll return the current best estimate - // for the available bandwidth for the link. - return link.Bandwidth() - } - // Instantiate mission control with config from the sub server. // // TODO(joostjager): When we are further in the process of moving to sub @@ -893,7 +866,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, paymentSessionSource := &routing.SessionSource{ Graph: cachedGraph, MissionControl: s.missionControl, - QueryBandwidth: queryBandwidth, + GetLink: s.htlcSwitch.GetLinkByShortID, PathFindingConfig: pathFindingConfig, } @@ -915,7 +888,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ChannelPruneExpiry: routing.DefaultChannelPruneExpiry, GraphPruneInterval: time.Hour, FirstTimePruneDelay: routing.DefaultFirstTimePruneDelay, - QueryBandwidth: queryBandwidth, + GetLink: s.htlcSwitch.GetLinkByShortID, AssumeChannelValid: cfg.Routing.AssumeChannelValid, NextPaymentID: sequencer.NextID, PathFindingConfig: pathFindingConfig,