sqldb: add channel data batch queries

Also add the calling logic for these queries. This logic is not yet
used.
This commit is contained in:
Elle Mouton
2025-07-29 12:46:58 +02:00
parent 23dd01cb35
commit 8ad5f633bc
4 changed files with 326 additions and 0 deletions

View File

@@ -115,7 +115,9 @@ type SQLQueries interface {
DeleteChannels(ctx context.Context, ids []int64) error
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelExtraType, error)
InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]sqlc.GraphChannelFeature, error)
/*
Channel Policy table queries.
@@ -126,6 +128,7 @@ type SQLQueries interface {
InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]sqlc.GetChannelPolicyExtraTypesBatchRow, error)
DeleteChannelPolicyExtraTypes(ctx context.Context, channelPolicyID int64) error
/*
@@ -4901,3 +4904,163 @@ func batchLoadNodeExtraTypesHelper(ctx context.Context,
callback,
)
}
// batchChannelData holds all the related data for a batch of channels.
type batchChannelData struct {
// chanFeatures is a map from DB channel ID to a slice of feature bits.
chanfeatures map[int64][]int
// chanExtras is a map from DB channel ID to a map of TLV type to
// extra signed field bytes.
chanExtraTypes map[int64]map[uint64][]byte
// policyExtras is a map from DB channel policy ID to a map of TLV type
// to extra signed field bytes.
policyExtras map[int64]map[uint64][]byte
}
// batchLoadChannelData loads all related data for batches of channels and
// policies.
func batchLoadChannelData(ctx context.Context, cfg *sqldb.PagedQueryConfig,
db SQLQueries, channelIDs []int64,
policyIDs []int64) (*batchChannelData, error) {
batchData := &batchChannelData{
chanfeatures: make(map[int64][]int),
chanExtraTypes: make(map[int64]map[uint64][]byte),
policyExtras: make(map[int64]map[uint64][]byte),
}
// Batch load channel features and extras
var err error
if len(channelIDs) > 0 {
batchData.chanfeatures, err = batchLoadChannelFeaturesHelper(
ctx, cfg, db, channelIDs,
)
if err != nil {
return nil, fmt.Errorf("unable to batch load "+
"channel features: %w", err)
}
batchData.chanExtraTypes, err = batchLoadChannelExtrasHelper(
ctx, cfg, db, channelIDs,
)
if err != nil {
return nil, fmt.Errorf("unable to batch load "+
"channel extras: %w", err)
}
}
if len(policyIDs) > 0 {
policyExtras, err := batchLoadChannelPolicyExtrasHelper(
ctx, cfg, db, policyIDs,
)
if err != nil {
return nil, fmt.Errorf("unable to batch load "+
"policy extras: %w", err)
}
batchData.policyExtras = policyExtras
}
return batchData, nil
}
// batchLoadChannelFeaturesHelper loads channel features for a batch of
// channel IDs using ExecutePagedQuery wrapper around the
// GetChannelFeaturesBatch query. It returns a map from DB channel ID to a
// slice of feature bits.
func batchLoadChannelFeaturesHelper(ctx context.Context,
cfg *sqldb.PagedQueryConfig, db SQLQueries,
channelIDs []int64) (map[int64][]int, error) {
features := make(map[int64][]int)
return features, sqldb.ExecutePagedQuery(
ctx, cfg, channelIDs,
func(id int64) int64 {
return id
},
func(ctx context.Context,
ids []int64) ([]sqlc.GraphChannelFeature, error) {
return db.GetChannelFeaturesBatch(ctx, ids)
},
func(ctx context.Context,
feature sqlc.GraphChannelFeature) error {
features[feature.ChannelID] = append(
features[feature.ChannelID],
int(feature.FeatureBit),
)
return nil
},
)
}
// batchLoadChannelExtrasHelper loads channel extra types for a batch of
// channel IDs using ExecutePagedQuery wrapper around the GetChannelExtrasBatch
// query. It returns a map from DB channel ID to a map of TLV type to extra
// signed field bytes.
func batchLoadChannelExtrasHelper(ctx context.Context,
cfg *sqldb.PagedQueryConfig, db SQLQueries,
channelIDs []int64) (map[int64]map[uint64][]byte, error) {
extras := make(map[int64]map[uint64][]byte)
cb := func(ctx context.Context,
extra sqlc.GraphChannelExtraType) error {
if extras[extra.ChannelID] == nil {
extras[extra.ChannelID] = make(map[uint64][]byte)
}
extras[extra.ChannelID][uint64(extra.Type)] = extra.Value
return nil
}
return extras, sqldb.ExecutePagedQuery(
ctx, cfg, channelIDs,
func(id int64) int64 {
return id
},
func(ctx context.Context,
ids []int64) ([]sqlc.GraphChannelExtraType, error) {
return db.GetChannelExtrasBatch(ctx, ids)
}, cb,
)
}
// batchLoadChannelPolicyExtrasHelper loads channel policy extra types for a
// batch of policy IDs using ExecutePagedQuery wrapper around the
// GetChannelPolicyExtraTypesBatch query. It returns a map from DB policy ID to
// a map of TLV type to extra signed field bytes.
func batchLoadChannelPolicyExtrasHelper(ctx context.Context,
cfg *sqldb.PagedQueryConfig, db SQLQueries,
policyIDs []int64) (map[int64]map[uint64][]byte, error) {
extras := make(map[int64]map[uint64][]byte)
return extras, sqldb.ExecutePagedQuery(
ctx, cfg, policyIDs,
func(id int64) int64 {
return id
},
func(ctx context.Context, ids []int64) (
[]sqlc.GetChannelPolicyExtraTypesBatchRow, error) {
return db.GetChannelPolicyExtraTypesBatch(ctx, ids)
},
func(ctx context.Context,
row sqlc.GetChannelPolicyExtraTypesBatchRow) error {
if extras[row.PolicyID] == nil {
extras[row.PolicyID] = make(map[uint64][]byte)
}
extras[row.PolicyID][uint64(row.Type)] = row.Value
return nil
},
)
}

View File

@@ -701,6 +701,49 @@ func (q *Queries) GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChann
return i, err
}
const getChannelExtrasBatch = `-- name: GetChannelExtrasBatch :many
SELECT
channel_id,
type,
value
FROM graph_channel_extra_types
WHERE channel_id IN (/*SLICE:chan_ids*/?)
ORDER BY channel_id, type
`
func (q *Queries) GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error) {
query := getChannelExtrasBatch
var queryParams []interface{}
if len(chanIds) > 0 {
for _, v := range chanIds {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1)
} else {
query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GraphChannelExtraType
for rows.Next() {
var i GraphChannelExtraType
if err := rows.Scan(&i.ChannelID, &i.Type, &i.Value); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChannelFeaturesAndExtras = `-- name: GetChannelFeaturesAndExtras :many
SELECT
cf.channel_id,
@@ -760,6 +803,48 @@ func (q *Queries) GetChannelFeaturesAndExtras(ctx context.Context, channelID int
return items, nil
}
const getChannelFeaturesBatch = `-- name: GetChannelFeaturesBatch :many
SELECT
channel_id,
feature_bit
FROM graph_channel_features
WHERE channel_id IN (/*SLICE:chan_ids*/?)
ORDER BY channel_id, feature_bit
`
func (q *Queries) GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]GraphChannelFeature, error) {
query := getChannelFeaturesBatch
var queryParams []interface{}
if len(chanIds) > 0 {
for _, v := range chanIds {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:chan_ids*/?", makeQueryParams(len(queryParams), len(chanIds)), 1)
} else {
query = strings.Replace(query, "/*SLICE:chan_ids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GraphChannelFeature
for rows.Next() {
var i GraphChannelFeature
if err := rows.Scan(&i.ChannelID, &i.FeatureBit); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChannelPolicyByChannelAndNode = `-- name: GetChannelPolicyByChannelAndNode :one
SELECT id, version, channel_id, node_id, timelock, fee_ppm, base_fee_msat, min_htlc_msat, max_htlc_msat, last_update, disabled, inbound_base_fee_msat, inbound_fee_rate_milli_msat, message_flags, channel_flags, signature
FROM graph_channel_policies
@@ -853,6 +938,55 @@ func (q *Queries) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannel
return items, nil
}
const getChannelPolicyExtraTypesBatch = `-- name: GetChannelPolicyExtraTypesBatch :many
SELECT
channel_policy_id as policy_id,
type,
value
FROM graph_channel_policy_extra_types
WHERE channel_policy_id IN (/*SLICE:policy_ids*/?)
ORDER BY channel_policy_id, type
`
type GetChannelPolicyExtraTypesBatchRow struct {
PolicyID int64
Type int64
Value []byte
}
func (q *Queries) GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error) {
query := getChannelPolicyExtraTypesBatch
var queryParams []interface{}
if len(policyIds) > 0 {
for _, v := range policyIds {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:policy_ids*/?", makeQueryParams(len(queryParams), len(policyIds)), 1)
} else {
query = strings.Replace(query, "/*SLICE:policy_ids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChannelPolicyExtraTypesBatchRow
for rows.Next() {
var i GetChannelPolicyExtraTypesBatchRow
if err := rows.Scan(&i.PolicyID, &i.Type, &i.Value); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChannelsByOutpoints = `-- name: GetChannelsByOutpoints :many
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,

View File

@@ -38,9 +38,12 @@ type Querier interface {
GetChannelByOutpointWithPolicies(ctx context.Context, arg GetChannelByOutpointWithPoliciesParams) (GetChannelByOutpointWithPoliciesRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (GraphChannel, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg GetChannelBySCIDWithPoliciesParams) (GetChannelBySCIDWithPoliciesRow, error)
GetChannelExtrasBatch(ctx context.Context, chanIds []int64) ([]GraphChannelExtraType, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error)
GetChannelFeaturesBatch(ctx context.Context, chanIds []int64) ([]GraphChannelFeature, error)
GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (GraphChannelPolicy, error)
GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error)
GetChannelPolicyExtraTypesBatch(ctx context.Context, policyIds []int64) ([]GetChannelPolicyExtraTypesBatchRow, error)
GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error)

View File

@@ -647,6 +647,14 @@ INSERT INTO graph_channel_features (
$1, $2
);
-- name: GetChannelFeaturesBatch :many
SELECT
channel_id,
feature_bit
FROM graph_channel_features
WHERE channel_id IN (sqlc.slice('chan_ids')/*SLICE:chan_ids*/)
ORDER BY channel_id, feature_bit;
/* ─────────────────────────────────────────────
graph_channel_extra_types table queries
─────────────────────────────────────────────
@@ -658,6 +666,15 @@ INSERT INTO graph_channel_extra_types (
)
VALUES ($1, $2, $3);
-- name: GetChannelExtrasBatch :many
SELECT
channel_id,
type,
value
FROM graph_channel_extra_types
WHERE channel_id IN (sqlc.slice('chan_ids')/*SLICE:chan_ids*/)
ORDER BY channel_id, type;
/* ─────────────────────────────────────────────
graph_channel_policies table queries
─────────────────────────────────────────────
@@ -760,6 +777,15 @@ INSERT INTO graph_channel_policy_extra_types (
)
VALUES ($1, $2, $3);
-- name: GetChannelPolicyExtraTypesBatch :many
SELECT
channel_policy_id as policy_id,
type,
value
FROM graph_channel_policy_extra_types
WHERE channel_policy_id IN (sqlc.slice('policy_ids')/*SLICE:policy_ids*/)
ORDER BY channel_policy_id, type;
-- name: GetChannelPolicyExtraTypes :many
SELECT
cp.id AS policy_id,