mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-25 13:12:11 +02:00
graph/db: update ForEachChannel to use new sqldb helper
Refactor to let ForEachChannel make use of the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper.
This commit is contained in:
@@ -1297,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context,
|
||||
cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
|
||||
*models.ChannelEdgePolicy) error, reset func()) error {
|
||||
|
||||
handleChannel := func(db SQLQueries, batchData *batchChannelData,
|
||||
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
|
||||
|
||||
node1, node2, err := buildNodeVertices(
|
||||
row.Node1Pubkey, row.Node2Pubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build node vertices: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
edge, err := buildEdgeInfoWithBatchData(
|
||||
s.cfg.ChainHash, row.GraphChannel, node1, node2,
|
||||
batchData,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build channel info: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to extract channel "+
|
||||
"policies: %w", err)
|
||||
}
|
||||
|
||||
p1, p2, err := buildChanPoliciesWithBatchData(
|
||||
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build channel "+
|
||||
"policies: %w", err)
|
||||
}
|
||||
|
||||
err = cb(edge, p1, p2)
|
||||
if err != nil {
|
||||
return fmt.Errorf("callback failed for channel "+
|
||||
"id=%d: %w", edge.ChannelID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||
lastID := int64(-1)
|
||||
for {
|
||||
//nolint:ll
|
||||
rows, err := db.ListChannelsWithPoliciesPaginated(
|
||||
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
|
||||
Version: int16(ProtocolV1),
|
||||
ID: lastID,
|
||||
Limit: s.cfg.QueryCfg.MaxPageSize,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(rows) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
// Collect the channel & policy IDs that we want to
|
||||
// do a batch collection for.
|
||||
var (
|
||||
channelIDs = make([]int64, len(rows))
|
||||
policyIDs = make([]int64, 0, len(rows)*2)
|
||||
)
|
||||
for i, row := range rows {
|
||||
channelIDs[i] = row.GraphChannel.ID
|
||||
|
||||
// Extract policy IDs from the row
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(
|
||||
row,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to extract "+
|
||||
"channel policies: %w", err)
|
||||
}
|
||||
|
||||
if dbPol1 != nil {
|
||||
policyIDs = append(policyIDs, dbPol1.ID)
|
||||
}
|
||||
|
||||
if dbPol2 != nil {
|
||||
policyIDs = append(policyIDs, dbPol2.ID)
|
||||
}
|
||||
}
|
||||
|
||||
batchData, err := batchLoadChannelData(
|
||||
ctx, s.cfg.QueryCfg, db, channelIDs,
|
||||
policyIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to batch load "+
|
||||
"channel data: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
err := handleChannel(db, batchData, row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lastID = row.GraphChannel.ID
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
|
||||
}, reset)
|
||||
}
|
||||
|
||||
@@ -5103,3 +4996,117 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
|
||||
collectFunc, batchQueryFunc, processItem,
|
||||
)
|
||||
}
|
||||
|
||||
// forEachChannelWithPolicies executes a paginated query to process each channel
|
||||
// with policies in the graph.
|
||||
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
|
||||
cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
|
||||
*models.ChannelEdgePolicy,
|
||||
*models.ChannelEdgePolicy) error) error {
|
||||
|
||||
type channelBatchIDs struct {
|
||||
channelID int64
|
||||
policyIDs []int64
|
||||
}
|
||||
|
||||
pageQueryFunc := func(ctx context.Context, lastID int64,
|
||||
limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
|
||||
error) {
|
||||
|
||||
return db.ListChannelsWithPoliciesPaginated(
|
||||
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
|
||||
Version: int16(ProtocolV1),
|
||||
ID: lastID,
|
||||
Limit: limit,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
extractPageCursor := func(
|
||||
row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
|
||||
|
||||
return row.GraphChannel.ID
|
||||
}
|
||||
|
||||
collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
|
||||
channelBatchIDs, error) {
|
||||
|
||||
ids := channelBatchIDs{
|
||||
channelID: row.GraphChannel.ID,
|
||||
}
|
||||
|
||||
// Extract policy IDs from the row.
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||
if err != nil {
|
||||
return ids, err
|
||||
}
|
||||
|
||||
if dbPol1 != nil {
|
||||
ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
|
||||
}
|
||||
if dbPol2 != nil {
|
||||
ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
batchDataFunc := func(ctx context.Context,
|
||||
allIDs []channelBatchIDs) (*batchChannelData, error) {
|
||||
|
||||
// Separate channel IDs from policy IDs.
|
||||
var (
|
||||
channelIDs = make([]int64, len(allIDs))
|
||||
policyIDs = make([]int64, 0, len(allIDs)*2)
|
||||
)
|
||||
|
||||
for i, ids := range allIDs {
|
||||
channelIDs[i] = ids.channelID
|
||||
policyIDs = append(policyIDs, ids.policyIDs...)
|
||||
}
|
||||
|
||||
return batchLoadChannelData(
|
||||
ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
|
||||
)
|
||||
}
|
||||
|
||||
processItem := func(ctx context.Context,
|
||||
row sqlc.ListChannelsWithPoliciesPaginatedRow,
|
||||
batchData *batchChannelData) error {
|
||||
|
||||
node1, node2, err := buildNodeVertices(
|
||||
row.Node1Pubkey, row.Node2Pubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
edge, err := buildEdgeInfoWithBatchData(
|
||||
cfg.ChainHash, row.GraphChannel, node1, node2,
|
||||
batchData,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build channel info: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p1, p2, err := buildChanPoliciesWithBatchData(
|
||||
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return processChannel(edge, p1, p2)
|
||||
}
|
||||
|
||||
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
|
||||
ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
|
||||
collectFunc, batchDataFunc, processItem,
|
||||
)
|
||||
}
|
||||
|
Reference in New Issue
Block a user