diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 95b4ad2ad..d6d7c54ad 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -911,6 +911,7 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime, edges []ChannelEdge hits int ) + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { rows, err := db.GetChannelsByPolicyLastUpdateRange( ctx, sqlc.GetChannelsByPolicyLastUpdateRangeParams{ @@ -923,72 +924,61 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime, return err } + if len(rows) == 0 { + return nil + } + + // We'll pre-allocate the slices and maps here with a best + // effort size in order to avoid unnecessary allocations later + // on. + uncachedRows := make( + []sqlc.GetChannelsByPolicyLastUpdateRangeRow, 0, + len(rows), + ) + edgesToCache = make(map[uint64]ChannelEdge, len(rows)) + edgesSeen = make(map[uint64]struct{}, len(rows)) + edges = make([]ChannelEdge, 0, len(rows)) + + // Separate cached from non-cached channels since we will only + // batch load the data for the ones we haven't cached yet. for _, row := range rows { - // If we've already retrieved the info and policies for - // this edge, then we can skip it as we don't need to do - // so again. chanIDInt := byteOrder.Uint64(row.GraphChannel.Scid) + + // Skip duplicates. if _, ok := edgesSeen[chanIDInt]; ok { continue } + edgesSeen[chanIDInt] = struct{}{} + // Check cache first. if channel, ok := s.chanCache.get(chanIDInt); ok { hits++ - edgesSeen[chanIDInt] = struct{}{} edges = append(edges, channel) - continue } - node1, node2, err := buildNodes( - ctx, db, row.GraphNode, row.GraphNode_2, - ) - if err != nil { - return err - } - - channel, err := getAndBuildEdgeInfo( - ctx, db, s.cfg.ChainHash, row.GraphChannel, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "info: %w", err) - } - - dbPol1, dbPol2, err := extractChannelPolicies(row) - if err != nil { - return fmt.Errorf("unable to extract channel "+ - "policies: %w", err) - } - - p1, p2, err := getAndBuildChanPolicies( - ctx, db, dbPol1, dbPol2, channel.ChannelID, - node1.PubKeyBytes, node2.PubKeyBytes, - ) - if err != nil { - return fmt.Errorf("unable to build channel "+ - "policies: %w", err) - } - - edgesSeen[chanIDInt] = struct{}{} - chanEdge := ChannelEdge{ - Info: channel, - Policy1: p1, - Policy2: p2, - Node1: node1, - Node2: node2, - } - edges = append(edges, chanEdge) - edgesToCache[chanIDInt] = chanEdge + // Mark this row as one we need to batch load data for. + uncachedRows = append(uncachedRows, row) } + // If there are no uncached rows, then we can return early. + if len(uncachedRows) == 0 { + return nil + } + + // Batch load data for all uncached channels. + newEdges, err := batchBuildChannelEdges( + ctx, s.cfg, db, uncachedRows, + ) + if err != nil { + return fmt.Errorf("unable to batch build channel "+ + "edges: %w", err) + } + + edges = append(edges, newEdges...) + return nil - }, func() { - edgesSeen = make(map[uint64]struct{}) - edgesToCache = make(map[uint64]ChannelEdge) - edges = nil - }) + }, sqldb.NoOpReset) if err != nil { return nil, fmt.Errorf("unable to fetch channels: %w", err) } @@ -5298,3 +5288,121 @@ func buildDirectedChannel(chain chainhash.Hash, nodeID int64, return directedChannel, nil } + +// batchBuildChannelEdges builds a slice of ChannelEdge instances from the +// provided rows. It uses batch loading for channels, policies, and nodes. +func batchBuildChannelEdges[T sqlc.ChannelAndNodes](ctx context.Context, + cfg *SQLStoreConfig, db SQLQueries, rows []T) ([]ChannelEdge, error) { + + var ( + channelIDs = make([]int64, len(rows)) + policyIDs = make([]int64, 0, len(rows)*2) + nodeIDs = make([]int64, 0, len(rows)*2) + + // nodeIDSet is used to ensure we only collect unique node IDs. + nodeIDSet = make(map[int64]bool) + + // edges will hold the final channel edges built from the rows. + edges = make([]ChannelEdge, 0, len(rows)) + ) + + // Collect all IDs needed for batch loading. + for i, row := range rows { + channelIDs[i] = row.Channel().ID + + // Collect policy IDs + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return nil, fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + if dbPol1 != nil { + policyIDs = append(policyIDs, dbPol1.ID) + } + if dbPol2 != nil { + policyIDs = append(policyIDs, dbPol2.ID) + } + + var ( + node1ID = row.Node1().ID + node2ID = row.Node2().ID + ) + + // Collect unique node IDs. + if !nodeIDSet[node1ID] { + nodeIDs = append(nodeIDs, node1ID) + nodeIDSet[node1ID] = true + } + + if !nodeIDSet[node2ID] { + nodeIDs = append(nodeIDs, node2ID) + nodeIDSet[node2ID] = true + } + } + + // Batch the data for all the channels and policies. + channelBatchData, err := batchLoadChannelData( + ctx, cfg.QueryCfg, db, channelIDs, policyIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load channel and "+ + "policy data: %w", err) + } + + // Batch the data for all the nodes. + nodeBatchData, err := batchLoadNodeData(ctx, cfg.QueryCfg, db, nodeIDs) + if err != nil { + return nil, fmt.Errorf("unable to batch load node data: %w", + err) + } + + // Build all channel edges using batch data. + for _, row := range rows { + // Build nodes using batch data. + node1, err := buildNodeWithBatchData(row.Node1(), nodeBatchData) + if err != nil { + return nil, fmt.Errorf("unable to build node1: %w", err) + } + + node2, err := buildNodeWithBatchData(row.Node2(), nodeBatchData) + if err != nil { + return nil, fmt.Errorf("unable to build node2: %w", err) + } + + // Build channel info using batch data. + channel, err := buildEdgeInfoWithBatchData( + cfg.ChainHash, row.Channel(), node1.PubKeyBytes, + node2.PubKeyBytes, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel "+ + "info: %w", err) + } + + // Extract and build policies using batch data. + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return nil, fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + p1, p2, err := buildChanPoliciesWithBatchData( + dbPol1, dbPol2, channel.ChannelID, + node1.PubKeyBytes, node2.PubKeyBytes, channelBatchData, + ) + if err != nil { + return nil, fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + edges = append(edges, ChannelEdge{ + Info: channel, + Policy1: p1, + Policy2: p2, + Node1: node1, + Node2: node2, + }) + } + + return edges, nil +} diff --git a/sqldb/sqlc/db_custom.go b/sqldb/sqlc/db_custom.go index 2490e5feb..64440a262 100644 --- a/sqldb/sqlc/db_custom.go +++ b/sqldb/sqlc/db_custom.go @@ -37,3 +37,37 @@ func makeQueryParams(numTotalArgs, numListArgs int) string { return b.String() } + +// ChannelAndNodes is an interface that provides access to a channel and its +// two nodes. +type ChannelAndNodes interface { + // Channel returns the GraphChannel associated with this interface. + Channel() GraphChannel + + // Node1 returns the first GraphNode associated with this channel. + Node1() GraphNode + + // Node2 returns the second GraphNode associated with this channel. + Node2() GraphNode +} + +// Channel returns the GraphChannel associated with this interface. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Channel() GraphChannel { + return r.GraphChannel +} + +// Node1 returns the first GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Node1() GraphNode { + return r.GraphNode +} + +// Node2 returns the second GraphNode associated with this channel. +// +// NOTE: This method is part of the ChannelAndNodes interface. +func (r GetChannelsByPolicyLastUpdateRangeRow) Node2() GraphNode { + return r.GraphNode_2 +}