diff --git a/channeldb/graph.go b/channeldb/graph.go index e3ec83113..8806bcff1 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2315,6 +2315,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy, ) copy(fromNodePubKey[:], fromNode) copy(toNodePubKey[:], toNode) + // TODO(guggero): Fetch lightning nodes before updating the cache! graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) return isUpdate1, nil diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 17b423857..5f02d406c 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2052,6 +2052,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) { return link, nil } +// GetLinkByShortID attempts to return the link which possesses the target short +// channel ID. +func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink, + error) { + + s.indexMtx.RLock() + defer s.indexMtx.RUnlock() + + return s.getLinkByShortID(chanID) +} + // getLinkByShortID attempts to return the link which possesses the target // short channel ID. // diff --git a/routing/payment_lifecycle_test.go b/routing/payment_lifecycle_test.go index 6a3517023..d233d8bde 100644 --- a/routing/payment_lifecycle_test.go +++ b/routing/payment_lifecycle_test.go @@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase, Payer: payer, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, - QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(e.Capacity) + QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + return lnwire.NewMSatFromSatoshis(c.Capacity) }, NextPaymentID: func() (uint64, error) { next := atomic.AddUint64(&uniquePaymentID, 1) diff --git a/routing/payment_session_source.go b/routing/payment_session_source.go index f08005909..661d5861d 100644 --- a/routing/payment_session_source.go +++ b/routing/payment_session_source.go @@ -24,7 +24,7 @@ type SessionSource struct { // to be traversed. If the link isn't available, then a value of zero // should be returned. Otherwise, the current up to date knowledge of // the available bandwidth of the link should be returned. - QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // MissionControl is a shared memory of sorts that executions of payment // path finding use in order to remember which vertexes/edges were @@ -65,7 +65,9 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) ( getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, error) { - return generateBandwidthHints(sourceNode, m.QueryBandwidth) + return generateBandwidthHints( + sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth, + ) } session, err := newPaymentSession( diff --git a/routing/router.go b/routing/router.go index 9864a991d..aa034eea0 100644 --- a/routing/router.go +++ b/routing/router.go @@ -339,7 +339,7 @@ type Config struct { // a value of zero should be returned. Otherwise, the current up to // date knowledge of the available bandwidth of the link should be // returned. - QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi + QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi // NextPaymentID is a method that guarantees to return a new, unique ID // each time it is called. This is used by the router to generate a @@ -1735,7 +1735,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // We'll attempt to obtain a set of bandwidth hints that can help us // eliminate certain routes early on in the path finding process. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err @@ -2657,19 +2657,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error { // these hints allows us to reduce the number of extraneous attempts as we can // skip channels that are inactive, or just don't have enough bandwidth to // carry the payment. -func generateBandwidthHints(sourceNode *channeldb.LightningNode, - queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) { +func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph, + queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) ( + map[uint64]lnwire.MilliSatoshi, error) { // First, we'll collect the set of outbound edges from the target // source node. - var localChans []*channeldb.ChannelEdgeInfo - err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx, - edgeInfo *channeldb.ChannelEdgeInfo, - _, _ *channeldb.ChannelEdgePolicy) error { - - localChans = append(localChans, edgeInfo) - return nil - }) + var localChans []*channeldb.DirectedChannel + err := graph.ForEachNodeChannel( + sourceNode, func(channel *channeldb.DirectedChannel) error { + localChans = append(localChans, channel) + return nil + }, + ) if err != nil { return nil, err } @@ -2722,7 +2722,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi, // We'll attempt to obtain a set of bandwidth hints that helps us select // the best outgoing channel to use in case no outgoing channel is set. bandwidthHints, err := generateBandwidthHints( - r.selfNode, r.cfg.QueryBandwidth, + r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth, ) if err != nil { return nil, err diff --git a/routing/router_test.go b/routing/router_test.go index 1633d3810..d263ce738 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -132,9 +132,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, sessionSource := &SessionSource{ Graph: graphInstance.graph, QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + c *channeldb.DirectedChannel) lnwire.MilliSatoshi { - return lnwire.NewMSatFromSatoshis(e.Capacity) + return lnwire.NewMSatFromSatoshis(c.Capacity) }, PathFindingConfig: pathFindingConfig, MissionControl: mc, @@ -158,7 +158,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, ChannelPruneExpiry: time.Hour * 24, GraphPruneInterval: time.Hour * 2, QueryBandwidth: func( - e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { + e *channeldb.DirectedChannel) lnwire.MilliSatoshi { return lnwire.NewMSatFromSatoshis(e.Capacity) }, diff --git a/server.go b/server.go index 0b1afe400..55e49d9f2 100644 --- a/server.go +++ b/server.go @@ -710,9 +710,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return nil, err } - queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { - cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) - link, err := s.htlcSwitch.GetLink(cid) + queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi { + cid := lnwire.NewShortChanIDFromInt(c.ChannelID) + link, err := s.htlcSwitch.GetLinkByShortID(cid) if err != nil { // If the link isn't online, then we'll report // that it has zero bandwidth to the router.