sqldb+graph/db: use pagination for FetchChanInfos

This commit is contained in:
Elle Mouton
2025-07-15 16:17:53 +02:00
parent 88e9a21d63
commit e269d57ffa
4 changed files with 330 additions and 24 deletions

View File

@@ -97,6 +97,7 @@ type SQLQueries interface {
GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]sqlc.GetChannelsByOutpointsRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
GetChannelsBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelsBySCIDWithPoliciesParams) ([]sqlc.GetChannelsBySCIDWithPoliciesRow, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
@@ -2170,27 +2171,11 @@ func (s *SQLStore) IsPublicNode(pubKey [33]byte) (bool, error) {
func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
var (
ctx = context.TODO()
edges []ChannelEdge
edges = make(map[uint64]ChannelEdge)
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for _, chanID := range chanIDs {
chanIDB := channelIDToBytes(chanID)
// TODO(elle): potentially optimize this by using
// sqlc.slice() once that works for both SQLite and
// Postgres.
row, err := db.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: chanIDB,
Version: int16(ProtocolV1),
},
)
if errors.Is(err, sql.ErrNoRows) {
continue
} else if err != nil {
return fmt.Errorf("unable to fetch channel: %w",
err)
}
chanCallBack := func(ctx context.Context,
row sqlc.GetChannelsBySCIDWithPoliciesRow) error {
node1, node2, err := buildNodes(
ctx, db, row.GraphNode, row.GraphNode_2,
@@ -2225,24 +2210,64 @@ func (s *SQLStore) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
"policies: %w", err)
}
edges = append(edges, ChannelEdge{
edges[edge.ChannelID] = ChannelEdge{
Info: edge,
Policy1: p1,
Policy2: p2,
Node1: node1,
Node2: node2,
})
}
return nil
}
return nil
return s.forEachChanWithPoliciesInSCIDList(
ctx, db, chanCallBack, chanIDs,
)
}, func() {
edges = nil
clear(edges)
})
if err != nil {
return nil, fmt.Errorf("unable to fetch channels: %w", err)
}
return edges, nil
res := make([]ChannelEdge, 0, len(edges))
for _, chanID := range chanIDs {
edge, ok := edges[chanID]
if !ok {
continue
}
res = append(res, edge)
}
return res, nil
}
// forEachChanWithPoliciesInSCIDList is a wrapper around the
// GetChannelsBySCIDWithPolicies query that allows us to iterate through
// channels in a paginated manner.
func (s *SQLStore) forEachChanWithPoliciesInSCIDList(ctx context.Context,
db SQLQueries, cb func(ctx context.Context,
row sqlc.GetChannelsBySCIDWithPoliciesRow) error,
chanIDs []uint64) error {
queryWrapper := func(ctx context.Context,
scids [][]byte) ([]sqlc.GetChannelsBySCIDWithPoliciesRow,
error) {
return db.GetChannelsBySCIDWithPolicies(
ctx, sqlc.GetChannelsBySCIDWithPoliciesParams{
Version: int16(ProtocolV1),
Scids: scids,
},
)
}
return sqldb.ExecutePagedQuery(
ctx, s.cfg.PaginationCfg, chanIDs, channelIDToBytes,
queryWrapper, cb,
)
}
// FilterKnownChanIDs takes a set of channel IDs and return the subset of chan
@@ -4300,6 +4325,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
var policy1, policy2 *sqlc.GraphChannelPolicy
switch r := row.(type) {
case sqlc.GetChannelsBySCIDWithPoliciesRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.GraphChannelPolicy{
ID: r.Policy1ID.Int64,
Version: r.Policy1Version.Int16,
ChannelID: r.GraphChannel.ID,
NodeID: r.Policy1NodeID.Int64,
Timelock: r.Policy1Timelock.Int32,
FeePpm: r.Policy1FeePpm.Int64,
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
LastUpdate: r.Policy1LastUpdate,
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
Disabled: r.Policy1Disabled,
MessageFlags: r.Policy1MessageFlags,
ChannelFlags: r.Policy1ChannelFlags,
Signature: r.Policy1Signature,
}
}
if r.Policy2ID.Valid {
policy2 = &sqlc.GraphChannelPolicy{
ID: r.Policy2ID.Int64,
Version: r.Policy2Version.Int16,
ChannelID: r.GraphChannel.ID,
NodeID: r.Policy2NodeID.Int64,
Timelock: r.Policy2Timelock.Int32,
FeePpm: r.Policy2FeePpm.Int64,
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
LastUpdate: r.Policy2LastUpdate,
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
Disabled: r.Policy2Disabled,
MessageFlags: r.Policy2MessageFlags,
ChannelFlags: r.Policy2ChannelFlags,
Signature: r.Policy2Signature,
}
}
return policy1, policy2, nil
case sqlc.GetChannelByOutpointWithPoliciesRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.GraphChannelPolicy{

View File

@@ -1154,6 +1154,191 @@ func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsByS
return items, nil
}
const getChannelsBySCIDWithPolicies = `-- name: GetChannelsBySCIDWithPolicies :many
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.id, n1.version, n1.pub_key, n1.alias, n1.last_update, n1.color, n1.signature,
n2.id, n2.version, n2.pub_key, n2.alias, n2.last_update, n2.color, n2.signature,
-- Policy 1
cp1.id AS policy1_id,
cp1.node_id AS policy1_node_id,
cp1.version AS policy1_version,
cp1.timelock AS policy1_timelock,
cp1.fee_ppm AS policy1_fee_ppm,
cp1.base_fee_msat AS policy1_base_fee_msat,
cp1.min_htlc_msat AS policy1_min_htlc_msat,
cp1.max_htlc_msat AS policy1_max_htlc_msat,
cp1.last_update AS policy1_last_update,
cp1.disabled AS policy1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.message_flags AS policy1_message_flags,
cp1.channel_flags AS policy1_channel_flags,
cp1.signature AS policy1_signature,
-- Policy 2
cp2.id AS policy2_id,
cp2.node_id AS policy2_node_id,
cp2.version AS policy2_version,
cp2.timelock AS policy2_timelock,
cp2.fee_ppm AS policy2_fee_ppm,
cp2.base_fee_msat AS policy2_base_fee_msat,
cp2.min_htlc_msat AS policy2_min_htlc_msat,
cp2.max_htlc_msat AS policy2_max_htlc_msat,
cp2.last_update AS policy2_last_update,
cp2.disabled AS policy2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.message_flags AS policy_2_message_flags,
cp2.channel_flags AS policy_2_channel_flags,
cp2.signature AS policy2_signature
FROM graph_channels c
JOIN graph_nodes n1 ON c.node_id_1 = n1.id
JOIN graph_nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN graph_channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN graph_channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE
c.version = $1
AND c.scid IN (/*SLICE:scids*/?)
`
type GetChannelsBySCIDWithPoliciesParams struct {
Version int16
Scids [][]byte
}
type GetChannelsBySCIDWithPoliciesRow struct {
GraphChannel GraphChannel
GraphNode GraphNode
GraphNode_2 GraphNode
Policy1ID sql.NullInt64
Policy1NodeID sql.NullInt64
Policy1Version sql.NullInt16
Policy1Timelock sql.NullInt32
Policy1FeePpm sql.NullInt64
Policy1BaseFeeMsat sql.NullInt64
Policy1MinHtlcMsat sql.NullInt64
Policy1MaxHtlcMsat sql.NullInt64
Policy1LastUpdate sql.NullInt64
Policy1Disabled sql.NullBool
Policy1InboundBaseFeeMsat sql.NullInt64
Policy1InboundFeeRateMilliMsat sql.NullInt64
Policy1MessageFlags sql.NullInt16
Policy1ChannelFlags sql.NullInt16
Policy1Signature []byte
Policy2ID sql.NullInt64
Policy2NodeID sql.NullInt64
Policy2Version sql.NullInt16
Policy2Timelock sql.NullInt32
Policy2FeePpm sql.NullInt64
Policy2BaseFeeMsat sql.NullInt64
Policy2MinHtlcMsat sql.NullInt64
Policy2MaxHtlcMsat sql.NullInt64
Policy2LastUpdate sql.NullInt64
Policy2Disabled sql.NullBool
Policy2InboundBaseFeeMsat sql.NullInt64
Policy2InboundFeeRateMilliMsat sql.NullInt64
Policy2MessageFlags sql.NullInt16
Policy2ChannelFlags sql.NullInt16
Policy2Signature []byte
}
func (q *Queries) GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error) {
query := getChannelsBySCIDWithPolicies
var queryParams []interface{}
queryParams = append(queryParams, arg.Version)
if len(arg.Scids) > 0 {
for _, v := range arg.Scids {
queryParams = append(queryParams, v)
}
query = strings.Replace(query, "/*SLICE:scids*/?", makeQueryParams(len(queryParams), len(arg.Scids)), 1)
} else {
query = strings.Replace(query, "/*SLICE:scids*/?", "NULL", 1)
}
rows, err := q.db.QueryContext(ctx, query, queryParams...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChannelsBySCIDWithPoliciesRow
for rows.Next() {
var i GetChannelsBySCIDWithPoliciesRow
if err := rows.Scan(
&i.GraphChannel.ID,
&i.GraphChannel.Version,
&i.GraphChannel.Scid,
&i.GraphChannel.NodeID1,
&i.GraphChannel.NodeID2,
&i.GraphChannel.Outpoint,
&i.GraphChannel.Capacity,
&i.GraphChannel.BitcoinKey1,
&i.GraphChannel.BitcoinKey2,
&i.GraphChannel.Node1Signature,
&i.GraphChannel.Node2Signature,
&i.GraphChannel.Bitcoin1Signature,
&i.GraphChannel.Bitcoin2Signature,
&i.GraphNode.ID,
&i.GraphNode.Version,
&i.GraphNode.PubKey,
&i.GraphNode.Alias,
&i.GraphNode.LastUpdate,
&i.GraphNode.Color,
&i.GraphNode.Signature,
&i.GraphNode_2.ID,
&i.GraphNode_2.Version,
&i.GraphNode_2.PubKey,
&i.GraphNode_2.Alias,
&i.GraphNode_2.LastUpdate,
&i.GraphNode_2.Color,
&i.GraphNode_2.Signature,
&i.Policy1ID,
&i.Policy1NodeID,
&i.Policy1Version,
&i.Policy1Timelock,
&i.Policy1FeePpm,
&i.Policy1BaseFeeMsat,
&i.Policy1MinHtlcMsat,
&i.Policy1MaxHtlcMsat,
&i.Policy1LastUpdate,
&i.Policy1Disabled,
&i.Policy1InboundBaseFeeMsat,
&i.Policy1InboundFeeRateMilliMsat,
&i.Policy1MessageFlags,
&i.Policy1ChannelFlags,
&i.Policy1Signature,
&i.Policy2ID,
&i.Policy2NodeID,
&i.Policy2Version,
&i.Policy2Timelock,
&i.Policy2FeePpm,
&i.Policy2BaseFeeMsat,
&i.Policy2MinHtlcMsat,
&i.Policy2MaxHtlcMsat,
&i.Policy2LastUpdate,
&i.Policy2Disabled,
&i.Policy2InboundBaseFeeMsat,
&i.Policy2InboundFeeRateMilliMsat,
&i.Policy2MessageFlags,
&i.Policy2ChannelFlags,
&i.Policy2Signature,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChannelsBySCIDs = `-- name: GetChannelsBySCIDs :many
SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature FROM graph_channels
WHERE version = $1

View File

@@ -44,6 +44,7 @@ type Querier interface {
GetChannelsByOutpoints(ctx context.Context, outpoints []string) ([]GetChannelsByOutpointsRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error)
GetChannelsBySCIDWithPolicies(ctx context.Context, arg GetChannelsBySCIDWithPoliciesParams) ([]GetChannelsBySCIDWithPoliciesRow, error)
GetChannelsBySCIDs(ctx context.Context, arg GetChannelsBySCIDsParams) ([]GraphChannel, error)
GetDatabaseVersion(ctx context.Context) (int32, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]GraphNodeExtraType, error)

View File

@@ -283,6 +283,57 @@ WHERE cet.channel_id = $1;
SELECT scid from graph_channels
WHERE outpoint = $1 AND version = $2;
-- name: GetChannelsBySCIDWithPolicies :many
SELECT
sqlc.embed(c),
sqlc.embed(n1),
sqlc.embed(n2),
-- Policy 1
cp1.id AS policy1_id,
cp1.node_id AS policy1_node_id,
cp1.version AS policy1_version,
cp1.timelock AS policy1_timelock,
cp1.fee_ppm AS policy1_fee_ppm,
cp1.base_fee_msat AS policy1_base_fee_msat,
cp1.min_htlc_msat AS policy1_min_htlc_msat,
cp1.max_htlc_msat AS policy1_max_htlc_msat,
cp1.last_update AS policy1_last_update,
cp1.disabled AS policy1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.message_flags AS policy1_message_flags,
cp1.channel_flags AS policy1_channel_flags,
cp1.signature AS policy1_signature,
-- Policy 2
cp2.id AS policy2_id,
cp2.node_id AS policy2_node_id,
cp2.version AS policy2_version,
cp2.timelock AS policy2_timelock,
cp2.fee_ppm AS policy2_fee_ppm,
cp2.base_fee_msat AS policy2_base_fee_msat,
cp2.min_htlc_msat AS policy2_min_htlc_msat,
cp2.max_htlc_msat AS policy2_max_htlc_msat,
cp2.last_update AS policy2_last_update,
cp2.disabled AS policy2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.message_flags AS policy_2_message_flags,
cp2.channel_flags AS policy_2_channel_flags,
cp2.signature AS policy2_signature
FROM graph_channels c
JOIN graph_nodes n1 ON c.node_id_1 = n1.id
JOIN graph_nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN graph_channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN graph_channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE
c.version = @version
AND c.scid IN (sqlc.slice('scids')/*SLICE:scids*/);
-- name: GetChannelsByPolicyLastUpdateRange :many
SELECT
sqlc.embed(c),