graph/db+sqldb: impl ForEachNodeCached and ForEachChannel

Which let's us run `TestGraphTraversal` against our SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 16:41:13 +02:00
parent 6aa2933379
commit ff84fa1cb2
5 changed files with 462 additions and 1 deletions

View File

@@ -1277,7 +1277,7 @@ func TestForEachSourceNodeChannel(t *testing.T) {
func TestGraphTraversal(t *testing.T) {
t.Parallel()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// We'd like to test some of the graph traversal capabilities within
// the DB, so we'll create a series of fake nodes to insert into the

View File

@@ -90,6 +90,7 @@ type SQLQueries interface {
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg sqlc.GetChannelsByPolicyLastUpdateRangeParams) ([]sqlc.GetChannelsByPolicyLastUpdateRangeRow, error)
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
@@ -1044,6 +1045,223 @@ func (s *SQLStore) ChanUpdatesInHorizon(startTime,
return edges, nil
}
// ForEachNodeCached is similar to forEachNode, but it returns DirectedChannel
// data to the call-back.
//
// NOTE: The callback contents MUST not be modified.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) ForEachNodeCached(cb func(node route.Vertex,
chans map[uint64]*DirectedChannel) error) error {
var ctx = context.TODO()
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
return forEachNodeCacheable(ctx, db, func(nodeID int64,
nodePub route.Vertex) error {
features, err := getNodeFeatures(ctx, db, nodeID)
if err != nil {
return fmt.Errorf("unable to fetch "+
"node(id=%d) features: %w", nodeID, err)
}
toNodeCallback := func() route.Vertex {
return nodePub
}
rows, err := db.ListChannelsByNodeID(
ctx, sqlc.ListChannelsByNodeIDParams{
Version: int16(ProtocolV1),
NodeID1: nodeID,
},
)
if err != nil {
return fmt.Errorf("unable to fetch channels "+
"of node(id=%d): %w", nodeID, err)
}
channels := make(map[uint64]*DirectedChannel, len(rows))
for _, row := range rows {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return err
}
e, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash,
row.Channel.ID, row.Channel, node1,
node2,
)
if err != nil {
return fmt.Errorf("unable to build "+
"channel info: %w", err)
}
dbPol1, dbPol2, err := extractChannelPolicies(
row,
)
if err != nil {
return fmt.Errorf("unable to "+
"extract channel "+
"policies: %w", err)
}
p1, p2, err := getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, e.ChannelID,
node1, node2,
)
if err != nil {
return fmt.Errorf("unable to "+
"build channel policies: %w",
err)
}
// Determine the outgoing and incoming policy
// for this channel and node combo.
outPolicy, inPolicy := p1, p2
if p1 != nil && p1.ToNode == nodePub {
outPolicy, inPolicy = p2, p1
} else if p2 != nil && p2.ToNode != nodePub {
outPolicy, inPolicy = p2, p1
}
var cachedInPolicy *models.CachedEdgePolicy
if inPolicy != nil {
cachedInPolicy = models.NewCachedPolicy(
p2,
)
cachedInPolicy.ToNodePubKey =
toNodeCallback
cachedInPolicy.ToNodeFeatures =
features
}
var inboundFee lnwire.Fee
outPolicy.InboundFee.WhenSome(
func(fee lnwire.Fee) {
inboundFee = fee
},
)
directedChannel := &DirectedChannel{
ChannelID: e.ChannelID,
IsNode1: nodePub ==
e.NodeKey1Bytes,
OtherNode: e.NodeKey2Bytes,
Capacity: e.Capacity,
OutPolicySet: p1 != nil,
InPolicy: cachedInPolicy,
InboundFee: inboundFee,
}
if nodePub == e.NodeKey2Bytes {
directedChannel.OtherNode =
e.NodeKey1Bytes
}
channels[e.ChannelID] = directedChannel
}
return cb(nodePub, channels)
})
}, sqldb.NoOpReset)
}
// ForEachChannel iterates through all the channel edges stored within the
// graph and invokes the passed callback for each edge. The callback takes two
// edges as since this is a directed graph, both the in/out edges are visited.
// If the callback returns an error, then the transaction is aborted and the
// iteration stops early.
//
// NOTE: If an edge can't be found, or wasn't advertised, then a nil pointer
// for that particular channel edge routing policy will be passed into the
// callback.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) ForEachChannel(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
ctx := context.TODO()
handleChannel := func(db SQLQueries,
row sqlc.ListChannelsWithPoliciesPaginatedRow) error {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return fmt.Errorf("unable to build node vertices: %w",
err)
}
edge, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID, row.Channel,
node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel info: %w",
err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
p1, p2, err := getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
)
if err != nil {
return fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
err = cb(edge, p1, p2)
if err != nil {
return fmt.Errorf("callback failed for channel "+
"id=%d: %w", edge.ChannelID, err)
}
return nil
}
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var lastID int64
for {
//nolint:ll
rows, err := db.ListChannelsWithPoliciesPaginated(
ctx, sqlc.ListChannelsWithPoliciesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: pageSize,
},
)
if err != nil {
return err
}
if len(rows) == 0 {
break
}
for _, row := range rows {
err := handleChannel(db, row)
if err != nil {
return err
}
lastID = row.Channel.ID
}
}
return nil
}, sqldb.NoOpReset)
}
// forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is
@@ -2525,6 +2743,46 @@ func extractChannelPolicies(row any) (*sqlc.ChannelPolicy, *sqlc.ChannelPolicy,
}
}
return policy1, policy2, nil
case sqlc.ListChannelsWithPoliciesPaginatedRow:
if r.Policy1ID.Valid {
policy1 = &sqlc.ChannelPolicy{
ID: r.Policy1ID.Int64,
Version: r.Policy1Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy1NodeID.Int64,
Timelock: r.Policy1Timelock.Int32,
FeePpm: r.Policy1FeePpm.Int64,
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
LastUpdate: r.Policy1LastUpdate,
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
Disabled: r.Policy1Disabled,
Signature: r.Policy1Signature,
}
}
if r.Policy2ID.Valid {
policy2 = &sqlc.ChannelPolicy{
ID: r.Policy2ID.Int64,
Version: r.Policy2Version.Int16,
ChannelID: r.Channel.ID,
NodeID: r.Policy2NodeID.Int64,
Timelock: r.Policy2Timelock.Int32,
FeePpm: r.Policy2FeePpm.Int64,
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
LastUpdate: r.Policy2LastUpdate,
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
Disabled: r.Policy2Disabled,
Signature: r.Policy2Signature,
}
}
return policy1, policy2, nil
default:
return nil, nil, fmt.Errorf("unexpected row type in "+

View File

@@ -1070,6 +1070,159 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo
return items, nil
}
const listChannelsWithPoliciesPaginated = `-- name: ListChannelsWithPoliciesPaginated :many
SELECT
c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
-- Join node pubkeys
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,
-- Node 1 policy
cp1.id AS policy_1_id,
cp1.node_id AS policy_1_node_id,
cp1.version AS policy_1_version,
cp1.timelock AS policy_1_timelock,
cp1.fee_ppm AS policy_1_fee_ppm,
cp1.base_fee_msat AS policy_1_base_fee_msat,
cp1.min_htlc_msat AS policy_1_min_htlc_msat,
cp1.max_htlc_msat AS policy_1_max_htlc_msat,
cp1.last_update AS policy_1_last_update,
cp1.disabled AS policy_1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy_1_signature,
-- Node 2 policy
cp2.id AS policy_2_id,
cp2.node_id AS policy_2_node_id,
cp2.version AS policy_2_version,
cp2.timelock AS policy_2_timelock,
cp2.fee_ppm AS policy_2_fee_ppm,
cp2.base_fee_msat AS policy_2_base_fee_msat,
cp2.min_htlc_msat AS policy_2_min_htlc_msat,
cp2.max_htlc_msat AS policy_2_max_htlc_msat,
cp2.last_update AS policy_2_last_update,
cp2.disabled AS policy_2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy_2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.version = $1 AND c.id > $2
ORDER BY c.id
LIMIT $3
`
type ListChannelsWithPoliciesPaginatedParams struct {
Version int16
ID int64
Limit int32
}
type ListChannelsWithPoliciesPaginatedRow struct {
Channel Channel
Node1Pubkey []byte
Node2Pubkey []byte
Policy1ID sql.NullInt64
Policy1NodeID sql.NullInt64
Policy1Version sql.NullInt16
Policy1Timelock sql.NullInt32
Policy1FeePpm sql.NullInt64
Policy1BaseFeeMsat sql.NullInt64
Policy1MinHtlcMsat sql.NullInt64
Policy1MaxHtlcMsat sql.NullInt64
Policy1LastUpdate sql.NullInt64
Policy1Disabled sql.NullBool
Policy1InboundBaseFeeMsat sql.NullInt64
Policy1InboundFeeRateMilliMsat sql.NullInt64
Policy1Signature []byte
Policy2ID sql.NullInt64
Policy2NodeID sql.NullInt64
Policy2Version sql.NullInt16
Policy2Timelock sql.NullInt32
Policy2FeePpm sql.NullInt64
Policy2BaseFeeMsat sql.NullInt64
Policy2MinHtlcMsat sql.NullInt64
Policy2MaxHtlcMsat sql.NullInt64
Policy2LastUpdate sql.NullInt64
Policy2Disabled sql.NullBool
Policy2InboundBaseFeeMsat sql.NullInt64
Policy2InboundFeeRateMilliMsat sql.NullInt64
Policy2Signature []byte
}
func (q *Queries) ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error) {
rows, err := q.db.QueryContext(ctx, listChannelsWithPoliciesPaginated, arg.Version, arg.ID, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListChannelsWithPoliciesPaginatedRow
for rows.Next() {
var i ListChannelsWithPoliciesPaginatedRow
if err := rows.Scan(
&i.Channel.ID,
&i.Channel.Version,
&i.Channel.Scid,
&i.Channel.NodeID1,
&i.Channel.NodeID2,
&i.Channel.Outpoint,
&i.Channel.Capacity,
&i.Channel.BitcoinKey1,
&i.Channel.BitcoinKey2,
&i.Channel.Node1Signature,
&i.Channel.Node2Signature,
&i.Channel.Bitcoin1Signature,
&i.Channel.Bitcoin2Signature,
&i.Node1Pubkey,
&i.Node2Pubkey,
&i.Policy1ID,
&i.Policy1NodeID,
&i.Policy1Version,
&i.Policy1Timelock,
&i.Policy1FeePpm,
&i.Policy1BaseFeeMsat,
&i.Policy1MinHtlcMsat,
&i.Policy1MaxHtlcMsat,
&i.Policy1LastUpdate,
&i.Policy1Disabled,
&i.Policy1InboundBaseFeeMsat,
&i.Policy1InboundFeeRateMilliMsat,
&i.Policy1Signature,
&i.Policy2ID,
&i.Policy2NodeID,
&i.Policy2Version,
&i.Policy2Timelock,
&i.Policy2FeePpm,
&i.Policy2BaseFeeMsat,
&i.Policy2MinHtlcMsat,
&i.Policy2MaxHtlcMsat,
&i.Policy2LastUpdate,
&i.Policy2Disabled,
&i.Policy2InboundBaseFeeMsat,
&i.Policy2InboundFeeRateMilliMsat,
&i.Policy2Signature,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listNodeIDsAndPubKeys = `-- name: ListNodeIDsAndPubKeys :many
SELECT id, pub_key
FROM nodes

View File

@@ -66,6 +66,7 @@ type Querier interface {
InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error
InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
ListChannelsWithPoliciesPaginated(ctx context.Context, arg ListChannelsWithPoliciesPaginatedParams) ([]ListChannelsWithPoliciesPaginatedRow, error)
ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error)
ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error)
NextInvoiceSettleIndex(ctx context.Context) (int64, error)

View File

@@ -317,6 +317,55 @@ FROM channels c
WHERE c.version = $1
AND (c.node_id_1 = $2 OR c.node_id_2 = $2);
-- name: ListChannelsWithPoliciesPaginated :many
SELECT
sqlc.embed(c),
-- Join node pubkeys
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,
-- Node 1 policy
cp1.id AS policy_1_id,
cp1.node_id AS policy_1_node_id,
cp1.version AS policy_1_version,
cp1.timelock AS policy_1_timelock,
cp1.fee_ppm AS policy_1_fee_ppm,
cp1.base_fee_msat AS policy_1_base_fee_msat,
cp1.min_htlc_msat AS policy_1_min_htlc_msat,
cp1.max_htlc_msat AS policy_1_max_htlc_msat,
cp1.last_update AS policy_1_last_update,
cp1.disabled AS policy_1_disabled,
cp1.inbound_base_fee_msat AS policy1_inbound_base_fee_msat,
cp1.inbound_fee_rate_milli_msat AS policy1_inbound_fee_rate_milli_msat,
cp1.signature AS policy_1_signature,
-- Node 2 policy
cp2.id AS policy_2_id,
cp2.node_id AS policy_2_node_id,
cp2.version AS policy_2_version,
cp2.timelock AS policy_2_timelock,
cp2.fee_ppm AS policy_2_fee_ppm,
cp2.base_fee_msat AS policy_2_base_fee_msat,
cp2.min_htlc_msat AS policy_2_min_htlc_msat,
cp2.max_htlc_msat AS policy_2_max_htlc_msat,
cp2.last_update AS policy_2_last_update,
cp2.disabled AS policy_2_disabled,
cp2.inbound_base_fee_msat AS policy2_inbound_base_fee_msat,
cp2.inbound_fee_rate_milli_msat AS policy2_inbound_fee_rate_milli_msat,
cp2.signature AS policy_2_signature
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
LEFT JOIN channel_policies cp1
ON cp1.channel_id = c.id AND cp1.node_id = c.node_id_1 AND cp1.version = c.version
LEFT JOIN channel_policies cp2
ON cp2.channel_id = c.id AND cp2.node_id = c.node_id_2 AND cp2.version = c.version
WHERE c.version = $1 AND c.id > $2
ORDER BY c.id
LIMIT $3;
/* ─────────────────────────────────────────────
channel_features table queries
─────────────────────────────────────────────