sqldb+graph/db: implement FilterChannelRange

This lets us run `TestFilterChannelRange` against the SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 16:47:07 +02:00
parent ff84fa1cb2
commit 3687171cd5
6 changed files with 235 additions and 1 deletions

View File

@ -85,6 +85,7 @@ circuit. The indices are only available for forwarding events saved after v0.20.
* [3](https://github.com/lightningnetwork/lnd/pull/9887) * [3](https://github.com/lightningnetwork/lnd/pull/9887)
* [4](https://github.com/lightningnetwork/lnd/pull/9931) * [4](https://github.com/lightningnetwork/lnd/pull/9931)
* [5](https://github.com/lightningnetwork/lnd/pull/9935) * [5](https://github.com/lightningnetwork/lnd/pull/9935)
* [6](https://github.com/lightningnetwork/lnd/pull/9936)
## RPC Updates ## RPC Updates

View File

@ -2727,7 +2727,7 @@ func TestFilterChannelRange(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
graph := MakeTestGraph(t) graph := MakeTestGraphNew(t)
// We'll first populate our graph with two nodes. All channels created // We'll first populate our graph with two nodes. All channels created
// below will be made between these two nodes. // below will be made between these two nodes.

View File

@ -7,8 +7,10 @@ import (
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"maps"
"math" "math"
"net" "net"
"slices"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@ -92,6 +94,7 @@ type SQLQueries interface {
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error) ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error) ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg sqlc.GetPublicV1ChannelsBySCIDParams) ([]sqlc.Channel, error)
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
@ -100,6 +103,7 @@ type SQLQueries interface {
Channel Policy table queries. Channel Policy table queries.
*/ */
UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error) UpsertEdgePolicy(ctx context.Context, arg sqlc.UpsertEdgePolicyParams) (int64, error)
GetChannelPolicyByChannelAndNode(ctx context.Context, arg sqlc.GetChannelPolicyByChannelAndNodeParams) (sqlc.ChannelPolicy, error)
InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error InsertChanPolicyExtraType(ctx context.Context, arg sqlc.InsertChanPolicyExtraTypeParams) error
GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg sqlc.GetChannelPolicyExtraTypesParams) ([]sqlc.GetChannelPolicyExtraTypesRow, error)
@ -1262,6 +1266,133 @@ func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
}, sqldb.NoOpReset) }, sqldb.NoOpReset)
} }
// FilterChannelRange returns the channel ID's of all known channels which were
// mined in a block height within the passed range. The channel IDs are grouped
// by their common block height. This method can be used to quickly share with a
// peer the set of channels we know of within a particular range to catch them
// up after a period of time offline. If withTimestamps is true then the
// timestamp info of the latest received channel update messages of the channel
// will be included in the response.
//
// NOTE: This is part of the V1Store interface.
func (s *SQLStore) FilterChannelRange(startHeight, endHeight uint32,
withTimestamps bool) ([]BlockChannelRange, error) {
var (
ctx = context.TODO()
startSCID = &lnwire.ShortChannelID{
BlockHeight: startHeight,
}
endSCID = lnwire.ShortChannelID{
BlockHeight: endHeight,
TxIndex: math.MaxUint32 & 0x00ffffff,
TxPosition: math.MaxUint16,
}
)
var chanIDStart [8]byte
byteOrder.PutUint64(chanIDStart[:], startSCID.ToUint64())
var chanIDEnd [8]byte
byteOrder.PutUint64(chanIDEnd[:], endSCID.ToUint64())
// 1) get all channels where channelID is between start and end chan ID.
// 2) skip if not public (ie, no channel_proof)
// 3) collect that channel.
// 4) if timestamps are wanted, fetch both policies for node 1 and node2
// and add those timestamps to the collected channel.
channelsPerBlock := make(map[uint32][]ChannelUpdateInfo)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbChans, err := db.GetPublicV1ChannelsBySCID(
ctx, sqlc.GetPublicV1ChannelsBySCIDParams{
StartScid: chanIDStart[:],
EndScid: chanIDEnd[:],
},
)
if err != nil {
return fmt.Errorf("unable to fetch channel range: %w",
err)
}
for _, dbChan := range dbChans {
cid := lnwire.NewShortChanIDFromInt(
byteOrder.Uint64(dbChan.Scid),
)
chanInfo := NewChannelUpdateInfo(
cid, time.Time{}, time.Time{},
)
if !withTimestamps {
channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight],
chanInfo,
)
continue
}
//nolint:ll
node1Policy, err := db.GetChannelPolicyByChannelAndNode(
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
Version: int16(ProtocolV1),
ChannelID: dbChan.ID,
NodeID: dbChan.NodeID1,
},
)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch node1 "+
"policy: %w", err)
} else if err == nil {
chanInfo.Node1UpdateTimestamp = time.Unix(
node1Policy.LastUpdate.Int64, 0,
)
}
//nolint:ll
node2Policy, err := db.GetChannelPolicyByChannelAndNode(
ctx, sqlc.GetChannelPolicyByChannelAndNodeParams{
Version: int16(ProtocolV1),
ChannelID: dbChan.ID,
NodeID: dbChan.NodeID2,
},
)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch node2 "+
"policy: %w", err)
} else if err == nil {
chanInfo.Node2UpdateTimestamp = time.Unix(
node2Policy.LastUpdate.Int64, 0,
)
}
channelsPerBlock[cid.BlockHeight] = append(
channelsPerBlock[cid.BlockHeight], chanInfo,
)
}
return nil
}, func() {
channelsPerBlock = make(map[uint32][]ChannelUpdateInfo)
})
if err != nil {
return nil, fmt.Errorf("unable to fetch channel range: %w", err)
}
if len(channelsPerBlock) == 0 {
return nil, nil
}
// Return the channel ranges in ascending block height order.
blocks := slices.Collect(maps.Keys(channelsPerBlock))
slices.Sort(blocks)
return fn.Map(blocks, func(block uint32) BlockChannelRange {
return BlockChannelRange{
Height: block,
Channels: channelsPerBlock[block],
}
}), nil
}
// forEachNodeDirectedChannel iterates through all channels of a given // forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the // node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is // channel and its incoming policy. If the node is not found, no error is

