mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-31 17:51:33 +02:00
graph/db: update ForEachChannel to use new sqldb helper
Refactor to let ForEachChannel make use of the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper.
This commit is contained in:
@@ -1297,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context,
|
|||||||
cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
|
cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy,
|
||||||
*models.ChannelEdgePolicy) error, reset func()) error {
|
*models.ChannelEdgePolicy) error, reset func()) error {
|
||||||
|
|
||||||
handleChannel := func(db SQLQueries, batchData *batchChannelData,
|
|
||||||
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
|
|
||||||
|
|
||||||
node1, node2, err := buildNodeVertices(
|
|
||||||
row.Node1Pubkey, row.Node2Pubkey,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to build node vertices: %w",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
|
|
||||||
edge, err := buildEdgeInfoWithBatchData(
|
|
||||||
s.cfg.ChainHash, row.GraphChannel, node1, node2,
|
|
||||||
batchData,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to build channel info: %w",
|
|
||||||
err)
|
|
||||||
}
|
|
||||||
|
|
||||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to extract channel "+
|
|
||||||
"policies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
p1, p2, err := buildChanPoliciesWithBatchData(
|
|
||||||
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to build channel "+
|
|
||||||
"policies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = cb(edge, p1, p2)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("callback failed for channel "+
|
|
||||||
"id=%d: %w", edge.ChannelID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||||
lastID := int64(-1)
|
return forEachChannelWithPolicies(ctx, db, s.cfg, cb)
|
||||||
for {
|
|
||||||
//nolint:ll
|
|
||||||
rows, err := db.ListChannelsWithPoliciesPaginated(
|
|
||||||
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
|
|
||||||
Version: int16(ProtocolV1),
|
|
||||||
ID: lastID,
|
|
||||||
Limit: s.cfg.QueryCfg.MaxPageSize,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(rows) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// Collect the channel & policy IDs that we want to
|
|
||||||
// do a batch collection for.
|
|
||||||
var (
|
|
||||||
channelIDs = make([]int64, len(rows))
|
|
||||||
policyIDs = make([]int64, 0, len(rows)*2)
|
|
||||||
)
|
|
||||||
for i, row := range rows {
|
|
||||||
channelIDs[i] = row.GraphChannel.ID
|
|
||||||
|
|
||||||
// Extract policy IDs from the row
|
|
||||||
dbPol1, dbPol2, err := extractChannelPolicies(
|
|
||||||
row,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to extract "+
|
|
||||||
"channel policies: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbPol1 != nil {
|
|
||||||
policyIDs = append(policyIDs, dbPol1.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dbPol2 != nil {
|
|
||||||
policyIDs = append(policyIDs, dbPol2.ID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
batchData, err := batchLoadChannelData(
|
|
||||||
ctx, s.cfg.QueryCfg, db, channelIDs,
|
|
||||||
policyIDs,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to batch load "+
|
|
||||||
"channel data: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, row := range rows {
|
|
||||||
err := handleChannel(db, batchData, row)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
lastID = row.GraphChannel.ID
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}, reset)
|
}, reset)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -5103,3 +4996,117 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig,
|
|||||||
collectFunc, batchQueryFunc, processItem,
|
collectFunc, batchQueryFunc, processItem,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// forEachChannelWithPolicies executes a paginated query to process each channel
|
||||||
|
// with policies in the graph.
|
||||||
|
func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
|
||||||
|
cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo,
|
||||||
|
*models.ChannelEdgePolicy,
|
||||||
|
*models.ChannelEdgePolicy) error) error {
|
||||||
|
|
||||||
|
type channelBatchIDs struct {
|
||||||
|
channelID int64
|
||||||
|
policyIDs []int64
|
||||||
|
}
|
||||||
|
|
||||||
|
pageQueryFunc := func(ctx context.Context, lastID int64,
|
||||||
|
limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow,
|
||||||
|
error) {
|
||||||
|
|
||||||
|
return db.ListChannelsWithPoliciesPaginated(
|
||||||
|
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
|
||||||
|
Version: int16(ProtocolV1),
|
||||||
|
ID: lastID,
|
||||||
|
Limit: limit,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractPageCursor := func(
|
||||||
|
row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 {
|
||||||
|
|
||||||
|
return row.GraphChannel.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) (
|
||||||
|
channelBatchIDs, error) {
|
||||||
|
|
||||||
|
ids := channelBatchIDs{
|
||||||
|
channelID: row.GraphChannel.ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract policy IDs from the row.
|
||||||
|
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||||
|
if err != nil {
|
||||||
|
return ids, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if dbPol1 != nil {
|
||||||
|
ids.policyIDs = append(ids.policyIDs, dbPol1.ID)
|
||||||
|
}
|
||||||
|
if dbPol2 != nil {
|
||||||
|
ids.policyIDs = append(ids.policyIDs, dbPol2.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ids, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
batchDataFunc := func(ctx context.Context,
|
||||||
|
allIDs []channelBatchIDs) (*batchChannelData, error) {
|
||||||
|
|
||||||
|
// Separate channel IDs from policy IDs.
|
||||||
|
var (
|
||||||
|
channelIDs = make([]int64, len(allIDs))
|
||||||
|
policyIDs = make([]int64, 0, len(allIDs)*2)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, ids := range allIDs {
|
||||||
|
channelIDs[i] = ids.channelID
|
||||||
|
policyIDs = append(policyIDs, ids.policyIDs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return batchLoadChannelData(
|
||||||
|
ctx, cfg.QueryCfg, db, channelIDs, policyIDs,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
processItem := func(ctx context.Context,
|
||||||
|
row sqlc.ListChannelsWithPoliciesPaginatedRow,
|
||||||
|
batchData *batchChannelData) error {
|
||||||
|
|
||||||
|
node1, node2, err := buildNodeVertices(
|
||||||
|
row.Node1Pubkey, row.Node2Pubkey,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
edge, err := buildEdgeInfoWithBatchData(
|
||||||
|
cfg.ChainHash, row.GraphChannel, node1, node2,
|
||||||
|
batchData,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to build channel info: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
p1, p2, err := buildChanPoliciesWithBatchData(
|
||||||
|
dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return processChannel(edge, p1, p2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
|
||||||
|
ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor,
|
||||||
|
collectFunc, batchDataFunc, processItem,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
Reference in New Issue
Block a user