diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 886d1ddcc..02a3f6372 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1297,115 +1297,8 @@ func (s *SQLStore) ForEachChannel(ctx context.Context, cb func(*models.ChannelEdgeInfo, *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error, reset func()) error { - handleChannel := func(db SQLQueries, batchData *batchChannelData, - row sqlc.ListChannelsWithPoliciesPaginatedRow) error { - - node1, node2, err := buildNodeVertices( - row.Node1Pubkey, row.Node2Pubkey, - ) - if err != nil { - return fmt.Errorf("unable to build node vertices: %w", - err) - } - - edge, err := buildEdgeInfoWithBatchData( - s.cfg.ChainHash, row.GraphChannel, node1, node2, - batchData, - ) - if err != nil { - return fmt.Errorf("unable to build channel info: %w", - err) - } - - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) - } - - p1, p2, err := buildChanPoliciesWithBatchData( - dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) - } - - err = cb(edge, p1, p2) - if err != nil { - return fmt.Errorf("callback failed for channel "+ - "id=%d: %w", edge.ChannelID, err) - } - - return nil - } - return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - lastID := int64(-1) - for { - //nolint:ll - rows, err := db.ListChannelsWithPoliciesPaginated( - ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{ - Version: int16(ProtocolV1), - ID: lastID, - Limit: s.cfg.QueryCfg.MaxPageSize, - }, - ) - if err != nil { - return err - } - - if len(rows) == 0 { - break - } - - // Collect the channel & policy IDs that we want to - // do a batch collection for. - var ( - channelIDs = make([]int64, len(rows)) - policyIDs = make([]int64, 0, len(rows)*2) - ) - for i, row := range rows { - channelIDs[i] = row.GraphChannel.ID - - // Extract policy IDs from the row - dbPol1, dbPol2, err := extractChannelPolicies( - row, - ) - if err != nil { - return fmt.Errorf("unable to extract "+ - "channel policies: %w", err) - } - - if dbPol1 != nil { - policyIDs = append(policyIDs, dbPol1.ID) - } - - if dbPol2 != nil { - policyIDs = append(policyIDs, dbPol2.ID) - } - } - - batchData, err := batchLoadChannelData( - ctx, s.cfg.QueryCfg, db, channelIDs, - policyIDs, - ) - if err != nil { - return fmt.Errorf("unable to batch load "+ - "channel data: %w", err) - } - - for _, row := range rows { - err := handleChannel(db, batchData, row) - if err != nil { - return err - } - - lastID = row.GraphChannel.ID - } - } - - return nil + return forEachChannelWithPolicies(ctx, db, s.cfg, cb) }, reset) } @@ -5103,3 +4996,117 @@ func forEachNodePaginated(ctx context.Context, cfg *sqldb.QueryConfig, collectFunc, batchQueryFunc, processItem, ) } + +// forEachChannelWithPolicies executes a paginated query to process each channel +// with policies in the graph. +func forEachChannelWithPolicies(ctx context.Context, db SQLQueries, + cfg *SQLStoreConfig, processChannel func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, + *models.ChannelEdgePolicy) error) error { + + type channelBatchIDs struct { + channelID int64 + policyIDs []int64 + } + + pageQueryFunc := func(ctx context.Context, lastID int64, + limit int32) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, + error) { + + return db.ListChannelsWithPoliciesPaginated( + ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: limit, + }, + ) + } + + extractPageCursor := func( + row sqlc.ListChannelsWithPoliciesPaginatedRow) int64 { + + return row.GraphChannel.ID + } + + collectFunc := func(row sqlc.ListChannelsWithPoliciesPaginatedRow) ( + channelBatchIDs, error) { + + ids := channelBatchIDs{ + channelID: row.GraphChannel.ID, + } + + // Extract policy IDs from the row. + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return ids, err + } + + if dbPol1 != nil { + ids.policyIDs = append(ids.policyIDs, dbPol1.ID) + } + if dbPol2 != nil { + ids.policyIDs = append(ids.policyIDs, dbPol2.ID) + } + + return ids, nil + } + + batchDataFunc := func(ctx context.Context, + allIDs []channelBatchIDs) (*batchChannelData, error) { + + // Separate channel IDs from policy IDs. + var ( + channelIDs = make([]int64, len(allIDs)) + policyIDs = make([]int64, 0, len(allIDs)*2) + ) + + for i, ids := range allIDs { + channelIDs[i] = ids.channelID + policyIDs = append(policyIDs, ids.policyIDs...) + } + + return batchLoadChannelData( + ctx, cfg.QueryCfg, db, channelIDs, policyIDs, + ) + } + + processItem := func(ctx context.Context, + row sqlc.ListChannelsWithPoliciesPaginatedRow, + batchData *batchChannelData) error { + + node1, node2, err := buildNodeVertices( + row.Node1Pubkey, row.Node2Pubkey, + ) + if err != nil { + return err + } + + edge, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.GraphChannel, node1, node2, + batchData, + ) + if err != nil { + return fmt.Errorf("unable to build channel info: %w", + err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return err + } + + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, edge.ChannelID, node1, node2, batchData, + ) + if err != nil { + return err + } + + return processChannel(edge, p1, p2) + } + + return sqldb.ExecuteCollectAndBatchWithSharedDataQuery( + ctx, cfg.QueryCfg, int64(-1), pageQueryFunc, extractPageCursor, + collectFunc, batchDataFunc, processItem, + ) +}