multi: use cache for source channels

This commit is contained in:
Oliver Gugger 2021-09-21 19:18:21 +02:00
parent 369c09be61
commit 15d3f62d5e
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
7 changed files with 37 additions and 23 deletions

View File

@ -2315,6 +2315,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *ChannelEdgePolicy,
) )
copy(fromNodePubKey[:], fromNode) copy(fromNodePubKey[:], fromNode)
copy(toNodePubKey[:], toNode) copy(toNodePubKey[:], toNode)
// TODO(guggero): Fetch lightning nodes before updating the cache!
graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1) graphCache.UpdatePolicy(edge, fromNodePubKey, toNodePubKey, isUpdate1)
return isUpdate1, nil return isUpdate1, nil

View File

@ -2052,6 +2052,17 @@ func (s *Switch) getLink(chanID lnwire.ChannelID) (ChannelLink, error) {
return link, nil return link, nil
} }
// GetLinkByShortID attempts to return the link which possesses the target short
// channel ID.
func (s *Switch) GetLinkByShortID(chanID lnwire.ShortChannelID) (ChannelLink,
error) {
s.indexMtx.RLock()
defer s.indexMtx.RUnlock()
return s.getLinkByShortID(chanID)
}
// getLinkByShortID attempts to return the link which possesses the target // getLinkByShortID attempts to return the link which possesses the target
// short channel ID. // short channel ID.
// //

View File

@ -472,8 +472,8 @@ func testPaymentLifecycle(t *testing.T, test paymentLifecycleTestCase,
Payer: payer, Payer: payer,
ChannelPruneExpiry: time.Hour * 24, ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2, GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func(e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { QueryBandwidth: func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(c.Capacity)
}, },
NextPaymentID: func() (uint64, error) { NextPaymentID: func() (uint64, error) {
next := atomic.AddUint64(&uniquePaymentID, 1) next := atomic.AddUint64(&uniquePaymentID, 1)

View File

@ -24,7 +24,7 @@ type SessionSource struct {
// to be traversed. If the link isn't available, then a value of zero // to be traversed. If the link isn't available, then a value of zero
// should be returned. Otherwise, the current up to date knowledge of // should be returned. Otherwise, the current up to date knowledge of
// the available bandwidth of the link should be returned. // the available bandwidth of the link should be returned.
QueryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// MissionControl is a shared memory of sorts that executions of payment // MissionControl is a shared memory of sorts that executions of payment
// path finding use in order to remember which vertexes/edges were // path finding use in order to remember which vertexes/edges were
@ -65,7 +65,9 @@ func (m *SessionSource) NewPaymentSession(p *LightningPayment) (
getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi, getBandwidthHints := func() (map[uint64]lnwire.MilliSatoshi,
error) { error) {
return generateBandwidthHints(sourceNode, m.QueryBandwidth) return generateBandwidthHints(
sourceNode.PubKeyBytes, m.Graph, m.QueryBandwidth,
)
} }
session, err := newPaymentSession( session, err := newPaymentSession(

View File

@ -339,7 +339,7 @@ type Config struct {
// a value of zero should be returned. Otherwise, the current up to // a value of zero should be returned. Otherwise, the current up to
// date knowledge of the available bandwidth of the link should be // date knowledge of the available bandwidth of the link should be
// returned. // returned.
QueryBandwidth func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi QueryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi
// NextPaymentID is a method that guarantees to return a new, unique ID // NextPaymentID is a method that guarantees to return a new, unique ID
// each time it is called. This is used by the router to generate a // each time it is called. This is used by the router to generate a
@ -1735,7 +1735,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// We'll attempt to obtain a set of bandwidth hints that can help us // We'll attempt to obtain a set of bandwidth hints that can help us
// eliminate certain routes early on in the path finding process. // eliminate certain routes early on in the path finding process.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -2657,19 +2657,19 @@ func (r *ChannelRouter) MarkEdgeLive(chanID lnwire.ShortChannelID) error {
// these hints allows us to reduce the number of extraneous attempts as we can // these hints allows us to reduce the number of extraneous attempts as we can
// skip channels that are inactive, or just don't have enough bandwidth to // skip channels that are inactive, or just don't have enough bandwidth to
// carry the payment. // carry the payment.
func generateBandwidthHints(sourceNode *channeldb.LightningNode, func generateBandwidthHints(sourceNode route.Vertex, graph *channeldb.ChannelGraph,
queryBandwidth func(*channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi) (map[uint64]lnwire.MilliSatoshi, error) { queryBandwidth func(*channeldb.DirectedChannel) lnwire.MilliSatoshi) (
map[uint64]lnwire.MilliSatoshi, error) {
// First, we'll collect the set of outbound edges from the target // First, we'll collect the set of outbound edges from the target
// source node. // source node.
var localChans []*channeldb.ChannelEdgeInfo var localChans []*channeldb.DirectedChannel
err := sourceNode.ForEachChannel(nil, func(tx kvdb.RTx, err := graph.ForEachNodeChannel(
edgeInfo *channeldb.ChannelEdgeInfo, sourceNode, func(channel *channeldb.DirectedChannel) error {
_, _ *channeldb.ChannelEdgePolicy) error { localChans = append(localChans, channel)
return nil
localChans = append(localChans, edgeInfo) },
return nil )
})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -2722,7 +2722,7 @@ func (r *ChannelRouter) BuildRoute(amt *lnwire.MilliSatoshi,
// We'll attempt to obtain a set of bandwidth hints that helps us select // We'll attempt to obtain a set of bandwidth hints that helps us select
// the best outgoing channel to use in case no outgoing channel is set. // the best outgoing channel to use in case no outgoing channel is set.
bandwidthHints, err := generateBandwidthHints( bandwidthHints, err := generateBandwidthHints(
r.selfNode, r.cfg.QueryBandwidth, r.selfNode.PubKeyBytes, r.cfg.Graph, r.cfg.QueryBandwidth,
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -132,9 +132,9 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
sessionSource := &SessionSource{ sessionSource := &SessionSource{
Graph: graphInstance.graph, Graph: graphInstance.graph,
QueryBandwidth: func( QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(c.Capacity)
}, },
PathFindingConfig: pathFindingConfig, PathFindingConfig: pathFindingConfig,
MissionControl: mc, MissionControl: mc,
@ -158,7 +158,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T,
ChannelPruneExpiry: time.Hour * 24, ChannelPruneExpiry: time.Hour * 24,
GraphPruneInterval: time.Hour * 2, GraphPruneInterval: time.Hour * 2,
QueryBandwidth: func( QueryBandwidth: func(
e *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { e *channeldb.DirectedChannel) lnwire.MilliSatoshi {
return lnwire.NewMSatFromSatoshis(e.Capacity) return lnwire.NewMSatFromSatoshis(e.Capacity)
}, },

View File

@ -710,9 +710,9 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
return nil, err return nil, err
} }
queryBandwidth := func(edge *channeldb.ChannelEdgeInfo) lnwire.MilliSatoshi { queryBandwidth := func(c *channeldb.DirectedChannel) lnwire.MilliSatoshi {
cid := lnwire.NewChanIDFromOutPoint(&edge.ChannelPoint) cid := lnwire.NewShortChanIDFromInt(c.ChannelID)
link, err := s.htlcSwitch.GetLink(cid) link, err := s.htlcSwitch.GetLinkByShortID(cid)
if err != nil { if err != nil {
// If the link isn't online, then we'll report // If the link isn't online, then we'll report
// that it has zero bandwidth to the router. // that it has zero bandwidth to the router.