From dc6f9256bc0b9cc437688fcc25df2d40a8f5ca89 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 1 Aug 2025 15:06:36 +0200 Subject: [PATCH] graph/db: batch fetch channels in ForEachNodeCached Previously, ForEachNodeCached would batch fetch node _feature_ data but would still fetch the channel set of each node in a node-by-node fashion which is not ideal. So this commit updates this method to make use of the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper. It lets us batch load channel data for a range of node IDs. This _greatly_ improves the performance of the method. --- graph/db/sql_store.go | 400 +++++++++++++++++++++++++++++++----------- 1 file changed, 299 insertions(+), 101 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 770481217..7e57f9619 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -100,6 +100,7 @@ type SQLQueries interface { GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) + ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error) ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error) ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error) @@ -1085,112 +1086,185 @@ func (s *SQLStore) ForEachNodeCached(ctx context.Context, cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error, reset func()) error { - handleNode := func(db SQLQueries, nodeID int64, - nodePub route.Vertex, features *lnwire.FeatureVector) error { - - toNodeCallback := func() route.Vertex { - return nodePub - } - - rows, err := db.ListChannelsByNodeID( - ctx, sqlc.ListChannelsByNodeIDParams{ - Version: int16(ProtocolV1), - NodeID1: nodeID, - }, - ) - if err != nil { - return fmt.Errorf("unable to fetch channels of "+ - "node(id=%d): %w", nodeID, err) - } - - channels := make(map[uint64]*DirectedChannel, len(rows)) - for _, row := range rows { - node1, node2, err := buildNodeVertices( - row.Node1Pubkey, row.Node2Pubkey, - ) - if err != nil { - return err - } - - e, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1, node2, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "info: %w", err) - } - - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) - } - - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, e.ChannelID, node1, - node2, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) - } - - // Determine the outgoing and incoming policy - // for this channel and node combo. - outPolicy, inPolicy := p1, p2 - if p1 != nil && p1.ToNode == nodePub { - outPolicy, inPolicy = p2, p1 - } else if p2 != nil && p2.ToNode != nodePub { - outPolicy, inPolicy = p2, p1 - } - - var cachedInPolicy *models.CachedEdgePolicy - if inPolicy != nil { - cachedInPolicy = models.NewCachedPolicy( - inPolicy, - ) - cachedInPolicy.ToNodePubKey = toNodeCallback - cachedInPolicy.ToNodeFeatures = features - } - - var inboundFee lnwire.Fee - if outPolicy != nil { - outPolicy.InboundFee.WhenSome( - func(fee lnwire.Fee) { - inboundFee = fee - }, - ) - } - - directedChannel := &DirectedChannel{ - ChannelID: e.ChannelID, - IsNode1: nodePub == e.NodeKey1Bytes, - OtherNode: e.NodeKey2Bytes, - Capacity: e.Capacity, - OutPolicySet: outPolicy != nil, - InPolicy: cachedInPolicy, - InboundFee: inboundFee, - } - - if nodePub == e.NodeKey2Bytes { - directedChannel.OtherNode = e.NodeKey1Bytes - } - - channels[e.ChannelID] = directedChannel - } - - return cb(nodePub, channels) + type nodeCachedBatchData struct { + features map[int64][]int + chanBatchData *batchChannelData + chanMap map[int64][]sqlc.ListChannelsForNodeIDsRow } return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - return forEachNodeCacheable( - ctx, s.cfg.QueryCfg, db, - func(nodeID int64, nodePub route.Vertex, - features *lnwire.FeatureVector) error { + // pageQueryFunc is used to query the next page of nodes. + pageQueryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) { - return handleNode(db, nodeID, nodePub, features) + return db.ListNodeIDsAndPubKeys( + ctx, sqlc.ListNodeIDsAndPubKeysParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: limit, + }, + ) + } + + // batchDataFunc is then used to batch load the data required + // for each page of nodes. + batchDataFunc := func(ctx context.Context, + nodeIDs []int64) (*nodeCachedBatchData, error) { + + // Batch load node features. + nodeFeatures, err := batchLoadNodeFeaturesHelper( + ctx, s.cfg.QueryCfg, db, nodeIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load "+ + "node features: %w", err) + } + + // Batch load ALL unique channels for ALL nodes in this + // page. + allChannels, err := db.ListChannelsForNodeIDs( + ctx, sqlc.ListChannelsForNodeIDsParams{ + Version: int16(ProtocolV1), + Node1Ids: nodeIDs, + Node2Ids: nodeIDs, + }, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch "+ + "fetch channels for nodes: %w", err) + } + + // Deduplicate channels and collect IDs. + var ( + allChannelIDs []int64 + allPolicyIDs []int64 + ) + uniqueChannels := make( + map[int64]sqlc.ListChannelsForNodeIDsRow, + ) + + for _, channel := range allChannels { + channelID := channel.GraphChannel.ID + + // Only process each unique channel once. + _, exists := uniqueChannels[channelID] + if exists { + continue + } + + uniqueChannels[channelID] = channel + allChannelIDs = append(allChannelIDs, channelID) + + if channel.Policy1ID.Valid { + allPolicyIDs = append( + allPolicyIDs, + channel.Policy1ID.Int64, + ) + } + if channel.Policy2ID.Valid { + allPolicyIDs = append( + allPolicyIDs, + channel.Policy2ID.Int64, + ) + } + } + + // Batch load channel data for all unique channels. + channelBatchData, err := batchLoadChannelData( + ctx, s.cfg.QueryCfg, db, allChannelIDs, + allPolicyIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch "+ + "load channel data: %w", err) + } + + // Create map of node ID to channels that involve this + // node. + nodeIDSet := make(map[int64]bool) + for _, nodeID := range nodeIDs { + nodeIDSet[nodeID] = true + } + + nodeChannelMap := make( + map[int64][]sqlc.ListChannelsForNodeIDsRow, + ) + for _, channel := range uniqueChannels { + // Add channel to both nodes if they're in our + // current page. + node1 := channel.GraphChannel.NodeID1 + if nodeIDSet[node1] { + nodeChannelMap[node1] = append( + nodeChannelMap[node1], channel, + ) + } + node2 := channel.GraphChannel.NodeID2 + if nodeIDSet[node2] { + nodeChannelMap[node2] = append( + nodeChannelMap[node2], channel, + ) + } + } + + return &nodeCachedBatchData{ + features: nodeFeatures, + chanBatchData: channelBatchData, + chanMap: nodeChannelMap, + }, nil + } + + // processItem is used to process each node in the current page. + processItem := func(ctx context.Context, + nodeData sqlc.ListNodeIDsAndPubKeysRow, + batchData *nodeCachedBatchData) error { + + // Build feature vector for this node. + fv := lnwire.EmptyFeatureVector() + features, exists := batchData.features[nodeData.ID] + if exists { + for _, bit := range features { + fv.Set(lnwire.FeatureBit(bit)) + } + } + + var nodePub route.Vertex + copy(nodePub[:], nodeData.PubKey) + + nodeChannels := batchData.chanMap[nodeData.ID] + + toNodeCallback := func() route.Vertex { + return nodePub + } + + // Build cached channels map for this node. + channels := make(map[uint64]*DirectedChannel) + for _, channelRow := range nodeChannels { + directedChan, err := buildDirectedChannel( + s.cfg.ChainHash, nodeData.ID, nodePub, + channelRow, batchData.chanBatchData, fv, + toNodeCallback, + ) + if err != nil { + return err + } + + channels[directedChan.ChannelID] = directedChan + } + + return cb(nodePub, channels) + } + + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc, + func(node sqlc.ListNodeIDsAndPubKeysRow) int64 { + return node.ID }, + func(node sqlc.ListNodeIDsAndPubKeysRow) (int64, + error) { + + return node.ID, nil + }, + batchDataFunc, processItem, ) }, reset) } @@ -4411,6 +4485,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, return policy1, policy2, nil + case sqlc.ListChannelsForNodeIDsRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.ListChannelsByNodeIDRow: if r.Policy1ID.Valid { policy1 = &sqlc.GraphChannelPolicy{ @@ -5118,3 +5236,83 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, collectFunc, batchDataFunc, processItem, ) } + +// buildDirectedChannel builds a DirectedChannel instance from the provided +// data. +func buildDirectedChannel(chain chainhash.Hash, nodeID int64, + nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow, + channelBatchData *batchChannelData, features *lnwire.FeatureVector, + toNodeCallback func() route.Vertex) (*DirectedChannel, error) { + + node1, node2, err := buildNodeVertices( + channelRow.Node1Pubkey, channelRow.Node2Pubkey, + ) + if err != nil { + return nil, fmt.Errorf("unable to build node vertices: %w", err) + } + + edge, err := buildEdgeInfoWithBatchData( + chain, channelRow.GraphChannel, node1, node2, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel info: %w", err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(channelRow) + if err != nil { + return nil, fmt.Errorf("unable to extract channel policies: %w", + err) + } + + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, edge.ChannelID, node1, node2, + channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel policies: %w", + err) + } + + // Determine outgoing and incoming policy for this specific node. + p1ToNode := channelRow.GraphChannel.NodeID2 + p2ToNode := channelRow.GraphChannel.NodeID1 + outPolicy, inPolicy := p1, p2 + if (p1 != nil && p1ToNode == nodeID) || + (p2 != nil && p2ToNode != nodeID) { + + outPolicy, inPolicy = p2, p1 + } + + // Build cached policy. + var cachedInPolicy *models.CachedEdgePolicy + if inPolicy != nil { + cachedInPolicy = models.NewCachedPolicy(inPolicy) + cachedInPolicy.ToNodePubKey = toNodeCallback + cachedInPolicy.ToNodeFeatures = features + } + + // Extract inbound fee. + var inboundFee lnwire.Fee + if outPolicy != nil { + outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) { + inboundFee = fee + }) + } + + // Build directed channel. + directedChannel := &DirectedChannel{ + ChannelID: edge.ChannelID, + IsNode1: nodePub == edge.NodeKey1Bytes, + OtherNode: edge.NodeKey2Bytes, + Capacity: edge.Capacity, + OutPolicySet: outPolicy != nil, + InPolicy: cachedInPolicy, + InboundFee: inboundFee, + } + + if nodePub == edge.NodeKey2Bytes { + directedChannel.OtherNode = edge.NodeKey1Bytes + } + + return directedChannel, nil +}