From e269d57ffa0f55198ab5772425b1d98a5c41f918 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 15 Jul 2025 16:17:53 +0200 Subject: [PATCH] sqldb+graph/db: use pagination for FetchChanInfos --- graph/db/sql_store.go | 117 +++++++++++++++++----- sqldb/sqlc/graph.sql.go | 185 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 51 ++++++++++ 4 files changed, 330 insertions(+), 24 deletions(-) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 5e9305489..a782f4b45 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -97,6 +97,7 @@ type SQLQueries interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error) GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) + GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) @@ -2170,27 +2171,11 @@ func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) { func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { var ( ctx = context.TODO() - edges []ChannelEdge + edges = make(map[uint64]ChannelEdge) ) err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { - for _, chanID := range chanIDs { - chanIDB := channelIDToBytes(chanID) - - // TODO(elle): potentially optimize this by using - // sqlc.slice() once that works for both SQLite and - // Postgres. - row, err := db.GetChannelBySCIDWithPolicies( - ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ - Scid: chanIDB, - Version: int16(ProtocolV1), - }, - ) - if errors.Is(err, sql.ErrNoRows) { - continue - } else if err != nil { - return fmt.Errorf("unable to fetch channel: %w", - err) - } + chanCallBack := func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error { node1, node2, err := buildNodes( ctx, db, row.GraphNode, row.GraphNode_2, @@ -2225,24 +2210,64 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { "policies: %w", err) } - edges = append(edges, ChannelEdge{ + edges[edge.ChannelID] = ChannelEdge{ Info: edge, Policy1: p1, Policy2: p2, Node1: node1, Node2: node2, - }) + } + + return nil } - return nil + return s.forEachChanWithPoliciesInSCIDList( + ctx, db, chanCallBack, chanIDs, + ) }, func() { - edges = nil + clear(edges) }) if err != nil { return nil, fmt.Errorf("unable to fetch channels: %w", err) } - return edges, nil + res := make([]ChannelEdge, 0, len(edges)) + for _, chanID := range chanIDs { + edge, ok := edges[chanID] + if !ok { + continue + } + + res = append(res, edge) + } + + return res, nil +} + +// forEachChanWithPoliciesInSCIDList is a wrapper around the +// GetChannelsBySCIDWithPolicies query that allows us to iterate through +// channels in a paginated manner. +func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context, + db SQLQueries, cb func(ctx context.Context, + row sqlc.GetChannelsBySCIDWithPoliciesRow) error, + chanIDs []uint64) error { + + queryWrapper := func(ctx context.Context, + scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, + error) { + + return db.GetChannelsBySCIDWithPolicies( + ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{ + Version: int16(ProtocolV1), + Scids: scids, + }, + ) + } + + return sqldb.ExecutePagedQuery( + ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes, + queryWrapper, cb, + ) } // FilterKnownChanIDs takes a set of channel IDs and return the subset of chan @@ -4300,6 +4325,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy, var policy1, policy2 *sqlc.GraphChannelPolicy switch r := row.(type) { + case sqlc.GetChannelsBySCIDWithPoliciesRow: + if r.Policy1ID.Valid { + policy1 = &sqlc.GraphChannelPolicy{ + ID: r.Policy1ID.Int64, + Version: r.Policy1Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy1NodeID.Int64, + Timelock: r.Policy1Timelock.Int32, + FeePpm: r.Policy1FeePpm.Int64, + BaseFeeMsat: r.Policy1BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy1MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy1MaxHtlcMsat, + LastUpdate: r.Policy1LastUpdate, + InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat, + Disabled: r.Policy1Disabled, + MessageFlags: r.Policy1MessageFlags, + ChannelFlags: r.Policy1ChannelFlags, + Signature: r.Policy1Signature, + } + } + if r.Policy2ID.Valid { + policy2 = &sqlc.GraphChannelPolicy{ + ID: r.Policy2ID.Int64, + Version: r.Policy2Version.Int16, + ChannelID: r.GraphChannel.ID, + NodeID: r.Policy2NodeID.Int64, + Timelock: r.Policy2Timelock.Int32, + FeePpm: r.Policy2FeePpm.Int64, + BaseFeeMsat: r.Policy2BaseFeeMsat.Int64, + MinHtlcMsat: r.Policy2MinHtlcMsat.Int64, + MaxHtlcMsat: r.Policy2MaxHtlcMsat, + LastUpdate: r.Policy2LastUpdate, + InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat, + InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat, + Disabled: r.Policy2Disabled, + MessageFlags: r.Policy2MessageFlags, + ChannelFlags: r.Policy2ChannelFlags, + Signature: r.Policy2Signature, + } + } + + return policy1, policy2, nil + case sqlc.GetChannelByOutpointWithPoliciesRow: if r.Policy1ID.Valid { policy1 = &sqlc.GraphChannelPolicy{ diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 09df90264..2f0ba8d67 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -1154,6 +1154,191 @@ func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsByS return items, nil } +const getChannelsBySCIDWithPolicies = `-- name: GetChannelsBySCIDWithPolicies :many +SELECT + c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature, + n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature, + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE + c.version = $1 + AND c.scid IN (/*SLICE:scids*/?) +` + +type GetChannelsBySCIDWithPoliciesParams struct { + Version int16 + Scids [][]byte +} + +type GetChannelsBySCIDWithPoliciesRow struct { + GraphChannel GraphChannel + GraphNode GraphNode + GraphNode_2 GraphNode + Policy1ID sql.NullInt64 + Policy1NodeID sql.NullInt64 + Policy1Version sql.NullInt16 + Policy1Timelock sql.NullInt32 + Policy1FeePpm sql.NullInt64 + Policy1BaseFeeMsat sql.NullInt64 + Policy1MinHtlcMsat sql.NullInt64 + Policy1MaxHtlcMsat sql.NullInt64 + Policy1LastUpdate sql.NullInt64 + Policy1Disabled sql.NullBool + Policy1InboundBaseFeeMsat sql.NullInt64 + Policy1InboundFeeRateMilliMsat sql.NullInt64 + Policy1MessageFlags sql.NullInt16 + Policy1ChannelFlags sql.NullInt16 + Policy1Signature []byte + Policy2ID sql.NullInt64 + Policy2NodeID sql.NullInt64 + Policy2Version sql.NullInt16 + Policy2Timelock sql.NullInt32 + Policy2FeePpm sql.NullInt64 + Policy2BaseFeeMsat sql.NullInt64 + Policy2MinHtlcMsat sql.NullInt64 + Policy2MaxHtlcMsat sql.NullInt64 + Policy2LastUpdate sql.NullInt64 + Policy2Disabled sql.NullBool + Policy2InboundBaseFeeMsat sql.NullInt64 + Policy2InboundFeeRateMilliMsat sql.NullInt64 + Policy2MessageFlags sql.NullInt16 + Policy2ChannelFlags sql.NullInt16 + Policy2Signature []byte +} + +func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) { + query := getChannelsBySCIDWithPolicies + var queryParams []interface{} + queryParams = append(queryParams, arg.Version) + if len(arg.Scids) > 0 { + for _, v := range arg.Scids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1) + } else { + query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsBySCIDWithPoliciesRow + for rows.Next() { + var i GetChannelsBySCIDWithPoliciesRow + if err := rows.Scan( + &i.GraphChannel.ID, + &i.GraphChannel.Version, + &i.GraphChannel.Scid, + &i.GraphChannel.NodeID1, + &i.GraphChannel.NodeID2, + &i.GraphChannel.Outpoint, + &i.GraphChannel.Capacity, + &i.GraphChannel.BitcoinKey1, + &i.GraphChannel.BitcoinKey2, + &i.GraphChannel.Node1Signature, + &i.GraphChannel.Node2Signature, + &i.GraphChannel.Bitcoin1Signature, + &i.GraphChannel.Bitcoin2Signature, + &i.GraphNode.ID, + &i.GraphNode.Version, + &i.GraphNode.PubKey, + &i.GraphNode.Alias, + &i.GraphNode.LastUpdate, + &i.GraphNode.Color, + &i.GraphNode.Signature, + &i.GraphNode_2.ID, + &i.GraphNode_2.Version, + &i.GraphNode_2.PubKey, + &i.GraphNode_2.Alias, + &i.GraphNode_2.LastUpdate, + &i.GraphNode_2.Color, + &i.GraphNode_2.Signature, + &i.Policy1ID, + &i.Policy1NodeID, + &i.Policy1Version, + &i.Policy1Timelock, + &i.Policy1FeePpm, + &i.Policy1BaseFeeMsat, + &i.Policy1MinHtlcMsat, + &i.Policy1MaxHtlcMsat, + &i.Policy1LastUpdate, + &i.Policy1Disabled, + &i.Policy1InboundBaseFeeMsat, + &i.Policy1InboundFeeRateMilliMsat, + &i.Policy1MessageFlags, + &i.Policy1ChannelFlags, + &i.Policy1Signature, + &i.Policy2ID, + &i.Policy2NodeID, + &i.Policy2Version, + &i.Policy2Timelock, + &i.Policy2FeePpm, + &i.Policy2BaseFeeMsat, + &i.Policy2MinHtlcMsat, + &i.Policy2MaxHtlcMsat, + &i.Policy2LastUpdate, + &i.Policy2Disabled, + &i.Policy2InboundBaseFeeMsat, + &i.Policy2InboundFeeRateMilliMsat, + &i.Policy2MessageFlags, + &i.Policy2ChannelFlags, + &i.Policy2Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelsBySCIDs = `-- name: GetChannelsBySCIDs :many SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM graph_channels WHERE version = $1 diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 01d45734a..cd32dc75b 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -44,6 +44,7 @@ type Querier interface { GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) + GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5216ddea5..52c09e23d 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -283,6 +283,57 @@ WHERE cet.channel_id = $1; SELECT scid from graph_channels WHERE outpoint = $1 AND version = $2; +-- name: GetChannelsBySCIDWithPolicies :many +SELECT + sqlc.embed(c), + sqlc.embed(n1), + sqlc.embed(n2), + + -- Policy 1 + cp1.id AS policy1_id, + cp1.node_id AS policy1_node_id, + cp1.version AS policy1_version, + cp1.timelock AS policy1_timelock, + cp1.fee_ppm AS policy1_fee_ppm, + cp1.base_fee_msat AS policy1_base_fee_msat, + cp1.min_htlc_msat AS policy1_min_htlc_msat, + cp1.max_htlc_msat AS policy1_max_htlc_msat, + cp1.last_update AS policy1_last_update, + cp1.disabled AS policy1_disabled, + cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat, + cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat, + cp1.message_flags AS policy1_message_flags, + cp1.channel_flags AS policy1_channel_flags, + cp1.signature AS policy1_signature, + + -- Policy 2 + cp2.id AS policy2_id, + cp2.node_id AS policy2_node_id, + cp2.version AS policy2_version, + cp2.timelock AS policy2_timelock, + cp2.fee_ppm AS policy2_fee_ppm, + cp2.base_fee_msat AS policy2_base_fee_msat, + cp2.min_htlc_msat AS policy2_min_htlc_msat, + cp2.max_htlc_msat AS policy2_max_htlc_msat, + cp2.last_update AS policy2_last_update, + cp2.disabled AS policy2_disabled, + cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat, + cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat, + cp2.message_flags AS policy_2_message_flags, + cp2.channel_flags AS policy_2_channel_flags, + cp2.signature AS policy2_signature + +FROM graph_channels c + JOIN graph_nodes n1 ON c.node_id_1 = n1.id + JOIN graph_nodes n2 ON c.node_id_2 = n2.id + LEFT JOIN graph_channel_policies cp1 + ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version + LEFT JOIN graph_channel_policies cp2 + ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version +WHERE + c.version = @version + AND c.scid IN (sqlc.slice('scids')/*SLICE:scids*/); + -- name: GetChannelsByPolicyLastUpdateRange :many SELECT sqlc.embed(c),