View File

@ -316,6 +316,42 @@ func (q *Queries) GetChannelFeaturesAndExtras(ctx context.Context, channelID int
return items, nil return items, nil
} }
const getChannelPolicyByChannelAndNode = `-- name: GetChannelPolicyByChannelAndNode :one
SELECT id, version, channel_id, node_id, timelock, fee_ppm, base_fee_msat, min_htlc_msat, max_htlc_msat, last_update, disabled, inbound_base_fee_msat, inbound_fee_rate_milli_msat, signature
FROM channel_policies
WHERE channel_id = $1
AND node_id = $2
AND version = $3
`
type GetChannelPolicyByChannelAndNodeParams struct {
ChannelID int64
NodeID int64
Version int16
}
func (q *Queries) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error) {
row := q.db.QueryRowContext(ctx, getChannelPolicyByChannelAndNode, arg.ChannelID, arg.NodeID, arg.Version)
var i ChannelPolicy
err := row.Scan(
&i.ID,
&i.Version,
&i.ChannelID,
&i.NodeID,
&i.Timelock,
&i.FeePpm,
&i.BaseFeeMsat,
&i.MinHtlcMsat,
&i.MaxHtlcMsat,
&i.LastUpdate,
&i.Disabled,
&i.InboundBaseFeeMsat,
&i.InboundFeeRateMilliMsat,
&i.Signature,
)
return i, err
}
const getChannelPolicyExtraTypes = `-- name: GetChannelPolicyExtraTypes :many const getChannelPolicyExtraTypes = `-- name: GetChannelPolicyExtraTypes :many
SELECT SELECT
cp.id AS policy_id, cp.id AS policy_id,
@ -767,6 +803,56 @@ func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByL
return items, nil return items, nil
} }
const getPublicV1ChannelsBySCID = `-- name: GetPublicV1ChannelsBySCID :many
SELECT id, version, scid, node_id_1, node_id_2, outpoint, capacity, bitcoin_key_1, bitcoin_key_2, node_1_signature, node_2_signature, bitcoin_1_signature, bitcoin_2_signature
FROM channels
WHERE node_1_signature IS NOT NULL
AND scid >= $1
AND scid < $2
`
type GetPublicV1ChannelsBySCIDParams struct {
StartScid []byte
EndScid []byte
}
func (q *Queries) GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error) {
rows, err := q.db.QueryContext(ctx, getPublicV1ChannelsBySCID, arg.StartScid, arg.EndScid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Channel
for rows.Next() {
var i Channel
if err := rows.Scan(
&i.ID,
&i.Version,
&i.Scid,
&i.NodeID1,
&i.NodeID2,
&i.Outpoint,
&i.Capacity,
&i.BitcoinKey1,
&i.BitcoinKey2,
&i.Node1Signature,
&i.Node2Signature,
&i.Bitcoin1Signature,
&i.Bitcoin2Signature,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many
SELECT sn.node_id, n.pub_key SELECT sn.node_id, n.pub_key
FROM source_nodes sn FROM source_nodes sn

View File

@ -30,6 +30,7 @@ type Querier interface {
GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg GetChannelAndNodesBySCIDParams) (GetChannelAndNodesBySCIDRow, error)
GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error) GetChannelBySCID(ctx context.Context, arg GetChannelBySCIDParams) (Channel, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]GetChannelFeaturesAndExtrasRow, error)
GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error)
GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error)
GetDatabaseVersion(ctx context.Context) (int32, error) GetDatabaseVersion(ctx context.Context) (int32, error)
@ -51,6 +52,7 @@ type Querier interface {
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
GetPublicV1ChannelsBySCID(ctx context.Context, arg GetPublicV1ChannelsBySCIDParams) ([]Channel, error)
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error) HighestSCID(ctx context.Context, version int16) ([]byte, error)
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error

View File

@ -317,6 +317,13 @@ FROM channels c
WHERE c.version = $1 WHERE c.version = $1
AND (c.node_id_1 = $2 OR c.node_id_2 = $2); AND (c.node_id_1 = $2 OR c.node_id_2 = $2);
-- name: GetPublicV1ChannelsBySCID :many
SELECT *
FROM channels
WHERE node_1_signature IS NOT NULL
AND scid >= @start_scid
AND scid < @end_scid;
-- name: ListChannelsWithPoliciesPaginated :many -- name: ListChannelsWithPoliciesPaginated :many
SELECT SELECT
sqlc.embed(c), sqlc.embed(c),
@ -420,6 +427,13 @@ ON CONFLICT (channel_id, node_id, version)
WHERE EXCLUDED.last_update > channel_policies.last_update WHERE EXCLUDED.last_update > channel_policies.last_update
RETURNING id; RETURNING id;
-- name: GetChannelPolicyByChannelAndNode :one
SELECT *
FROM channel_policies
WHERE channel_id = $1
AND node_id = $2
AND version = $3;
/* ───────────────────────────────────────────── /* ─────────────────────────────────────────────
channel_policy_extra_types table queries channel_policy_extra_types table queries
───────────────────────────────────────────── ─────────────────────────────────────────────