graph/db: batch fetch channel data in forEachNodeChannel

Update the forEachNodeChannel helper to batch fetch channel data.
This commit is contained in:
Elle Mouton
2025-08-01 15:30:33 +02:00
parent dc6f9256bc
commit 0850bf4781

View File

@@ -757,7 +757,7 @@ func (s *SQLStore) ForEachSourceNodeChannel(ctx context.Context,
}
return forEachNodeChannel(
ctx, db, s.cfg.ChainHash, nodeID,
ctx, db, s.cfg, nodeID,
func(info *models.ChannelEdgeInfo,
outPolicy *models.ChannelEdgePolicy,
_ *models.ChannelEdgePolicy) error {
@@ -815,7 +815,7 @@ func (s *SQLStore) ForEachNode(ctx context.Context,
node *models.LightningNode) error {
return cb(newSQLGraphNodeTx(
db, s.cfg.ChainHash, dbNodeID, node,
db, s.cfg, dbNodeID, node,
))
},
)
@@ -825,24 +825,24 @@ func (s *SQLStore) ForEachNode(ctx context.Context,
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
// SQLStore and a SQL transaction.
type sqlGraphNodeTx struct {
db SQLQueries
id int64
node *models.LightningNode
chain chainhash.Hash
db SQLQueries
id int64
node *models.LightningNode
cfg *SQLStoreConfig
}
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
// interface.
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
func newSQLGraphNodeTx(db SQLQueries, cfg *SQLStoreConfig,
id int64, node *models.LightningNode) *sqlGraphNodeTx {
return &sqlGraphNodeTx{
db: db,
chain: chain,
id: id,
node: node,
db: db,
cfg: cfg,
id: id,
node: node,
}
}
@@ -862,7 +862,7 @@ func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
ctx := context.TODO()
return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
return forEachNodeChannel(ctx, s.db, s.cfg, s.id, cb)
}
// FetchNode fetches the node with the given pub key under the same transaction
@@ -879,7 +879,7 @@ func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
nodePub, err)
}
return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
return newSQLGraphNodeTx(s.db, s.cfg, id, node), nil
}
// ForEachNodeDirectedChannel iterates through all channels of a given node,
@@ -952,9 +952,7 @@ func (s *SQLStore) ForEachNodeChannel(ctx context.Context, nodePub route.Vertex,
return fmt.Errorf("unable to fetch node: %w", err)
}
return forEachNodeChannel(
ctx, db, s.cfg.ChainHash, dbNode.ID, cb,
)
return forEachNodeChannel(ctx, db, s.cfg, dbNode.ID, cb)
}, reset)
}
@@ -3073,11 +3071,11 @@ func forEachNodeCacheable(ctx context.Context, cfg *sqldb.QueryConfig,
// edge information, the outgoing policy and the incoming policy for the
// channel and node combo.
func forEachNodeChannel(ctx context.Context, db SQLQueries,
chain chainhash.Hash, id int64, cb func(*models.ChannelEdgeInfo,
cfg *SQLStoreConfig, id int64, cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error) error {
// Get all the V1 channels for this node.Add commentMore actions
// Get all the V1 channels for this node.
rows, err := db.ListChannelsByNodeID(
ctx, sqlc.ListChannelsByNodeIDParams{
Version: int16(ProtocolV1),
@@ -3088,6 +3086,29 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries,
return fmt.Errorf("unable to fetch channels: %w", err)
}
// Collect all the channel and policy IDs.
var (
chanIDs = make([]int64, 0, len(rows))
policyIDs = make([]int64, 0, 2*len(rows))
)
for _, row := range rows {
chanIDs = append(chanIDs, row.GraphChannel.ID)
if row.Policy1ID.Valid {
policyIDs = append(policyIDs, row.Policy1ID.Int64)
}
if row.Policy2ID.Valid {
policyIDs = append(policyIDs, row.Policy2ID.Int64)
}
}
batchData, err := batchLoadChannelData(
ctx, cfg.QueryCfg, db, chanIDs, policyIDs,
)
if err != nil {
return fmt.Errorf("unable to batch load channel data: %w", err)
}
// Call the call-back for each channel and its known policies.
for _, row := range rows {
node1, node2, err := buildNodeVertices(
@@ -3098,8 +3119,9 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries,
err)
}
edge, err := getAndBuildEdgeInfo(
ctx, db, chain, row.GraphChannel, node1, node2,
edge, err := buildEdgeInfoWithBatchData(
cfg.ChainHash, row.GraphChannel, node1, node2,
batchData,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
@@ -3112,8 +3134,8 @@ func forEachNodeChannel(ctx context.Context, db SQLQueries,
"policies: %w", err)
}
p1, p2, err := getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
p1, p2, err := buildChanPoliciesWithBatchData(
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
)
if err != nil {
return fmt.Errorf("unable to build channel "+