From 8ad5f633bc213e52e5e68fe0f104005e5924185a Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 29 Jul 2025 12:46:58 +0200 Subject: [PATCH] sqldb: add channel data batch queries Also add the calling logic for these queries. This logic is not yet used. --- graph/db/sql_store.go | 163 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 134 ++++++++++++++++++++++++++++ sqldb/sqlc/querier.go | 3 + sqldb/sqlc/queries/graph.sql | 26 ++++++ 4 files changed, 326 insertions(+) diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 5c8c12e8e..ca7b8cdcd 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -115,7 +115,9 @@ type SQLQueries interface { DeleteChannels(ctx context.Context, ids []int64) error CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error + GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error) InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error + GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error) /* Channel Policy table queries. @@ -126,6 +128,7 @@ type SQLQueries interface { InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) + GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error) DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error /* @@ -4901,3 +4904,163 @@ func batchLoadNodeExtraTypesHelper(ctx context.Context, callback, ) } + +// batchChannelData holds all the related data for a batch of channels. +type batchChannelData struct { + // chanFeatures is a map from DB channel ID to a slice of feature bits. + chanfeatures map[int64][]int + + // chanExtras is a map from DB channel ID to a map of TLV type to + // extra signed field bytes. + chanExtraTypes map[int64]map[uint64][]byte + + // policyExtras is a map from DB channel policy ID to a map of TLV type + // to extra signed field bytes. + policyExtras map[int64]map[uint64][]byte +} + +// batchLoadChannelData loads all related data for batches of channels and +// policies. +func batchLoadChannelData(ctx context.Context, cfg *sqldb.PagedQueryConfig, + db SQLQueries, channelIDs []int64, + policyIDs []int64) (*batchChannelData, error) { + + batchData := &batchChannelData{ + chanfeatures: make(map[int64][]int), + chanExtraTypes: make(map[int64]map[uint64][]byte), + policyExtras: make(map[int64]map[uint64][]byte), + } + + // Batch load channel features and extras + var err error + if len(channelIDs) > 0 { + batchData.chanfeatures, err = batchLoadChannelFeaturesHelper( + ctx, cfg, db, channelIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load "+ + "channel features: %w", err) + } + + batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper( + ctx, cfg, db, channelIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load "+ + "channel extras: %w", err) + } + } + + if len(policyIDs) > 0 { + policyExtras, err := batchLoadChannelPolicyExtrasHelper( + ctx, cfg, db, policyIDs, + ) + if err != nil { + return nil, fmt.Errorf("unable to batch load "+ + "policy extras: %w", err) + } + batchData.policyExtras = policyExtras + } + + return batchData, nil +} + +// batchLoadChannelFeaturesHelper loads channel features for a batch of +// channel IDs using ExecutePagedQuery wrapper around the +// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a +// slice of feature bits. +func batchLoadChannelFeaturesHelper(ctx context.Context, + cfg *sqldb.PagedQueryConfig, db SQLQueries, + channelIDs []int64) (map[int64][]int, error) { + + features := make(map[int64][]int) + + return features, sqldb.ExecutePagedQuery( + ctx, cfg, channelIDs, + func(id int64) int64 { + return id + }, + func(ctx context.Context, + ids []int64) ([]sqlc.GraphChannelFeature, error) { + + return db.GetChannelFeaturesBatch(ctx, ids) + }, + func(ctx context.Context, + feature sqlc.GraphChannelFeature) error { + + features[feature.ChannelID] = append( + features[feature.ChannelID], + int(feature.FeatureBit), + ) + + return nil + }, + ) +} + +// batchLoadChannelExtrasHelper loads channel extra types for a batch of +// channel IDs using ExecutePagedQuery wrapper around the GetChannelExtrasBatch +// query. It returns a map from DB channel ID to a map of TLV type to extra +// signed field bytes. +func batchLoadChannelExtrasHelper(ctx context.Context, + cfg *sqldb.PagedQueryConfig, db SQLQueries, + channelIDs []int64) (map[int64]map[uint64][]byte, error) { + + extras := make(map[int64]map[uint64][]byte) + + cb := func(ctx context.Context, + extra sqlc.GraphChannelExtraType) error { + + if extras[extra.ChannelID] == nil { + extras[extra.ChannelID] = make(map[uint64][]byte) + } + extras[extra.ChannelID][uint64(extra.Type)] = extra.Value + + return nil + } + + return extras, sqldb.ExecutePagedQuery( + ctx, cfg, channelIDs, + func(id int64) int64 { + return id + }, + func(ctx context.Context, + ids []int64) ([]sqlc.GraphChannelExtraType, error) { + + return db.GetChannelExtrasBatch(ctx, ids) + }, cb, + ) +} + +// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a +// batch of policy IDs using ExecutePagedQuery wrapper around the +// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to +// a map of TLV type to extra signed field bytes. +func batchLoadChannelPolicyExtrasHelper(ctx context.Context, + cfg *sqldb.PagedQueryConfig, db SQLQueries, + policyIDs []int64) (map[int64]map[uint64][]byte, error) { + + extras := make(map[int64]map[uint64][]byte) + + return extras, sqldb.ExecutePagedQuery( + ctx, cfg, policyIDs, + func(id int64) int64 { + return id + }, + func(ctx context.Context, ids []int64) ( + []sqlc.GetChannelPolicyExtraTypesBatchRow, error) { + + return db.GetChannelPolicyExtraTypesBatch(ctx, ids) + }, + func(ctx context.Context, + row sqlc.GetChannelPolicyExtraTypesBatchRow) error { + + if extras[row.PolicyID] == nil { + extras[row.PolicyID] = make(map[uint64][]byte) + } + extras[row.PolicyID][uint64(row.Type)] = row.Value + + return nil + }, + ) +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 3122d1dab..9f847720e 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -701,6 +701,49 @@ func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChann return i, err } +const getChannelExtrasBatch = `-- name: GetChannelExtrasBatch :many +SELECT + channel_id, + type, + value +FROM graph_channel_extra_types +WHERE channel_id IN (/*SLICE:chan_ids*/?) +ORDER BY channel_id, type +` + +func (q *Queries) GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) { + query := getChannelExtrasBatch + var queryParams []interface{} + if len(chanIds) > 0 { + for _, v := range chanIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphChannelExtraType + for rows.Next() { + var i GraphChannelExtraType + if err := rows.Scan(&i.ChannelID, &i.Type, &i.Value); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many SELECT cf.channel_id, @@ -760,6 +803,48 @@ func (q *Queries) GetChannelFeaturesAndExtras(ctx context.Context, channelID int return items, nil } +const getChannelFeaturesBatch = `-- name: GetChannelFeaturesBatch :many +SELECT + channel_id, + feature_bit +FROM graph_channel_features +WHERE channel_id IN (/*SLICE:chan_ids*/?) +ORDER BY channel_id, feature_bit +` + +func (q *Queries) GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]GraphChannelFeature, error) { + query := getChannelFeaturesBatch + var queryParams []interface{} + if len(chanIds) > 0 { + for _, v := range chanIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GraphChannelFeature + for rows.Next() { + var i GraphChannelFeature + if err := rows.Scan(&i.ChannelID, &i.FeatureBit); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelPolicyByChannelAndNode = `-- name: GetChannelPolicyByChannelAndNode :one SELECT id, version, channel_id, node_id, timelock, fee_ppm, base_fee_msat, min_htlc_msat, max_htlc_msat, last_update, disabled, inbound_base_fee_msat, inbound_fee_rate_milli_msat, message_flags, channel_flags, signature FROM graph_channel_policies @@ -853,6 +938,55 @@ func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannel return items, nil } +const getChannelPolicyExtraTypesBatch = `-- name: GetChannelPolicyExtraTypesBatch :many +SELECT + channel_policy_id as policy_id, + type, + value +FROM graph_channel_policy_extra_types +WHERE channel_policy_id IN (/*SLICE:policy_ids*/?) +ORDER BY channel_policy_id, type +` + +type GetChannelPolicyExtraTypesBatchRow struct { + PolicyID int64 + Type int64 + Value []byte +} + +func (q *Queries) GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error) { + query := getChannelPolicyExtraTypesBatch + var queryParams []interface{} + if len(policyIds) > 0 { + for _, v := range policyIds { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:policy_ids*/?", makeQueryParams(len(queryParams), len(policyIds)), 1) + } else { + query = strings.Replace(query, "/*SLICE:policy_ids*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelPolicyExtraTypesBatchRow + for rows.Next() { + var i GetChannelPolicyExtraTypesBatchRow + if err := rows.Scan(&i.PolicyID, &i.Type, &i.Value); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index e32383032..7f2faeb3b 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -38,9 +38,12 @@ type Querier interface { GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (GraphChannel, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error) + GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) + GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]GraphChannelFeature, error) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (GraphChannelPolicy, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) + GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error) GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index afa446f0b..cd7dfb467 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -647,6 +647,14 @@ INSERT INTO graph_channel_features ( $1, $2 ); +-- name: GetChannelFeaturesBatch :many +SELECT + channel_id, + feature_bit +FROM graph_channel_features +WHERE channel_id IN (sqlc.slice('chan_ids')/*SLICE:chan_ids*/) +ORDER BY channel_id, feature_bit; + /* ───────────────────────────────────────────── graph_channel_extra_types table queries ───────────────────────────────────────────── @@ -658,6 +666,15 @@ INSERT INTO graph_channel_extra_types ( ) VALUES ($1, $2, $3); +-- name: GetChannelExtrasBatch :many +SELECT + channel_id, + type, + value +FROM graph_channel_extra_types +WHERE channel_id IN (sqlc.slice('chan_ids')/*SLICE:chan_ids*/) +ORDER BY channel_id, type; + /* ───────────────────────────────────────────── graph_channel_policies table queries ───────────────────────────────────────────── @@ -760,6 +777,15 @@ INSERT INTO graph_channel_policy_extra_types ( ) VALUES ($1, $2, $3); +-- name: GetChannelPolicyExtraTypesBatch :many +SELECT + channel_policy_id as policy_id, + type, + value +FROM graph_channel_policy_extra_types +WHERE channel_policy_id IN (sqlc.slice('policy_ids')/*SLICE:policy_ids*/) +ORDER BY channel_policy_id, type; + -- name: GetChannelPolicyExtraTypes :many SELECT cp.id AS policy_id,