graph/db+sqldb: implement ForEachNode

In this commit the `ForEachNode` method is added to the SQLStore. With
this, the `TestGraphCacheTraversal` unit test can be run against SQL
backends.
This commit is contained in:
Elle Mouton
2025-06-11 15:59:43 +02:00
parent d60761f79c
commit 8af32951c7
5 changed files with 183 additions and 3 deletions

View File

@ -1436,7 +1436,7 @@ func TestGraphTraversalCacheable(t *testing.T) {
func TestGraphCacheTraversal(t *testing.T) {
t.Parallel()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// We'd like to test some of the graph traversal capabilities within
// the DB, so we'll create a series of fake nodes to insert into the

View File

@ -28,6 +28,10 @@ import (
"github.com/lightningnetwork/lnd/tor"
)
// pageSize is the limit for the number of records that can be returned
// in a paginated query. This can be tuned after some benchmarks.
const pageSize = 2000
// ProtocolVersion is an enum that defines the gossip protocol version of a
// message.
type ProtocolVersion uint8
@ -53,6 +57,7 @@ type SQLQueries interface {
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
@ -715,6 +720,130 @@ func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint,
}, sqldb.NoOpReset)
}
// ForEachNode iterates through all the stored vertices/nodes in the graph,
// executing the passed callback with each node encountered. If the callback
// returns an error, then the transaction is aborted and the iteration stops
// early. Any operations performed on the NodeTx passed to the call-back are
// executed under the same read transaction and so, methods on the NodeTx object
// _MUST_ only be called from within the call-back.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error {
var (
ctx = context.TODO()
lastID int64 = 0
)
handleNode := func(db SQLQueries, dbNode sqlc.Node) error {
node, err := buildNode(ctx, db, &dbNode)
if err != nil {
return fmt.Errorf("unable to build node(id=%d): %w",
dbNode.ID, err)
}
err = cb(
newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node),
)
if err != nil {
return fmt.Errorf("callback failed for node(id=%d): %w",
dbNode.ID, err)
}
return nil
}
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
for {
nodes, err := db.ListNodesPaginated(
ctx, sqlc.ListNodesPaginatedParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: pageSize,
},
)
if err != nil {
return fmt.Errorf("unable to fetch nodes: %w",
err)
}
if len(nodes) == 0 {
break
}
for _, dbNode := range nodes {
err = handleNode(db, dbNode)
if err != nil {
return err
}
lastID = dbNode.ID
}
}
return nil
}, sqldb.NoOpReset)
}
// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the
// SQLStore and a SQL transaction.
type sqlGraphNodeTx struct {
db SQLQueries
id int64
node *models.LightningNode
chain chainhash.Hash
}
// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx
// interface.
var _ NodeRTx = (*sqlGraphNodeTx)(nil)
func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash,
id int64, node *models.LightningNode) *sqlGraphNodeTx {
return &sqlGraphNodeTx{
db: db,
chain: chain,
id: id,
node: node,
}
}
// Node returns the raw information of the node.
//
// NOTE: This is a part of the NodeRTx interface.
func (s *sqlGraphNodeTx) Node() *models.LightningNode {
return s.node
}
// ForEachChannel can be used to iterate over the node's channels under the same
// transaction used to fetch the node.
//
// NOTE: This is a part of the NodeRTx interface.
func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error {
ctx := context.TODO()
return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb)
}
// FetchNode fetches the node with the given pub key under the same transaction
// used to fetch the current node. The returned node is also a NodeRTx and any
// operations on that NodeRTx will also be done under the same transaction.
//
// NOTE: This is a part of the NodeRTx interface.
func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
ctx := context.TODO()
id, node, err := getNodeByPubKey(ctx, s.db, nodePub)
if err != nil {
return nil, fmt.Errorf("unable to fetch V1 node(%x): %w",
nodePub, err)
}
return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
}
// forEachNodeChannel iterates through all channels of a node, executing
// the passed callback on each. The call-back is provided with the channel's
// edge information, the outgoing policy and the incoming policy for the

View File

@ -729,7 +729,6 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa
}
const listChannelsByNodeID = `-- name: ListChannelsByNodeID :many
SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,
@ -880,6 +879,51 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo
return items, nil
}
const listNodesPaginated = `-- name: ListNodesPaginated :many
SELECT id, version, pub_key, alias, last_update, color, signature
FROM nodes
WHERE version = $1 AND id > $2
ORDER BY id
LIMIT $3
`
type ListNodesPaginatedParams struct {
Version int16
ID int64
Limit int32
}
func (q *Queries) ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error) {
rows, err := q.db.QueryContext(ctx, listNodesPaginated, arg.Version, arg.ID, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Node
for rows.Next() {
var i Node
if err := rows.Scan(
&i.ID,
&i.Version,
&i.PubKey,
&i.Alias,
&i.LastUpdate,
&i.Color,
&i.Signature,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one
/* ─────────────────────────────────────────────
channel_policies table queries

View File

@ -64,6 +64,7 @@ type Querier interface {
InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error
InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error)
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error

View File

@ -27,6 +27,13 @@ FROM nodes
WHERE pub_key = $1
AND version = $2;
-- name: ListNodesPaginated :many
SELECT *
FROM nodes
WHERE version = $1 AND id > $2
ORDER BY id
LIMIT $3;
-- name: DeleteNodeByPubKey :execresult
DELETE FROM nodes
WHERE pub_key = $1
@ -194,7 +201,6 @@ ORDER BY scid DESC
LIMIT 1;
-- name: ListChannelsByNodeID :many
SELECT sqlc.embed(c),
n1.pub_key AS node1_pubkey,
n2.pub_key AS node2_pubkey,