graph/db: update ForEachChannel to use new sqldb helper

Refactor to let ForEachChannel make use of the new
sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper.
This commit is contained in:
Elle Mouton
2025-07-31 12:54:13 +02:00
parent 1219cdb7f1
commit ae13158b68

View File

@@ -1297,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context,
cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error, reset func()) error {
handleChannel := func(db SQLQueries, batchData *batchChannelData,
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return fmt.Errorf("unable to build node vertices: %w",
err)
}
edge, err := buildEdgeInfoWithBatchData(
s.cfg.ChainHash, row.GraphChannel, node1, node2,
batchData,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
p1, p2, err := buildChanPoliciesWithBatchData(
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
err = cb(edge, p1, p2)
if err != nil {
return fmt.Errorf("callback failed for channel "+
"id=%d: %w", edge.ChannelID, err)
}
return nil
}
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
lastID := int64(-1)
for {
//nolint:ll
rows, err := db.ListChannelsWithPoliciesPaginated(
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: s.cfg.QueryCfg.MaxPageSize,
},
)
if err != nil {
return err
}
if len(rows) == 0 {
break
}
// Collect the channel & policy IDs that we want to
// do a batch collection for.
var (
channelIDs = make([]int64, len(rows))
policyIDs = make([]int64, 0, len(rows)*2)
)
for i, row := range rows {
channelIDs[i] = row.GraphChannel.ID
// Extract policy IDs from the row
dbPol1, dbPol2, err := extractChannelPolicies(
row,
)
if err != nil {
return fmt.Errorf("unable to extract "+
"channel policies: %w", err)
}
if dbPol1 != nil {
policyIDs = append(policyIDs, dbPol1.ID)
}
if dbPol2 != nil {
policyIDs = append(policyIDs, dbPol2.ID)
}
}
batchData, err := batchLoadChannelData(
ctx, s.cfg.QueryCfg, db, channelIDs,
policyIDs,
)
if err != nil {
return fmt.Errorf("unable to batch load "+
"channel data: %w", err)
}
for _, row := range rows {
err := handleChannel(db, batchData, row)
if err != nil {
return err
}
lastID = row.GraphChannel.ID
}
}
return nil
return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
}, reset)
}
@@ -5103,3 +4996,117 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
collectFunc, batchQueryFunc, processItem,
)
}
// forEachChannelWithPolicies executes a paginated query to process each channel
// with policies in the graph.
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy,
*models.ChannelEdgePolicy) error) error {
type channelBatchIDs struct {
channelID int64
policyIDs []int64
}
pageQueryFunc := func(ctx context.Context, lastID int64,
limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
error) {
return db.ListChannelsWithPoliciesPaginated(
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: limit,
},
)
}
extractPageCursor := func(
row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
return row.GraphChannel.ID
}
collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
channelBatchIDs, error) {
ids := channelBatchIDs{
channelID: row.GraphChannel.ID,
}
// Extract policy IDs from the row.
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return ids, err
}
if dbPol1 != nil {
ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
}
if dbPol2 != nil {
ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
}
return ids, nil
}
batchDataFunc := func(ctx context.Context,
allIDs []channelBatchIDs) (*batchChannelData, error) {
// Separate channel IDs from policy IDs.
var (
channelIDs = make([]int64, len(allIDs))
policyIDs = make([]int64, 0, len(allIDs)*2)
)
for i, ids := range allIDs {
channelIDs[i] = ids.channelID
policyIDs = append(policyIDs, ids.policyIDs...)
}
return batchLoadChannelData(
ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
)
}
processItem := func(ctx context.Context,
row sqlc.ListChannelsWithPoliciesPaginatedRow,
batchData *batchChannelData) error {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}
edge, err := buildEdgeInfoWithBatchData(
cfg.ChainHash, row.GraphChannel, node1, node2,
batchData,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return err
}
p1, p2, err := buildChanPoliciesWithBatchData(
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
)
if err != nil {
return err
}
return processChannel(edge, p1, p2)
}
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
collectFunc, batchDataFunc, processItem,
)
}