graph/db: batch fetch channels in ForEachNodeCached

Previously, ForEachNodeCached would batch fetch node _feature_ data but
would still fetch the channel set of each node in a node-by-node fashion
which is not ideal. So this commit updates this method to make use of
the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper. It lets
us batch load channel data for a range of node IDs.

This _greatly_ improves the performance of the method.
This commit is contained in:
Elle Mouton
2025-08-01 15:06:36 +02:00
parent 3b60d33ac8
commit dc6f9256bc

View File

@@ -100,6 +100,7 @@ type SQLQueries interface {
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
@@ -1085,112 +1086,185 @@ func (s *SQLStore) ForEachNodeCached(ctx context.Context,
cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
reset func()) error {
handleNode := func(db SQLQueries, nodeID int64,
nodePub route.Vertex, features *lnwire.FeatureVector) error {
toNodeCallback := func() route.Vertex {
return nodePub
}
rows, err := db.ListChannelsByNodeID(
ctx, sqlc.ListChannelsByNodeIDParams{
Version: int16(ProtocolV1),
NodeID1: nodeID,
},
)
if err != nil {
return fmt.Errorf("unable to fetch channels of "+
"node(id=%d): %w", nodeID, err)
}
channels := make(map[uint64]*DirectedChannel, len(rows))
for _, row := range rows {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}
e, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.GraphChannel,
node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"info: %w", err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
p1, p2, err := getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, e.ChannelID, node1,
node2,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
// Determine the outgoing and incoming policy
// for this channel and node combo.
outPolicy, inPolicy := p1, p2
if p1 != nil && p1.ToNode == nodePub {
outPolicy, inPolicy = p2, p1
} else if p2 != nil && p2.ToNode != nodePub {
outPolicy, inPolicy = p2, p1
}
var cachedInPolicy *models.CachedEdgePolicy
if inPolicy != nil {
cachedInPolicy = models.NewCachedPolicy(
inPolicy,
)
cachedInPolicy.ToNodePubKey = toNodeCallback
cachedInPolicy.ToNodeFeatures = features
}
var inboundFee lnwire.Fee
if outPolicy != nil {
outPolicy.InboundFee.WhenSome(
func(fee lnwire.Fee) {
inboundFee = fee
},
)
}
directedChannel := &DirectedChannel{
ChannelID: e.ChannelID,
IsNode1: nodePub == e.NodeKey1Bytes,
OtherNode: e.NodeKey2Bytes,
Capacity: e.Capacity,
OutPolicySet: outPolicy != nil,
InPolicy: cachedInPolicy,
InboundFee: inboundFee,
}
if nodePub == e.NodeKey2Bytes {
directedChannel.OtherNode = e.NodeKey1Bytes
}
channels[e.ChannelID] = directedChannel
}
return cb(nodePub, channels)
type nodeCachedBatchData struct {
features map[int64][]int
chanBatchData *batchChannelData
chanMap map[int64][]sqlc.ListChannelsForNodeIDsRow
}
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
return forEachNodeCacheable(
ctx, s.cfg.QueryCfg, db,
func(nodeID int64, nodePub route.Vertex,
features *lnwire.FeatureVector) error {
// pageQueryFunc is used to query the next page of nodes.
pageQueryFunc := func(ctx context.Context, lastID int64,
limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
return handleNode(db, nodeID, nodePub, features)
return db.ListNodeIDsAndPubKeys(
ctx, sqlc.ListNodeIDsAndPubKeysParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: limit,
},
)
}
// batchDataFunc is then used to batch load the data required
// for each page of nodes.
batchDataFunc := func(ctx context.Context,
nodeIDs []int64) (*nodeCachedBatchData, error) {
// Batch load node features.
nodeFeatures, err := batchLoadNodeFeaturesHelper(
ctx, s.cfg.QueryCfg, db, nodeIDs,
)
if err != nil {
return nil, fmt.Errorf("unable to batch load "+
"node features: %w", err)
}
// Batch load ALL unique channels for ALL nodes in this
// page.
allChannels, err := db.ListChannelsForNodeIDs(
ctx, sqlc.ListChannelsForNodeIDsParams{
Version: int16(ProtocolV1),
Node1Ids: nodeIDs,
Node2Ids: nodeIDs,
},
)
if err != nil {
return nil, fmt.Errorf("unable to batch "+
"fetch channels for nodes: %w", err)
}
// Deduplicate channels and collect IDs.
var (
allChannelIDs []int64
allPolicyIDs []int64
)
uniqueChannels := make(
map[int64]sqlc.ListChannelsForNodeIDsRow,
)
for _, channel := range allChannels {
channelID := channel.GraphChannel.ID
// Only process each unique channel once.
_, exists := uniqueChannels[channelID]
if exists {
continue
}
uniqueChannels[channelID] = channel
allChannelIDs = append(allChannelIDs, channelID)
if channel.Policy1ID.Valid {
allPolicyIDs = append(
allPolicyIDs,
channel.Policy1ID.Int64,
)
}
if channel.Policy2ID.Valid {
allPolicyIDs = append(
allPolicyIDs,
channel.Policy2ID.Int64,
)
}
}
// Batch load channel data for all unique channels.
channelBatchData, err := batchLoadChannelData(
ctx, s.cfg.QueryCfg, db, allChannelIDs,
allPolicyIDs,
)
if err != nil {
return nil, fmt.Errorf("unable to batch "+
"load channel data: %w", err)
}
// Create map of node ID to channels that involve this
// node.
nodeIDSet := make(map[int64]bool)
for _, nodeID := range nodeIDs {
nodeIDSet[nodeID] = true
}
nodeChannelMap := make(
map[int64][]sqlc.ListChannelsForNodeIDsRow,
)
for _, channel := range uniqueChannels {
// Add channel to both nodes if they're in our
// current page.
node1 := channel.GraphChannel.NodeID1
if nodeIDSet[node1] {
nodeChannelMap[node1] = append(
nodeChannelMap[node1], channel,
)
}
node2 := channel.GraphChannel.NodeID2
if nodeIDSet[node2] {
nodeChannelMap[node2] = append(
nodeChannelMap[node2], channel,
)
}
}
return &nodeCachedBatchData{
features: nodeFeatures,
chanBatchData: channelBatchData,
chanMap: nodeChannelMap,
}, nil
}
// processItem is used to process each node in the current page.
processItem := func(ctx context.Context,
nodeData sqlc.ListNodeIDsAndPubKeysRow,
batchData *nodeCachedBatchData) error {
// Build feature vector for this node.
fv := lnwire.EmptyFeatureVector()
features, exists := batchData.features[nodeData.ID]
if exists {
for _, bit := range features {
fv.Set(lnwire.FeatureBit(bit))
}
}
var nodePub route.Vertex
copy(nodePub[:], nodeData.PubKey)
nodeChannels := batchData.chanMap[nodeData.ID]
toNodeCallback := func() route.Vertex {
return nodePub
}
// Build cached channels map for this node.
channels := make(map[uint64]*DirectedChannel)
for _, channelRow := range nodeChannels {
directedChan, err := buildDirectedChannel(
s.cfg.ChainHash, nodeData.ID, nodePub,
channelRow, batchData.chanBatchData, fv,
toNodeCallback,
)
if err != nil {
return err
}
channels[directedChan.ChannelID] = directedChan
}
return cb(nodePub, channels)
}
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
return node.ID
},
func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
error) {
return node.ID, nil
},
batchDataFunc, processItem,
)
}, reset)
}
@@ -4411,6 +4485,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
return policy1, policy2, nil
case sqlc.ListChannelsForNodeIDsRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.GraphChannelPolicy{
ID: r.Policy1ID.Int64,
Version: r.Policy1Version.Int16,
ChannelID: r.GraphChannel.ID,
NodeID: r.Policy1NodeID.Int64,
Timelock: r.Policy1Timelock.Int32,
FeePpm: r.Policy1FeePpm.Int64,
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
LastUpdate: r.Policy1LastUpdate,
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
Disabled: r.Policy1Disabled,
MessageFlags: r.Policy1MessageFlags,
ChannelFlags: r.Policy1ChannelFlags,
Signature: r.Policy1Signature,
}
}
if r.Policy2ID.Valid {
policy2 = &sqlc.GraphChannelPolicy{
ID: r.Policy2ID.Int64,
Version: r.Policy2Version.Int16,
ChannelID: r.GraphChannel.ID,
NodeID: r.Policy2NodeID.Int64,
Timelock: r.Policy2Timelock.Int32,
FeePpm: r.Policy2FeePpm.Int64,
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
LastUpdate: r.Policy2LastUpdate,
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
Disabled: r.Policy2Disabled,
MessageFlags: r.Policy2MessageFlags,
ChannelFlags: r.Policy2ChannelFlags,
Signature: r.Policy2Signature,
}
}
return policy1, policy2, nil
case sqlc.ListChannelsByNodeIDRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.GraphChannelPolicy{
@@ -5118,3 +5236,83 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
collectFunc, batchDataFunc, processItem,
)
}
// buildDirectedChannel builds a DirectedChannel instance from the provided
// data.
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
channelBatchData *batchChannelData, features *lnwire.FeatureVector,
toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
node1, node2, err := buildNodeVertices(
channelRow.Node1Pubkey, channelRow.Node2Pubkey,
)
if err != nil {
return nil, fmt.Errorf("unable to build node vertices: %w", err)
}
edge, err := buildEdgeInfoWithBatchData(
chain, channelRow.GraphChannel, node1, node2, channelBatchData,
)
if err != nil {
return nil, fmt.Errorf("unable to build channel info: %w", err)
}
dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
if err != nil {
return nil, fmt.Errorf("unable to extract channel policies: %w",
err)
}
p1, p2, err := buildChanPoliciesWithBatchData(
dbPol1, dbPol2, edge.ChannelID, node1, node2,
channelBatchData,
)
if err != nil {
return nil, fmt.Errorf("unable to build channel policies: %w",
err)
}
// Determine outgoing and incoming policy for this specific node.
p1ToNode := channelRow.GraphChannel.NodeID2
p2ToNode := channelRow.GraphChannel.NodeID1
outPolicy, inPolicy := p1, p2
if (p1 != nil && p1ToNode == nodeID) ||
(p2 != nil && p2ToNode != nodeID) {
outPolicy, inPolicy = p2, p1
}
// Build cached policy.
var cachedInPolicy *models.CachedEdgePolicy
if inPolicy != nil {
cachedInPolicy = models.NewCachedPolicy(inPolicy)
cachedInPolicy.ToNodePubKey = toNodeCallback
cachedInPolicy.ToNodeFeatures = features
}
// Extract inbound fee.
var inboundFee lnwire.Fee
if outPolicy != nil {
outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
inboundFee = fee
})
}
// Build directed channel.
directedChannel := &DirectedChannel{
ChannelID: edge.ChannelID,
IsNode1: nodePub == edge.NodeKey1Bytes,
OtherNode: edge.NodeKey2Bytes,
Capacity: edge.Capacity,
OutPolicySet: outPolicy != nil,
InPolicy: cachedInPolicy,
InboundFee: inboundFee,
}
if nodePub == edge.NodeKey2Bytes {
directedChannel.OtherNode = edge.NodeKey1Bytes
}
return directedChannel, nil
}