From 8af32951c734bc05aa78a0613f9bdf20d53a6b2f Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 15:59:43 +0200 Subject: [PATCH] graph/db+sqldb: implement ForEachNode In this commit the `ForEachNode` method is added to the SQLStore. With this, the `TestGraphCacheTraversal` unit test can be run against SQL backends. --- graph/db/graph_test.go | 2 +- graph/db/sql_store.go | 129 +++++++++++++++++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 46 ++++++++++++- sqldb/sqlc/querier.go | 1 + sqldb/sqlc/queries/graph.sql | 8 ++- 5 files changed, 183 insertions(+), 3 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index b54b1fa4c..d48e229f3 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -1436,7 +1436,7 @@ func TestGraphTraversalCacheable(t *testing.T) { func TestGraphCacheTraversal(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index ca8569aee..893575154 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -28,6 +28,10 @@ import ( "github.com/lightningnetwork/lnd/tor" ) +// pageSize is the limit for the number of records that can be returned +// in a paginated query. This can be tuned after some benchmarks. +const pageSize = 2000 + // ProtocolVersion is an enum that defines the gossip protocol version of a // message. type ProtocolVersion uint8 @@ -53,6 +57,7 @@ type SQLQueries interface { UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error) GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error) GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error) + ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error) DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error) @@ -715,6 +720,130 @@ func (s *SQLStore) ForEachSourceNodeChannel(cb func(chanPoint wire.OutPoint, }, sqldb.NoOpReset) } +// ForEachNode iterates through all the stored vertices/nodes in the graph, +// executing the passed callback with each node encountered. If the callback +// returns an error, then the transaction is aborted and the iteration stops +// early. Any operations performed on the NodeTx passed to the call-back are +// executed under the same read transaction and so, methods on the NodeTx object +// _MUST_ only be called from within the call-back. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) ForEachNode(cb func(tx NodeRTx) error) error { + var ( + ctx = context.TODO() + lastID int64 = 0 + ) + + handleNode := func(db SQLQueries, dbNode sqlc.Node) error { + node, err := buildNode(ctx, db, &dbNode) + if err != nil { + return fmt.Errorf("unable to build node(id=%d): %w", + dbNode.ID, err) + } + + err = cb( + newSQLGraphNodeTx(db, s.cfg.ChainHash, dbNode.ID, node), + ) + if err != nil { + return fmt.Errorf("callback failed for node(id=%d): %w", + dbNode.ID, err) + } + + return nil + } + + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + for { + nodes, err := db.ListNodesPaginated( + ctx, sqlc.ListNodesPaginatedParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: pageSize, + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch nodes: %w", + err) + } + + if len(nodes) == 0 { + break + } + + for _, dbNode := range nodes { + err = handleNode(db, dbNode) + if err != nil { + return err + } + + lastID = dbNode.ID + } + } + + return nil + }, sqldb.NoOpReset) +} + +// sqlGraphNodeTx is an implementation of the NodeRTx interface backed by the +// SQLStore and a SQL transaction. +type sqlGraphNodeTx struct { + db SQLQueries + id int64 + node *models.LightningNode + chain chainhash.Hash +} + +// A compile-time constraint to ensure sqlGraphNodeTx implements the NodeRTx +// interface. +var _ NodeRTx = (*sqlGraphNodeTx)(nil) + +func newSQLGraphNodeTx(db SQLQueries, chain chainhash.Hash, + id int64, node *models.LightningNode) *sqlGraphNodeTx { + + return &sqlGraphNodeTx{ + db: db, + chain: chain, + id: id, + node: node, + } +} + +// Node returns the raw information of the node. +// +// NOTE: This is a part of the NodeRTx interface. +func (s *sqlGraphNodeTx) Node() *models.LightningNode { + return s.node +} + +// ForEachChannel can be used to iterate over the node's channels under the same +// transaction used to fetch the node. +// +// NOTE: This is a part of the NodeRTx interface. +func (s *sqlGraphNodeTx) ForEachChannel(cb func(*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy) error) error { + + ctx := context.TODO() + + return forEachNodeChannel(ctx, s.db, s.chain, s.id, cb) +} + +// FetchNode fetches the node with the given pub key under the same transaction +// used to fetch the current node. The returned node is also a NodeRTx and any +// operations on that NodeRTx will also be done under the same transaction. +// +// NOTE: This is a part of the NodeRTx interface. +func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) { + ctx := context.TODO() + + id, node, err := getNodeByPubKey(ctx, s.db, nodePub) + if err != nil { + return nil, fmt.Errorf("unable to fetch V1 node(%x): %w", + nodePub, err) + } + + return newSQLGraphNodeTx(s.db, s.chain, id, node), nil +} + // forEachNodeChannel iterates through all channels of a node, executing // the passed callback on each. The call-back is provided with the channel's // edge information, the outgoing policy and the incoming policy for the diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 884d5c340..5b7f974a2 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -729,7 +729,6 @@ func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeaturePa } const listChannelsByNodeID = `-- name: ListChannelsByNodeID :many - SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, n1.pub_key AS node1_pubkey, n2.pub_key AS node2_pubkey, @@ -880,6 +879,51 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo return items, nil } +const listNodesPaginated = `-- name: ListNodesPaginated :many +SELECT id, version, pub_key, alias, last_update, color, signature +FROM nodes +WHERE version = $1 AND id > $2 +ORDER BY id +LIMIT $3 +` + +type ListNodesPaginatedParams struct { + Version int16 + ID int64 + Limit int32 +} + +func (q *Queries) ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error) { + rows, err := q.db.QueryContext(ctx, listNodesPaginated, arg.Version, arg.ID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Node + for rows.Next() { + var i Node + if err := rows.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const upsertEdgePolicy = `-- name: UpsertEdgePolicy :one /* ───────────────────────────────────────────── channel_policies table queries diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index d4a6a9dd0..35d7b2829 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -64,6 +64,7 @@ type Querier interface { InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) + ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index c6dbb6e6b..7c98cc60e 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -27,6 +27,13 @@ FROM nodes WHERE pub_key = $1 AND version = $2; +-- name: ListNodesPaginated :many +SELECT * +FROM nodes +WHERE version = $1 AND id > $2 +ORDER BY id +LIMIT $3; + -- name: DeleteNodeByPubKey :execresult DELETE FROM nodes WHERE pub_key = $1 @@ -194,7 +201,6 @@ ORDER BY scid DESC LIMIT 1; -- name: ListChannelsByNodeID :many - SELECT sqlc.embed(c), n1.pub_key AS node1_pubkey, n2.pub_key AS node2_pubkey,