From 06079828865a362dbefeda9508ef69b6d1ad4f27 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 16:12:08 +0200 Subject: [PATCH] graph/db: implement ForEachNodeDirectedChannel and ForEachNodeCacheable Here we add the `ForEachNodeDirectedChannel` and `ForEachNodeCacheable` SQLStore implementations which then lets us run `TestGraphTraversalCacheable` and `TestGraphCacheForEachNodeChannel` against SQL backends. --- graph/db/graph_test.go | 21 ++-- graph/db/sql_store.go | 236 ++++++++++++++++++++++++++++++++++- sqldb/sqlc/graph.sql.go | 61 +++++++++ sqldb/sqlc/querier.go | 2 + sqldb/sqlc/queries/graph.sql | 13 ++ 5 files changed, 324 insertions(+), 9 deletions(-) diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index d48e229f3..f681a1f75 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -23,6 +23,7 @@ import ( "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" @@ -1366,7 +1367,7 @@ func TestGraphTraversal(t *testing.T) { func TestGraphTraversalCacheable(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test some of the graph traversal capabilities within // the DB, so we'll create a series of fake nodes to insert into the @@ -4162,7 +4163,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { t.Parallel() ctx := context.Background() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // Unset the channel graph cache to simulate the user running with the // option turned off. @@ -4212,21 +4213,25 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { edge1.ExtraOpaqueData = []byte{ 253, 217, 3, 8, 0, 0, 0, 10, 0, 0, 0, 20, } + inboundFee := lnwire.Fee{ + BaseFee: 10, + FeeRate: 20, + } + edge1.InboundFee = fn.Some(inboundFee) require.NoError(t, graph.UpdateEdgePolicy(edge1)) edge1 = copyEdgePolicy(edge1) // Avoid read/write race conditions. directedChan := getSingleChannel() require.NotNil(t, directedChan) - expectedInbound := lnwire.Fee{ - BaseFee: 10, - FeeRate: 20, - } - require.Equal(t, expectedInbound, directedChan.InboundFee) + require.Equal(t, inboundFee, directedChan.InboundFee) // Set an invalid inbound fee and check that persistence fails. edge1.ExtraOpaqueData = []byte{ 253, 217, 3, 8, 0, } + // We need to update the timestamp so that we don't hit the DB conflict + // error when we try to update the edge policy. + edge1.LastUpdate = edge1.LastUpdate.Add(time.Second) require.ErrorIs( t, graph.UpdateEdgePolicy(edge1), ErrParsingExtraTLVBytes, ) @@ -4235,7 +4240,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) { // the previous result when we query the channel again. directedChan = getSingleChannel() require.NotNil(t, directedChan) - require.Equal(t, expectedInbound, directedChan.InboundFee) + require.Equal(t, inboundFee, directedChan.InboundFee) } // TestGraphLoading asserts that the cache is properly reconstructed after a diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 893575154..e7bb18904 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -56,8 +56,10 @@ type SQLQueries interface { */ UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error) GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error) + GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error) ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error) + ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error) DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error) @@ -844,6 +846,224 @@ func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) { return newSQLGraphNodeTx(s.db, s.chain, id, node), nil } +// ForEachNodeDirectedChannel iterates through all channels of a given node, +// executing the passed callback on the directed edge representing the channel +// and its incoming policy. If the callback returns an error, then the iteration +// is halted with the error propagated back up to the caller. +// +// Unknown policies are passed into the callback as nil values. +// +// NOTE: this is part of the graphdb.NodeTraverser interface. +func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex, + cb func(channel *DirectedChannel) error) error { + + var ctx = context.TODO() + + return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + return forEachNodeDirectedChannel(ctx, db, nodePub, cb) + }, sqldb.NoOpReset) +} + +// ForEachNodeCacheable iterates through all the stored vertices/nodes in the +// graph, executing the passed callback with each node encountered. If the +// callback returns an error, then the transaction is aborted and the iteration +// stops early. +// +// NOTE: This is a part of the V1Store interface. +func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex, + *lnwire.FeatureVector) error) error { + + ctx := context.TODO() + + err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error { + return forEachNodeCacheable(ctx, db, func(nodeID int64, + nodePub route.Vertex) error { + + features, err := getNodeFeatures(ctx, db, nodeID) + if err != nil { + return fmt.Errorf("unable to fetch node "+ + "features: %w", err) + } + + return cb(nodePub, features) + }) + }, sqldb.NoOpReset) + if err != nil { + return fmt.Errorf("unable to fetch nodes: %w", err) + } + + return nil +} + +// forEachNodeDirectedChannel iterates through all channels of a given +// node, executing the passed callback on the directed edge representing the +// channel and its incoming policy. If the node is not found, no error is +// returned. +func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries, + nodePub route.Vertex, cb func(channel *DirectedChannel) error) error { + + toNodeCallback := func() route.Vertex { + return nodePub + } + + dbID, err := db.GetNodeIDByPubKey( + ctx, sqlc.GetNodeIDByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: nodePub[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil + } else if err != nil { + return fmt.Errorf("unable to fetch node: %w", err) + } + + rows, err := db.ListChannelsByNodeID( + ctx, sqlc.ListChannelsByNodeIDParams{ + Version: int16(ProtocolV1), + NodeID1: dbID, + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch channels: %w", err) + } + + // Exit early if there are no channels for this node so we don't + // do the unnecessary feature fetching. + if len(rows) == 0 { + return nil + } + + features, err := getNodeFeatures(ctx, db, dbID) + if err != nil { + return fmt.Errorf("unable to fetch node features: %w", err) + } + + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1Pubkey, row.Node2Pubkey, + ) + if err != nil { + return fmt.Errorf("unable to build node vertices: %w", + err) + } + + edge, err := buildCacheableChannelInfo( + row.Channel, node1, node2, + ) + if err != nil { + return err + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return err + } + + var p1, p2 *models.CachedEdgePolicy + if dbPol1 != nil { + policy1, err := buildChanPolicy( + *dbPol1, edge.ChannelID, nil, node2, true, + ) + if err != nil { + return err + } + + p1 = models.NewCachedPolicy(policy1) + } + if dbPol2 != nil { + policy2, err := buildChanPolicy( + *dbPol2, edge.ChannelID, nil, node1, false, + ) + if err != nil { + return err + } + + p2 = models.NewCachedPolicy(policy2) + } + + // Determine the outgoing and incoming policy for this + // channel and node combo. + outPolicy, inPolicy := p1, p2 + if p1 != nil && node2 == nodePub { + outPolicy, inPolicy = p2, p1 + } else if p2 != nil && node1 != nodePub { + outPolicy, inPolicy = p2, p1 + } + + var cachedInPolicy *models.CachedEdgePolicy + if inPolicy != nil { + cachedInPolicy = inPolicy + cachedInPolicy.ToNodePubKey = toNodeCallback + cachedInPolicy.ToNodeFeatures = features + } + + directedChannel := &DirectedChannel{ + ChannelID: edge.ChannelID, + IsNode1: nodePub == edge.NodeKey1Bytes, + OtherNode: edge.NodeKey2Bytes, + Capacity: edge.Capacity, + OutPolicySet: outPolicy != nil, + InPolicy: cachedInPolicy, + } + if outPolicy != nil { + outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) { + directedChannel.InboundFee = fee + }) + } + + if nodePub == edge.NodeKey2Bytes { + directedChannel.OtherNode = edge.NodeKey1Bytes + } + + if err := cb(directedChannel); err != nil { + return err + } + } + + return nil +} + +// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database, +// and executes the provided callback for each node. +func forEachNodeCacheable(ctx context.Context, db SQLQueries, + cb func(nodeID int64, nodePub route.Vertex) error) error { + + var lastID int64 + + for { + nodes, err := db.ListNodeIDsAndPubKeys( + ctx, sqlc.ListNodeIDsAndPubKeysParams{ + Version: int16(ProtocolV1), + ID: lastID, + Limit: pageSize, + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch nodes: %w", err) + } + + if len(nodes) == 0 { + break + } + + for _, node := range nodes { + var pub route.Vertex + copy(pub[:], node.PubKey) + + if err := cb(node.ID, pub); err != nil { + return fmt.Errorf("forEachNodeCacheable "+ + "callback failed for node(id=%d): %w", + node.ID, err) + } + + lastID = node.ID + } + } + + return nil +} + // forEachNodeChannel iterates through all channels of a node, executing // the passed callback on each. The call-back is provided with the channel's // edge information, the outgoing policy and the incoming policy for the @@ -1033,6 +1253,20 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries, return dbNode.ID, node, nil } +// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the +// provided database channel row and the public keys of the two nodes +// involved in the channel. +func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub, + node2Pub route.Vertex) (*models.CachedEdgeInfo, error) { + + return &models.CachedEdgeInfo{ + ChannelID: byteOrder.Uint64(dbChan.Scid), + NodeKey1Bytes: node1Pub, + NodeKey2Bytes: node2Pub, + Capacity: btcutil.Amount(dbChan.Capacity.Int64), + }, nil +} + // buildNode constructs a LightningNode instance from the given database node // record. The node's features, addresses and extra signed fields are also // fetched from the database and set on the node. @@ -1589,7 +1823,7 @@ func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) { // pass it into the P2P decoding variant. parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r) if err != nil { - return nil, err + return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err) } if len(parsedTypes) == 0 { return nil, nil diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index 5b7f974a2..606d40217 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -533,6 +533,25 @@ func (q *Queries) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeatur return items, nil } +const getNodeIDByPubKey = `-- name: GetNodeIDByPubKey :one +SELECT id +FROM nodes +WHERE pub_key = $1 + AND version = $2 +` + +type GetNodeIDByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) { + row := q.db.QueryRowContext(ctx, getNodeIDByPubKey, arg.PubKey, arg.Version) + var id int64 + err := row.Scan(&id) + return id, err +} + const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many SELECT id, version, pub_key, alias, last_update, color, signature FROM nodes @@ -879,6 +898,48 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo return items, nil } +const listNodeIDsAndPubKeys = `-- name: ListNodeIDsAndPubKeys :many +SELECT id, pub_key +FROM nodes +WHERE version = $1 AND id > $2 +ORDER BY id +LIMIT $3 +` + +type ListNodeIDsAndPubKeysParams struct { + Version int16 + ID int64 + Limit int32 +} + +type ListNodeIDsAndPubKeysRow struct { + ID int64 + PubKey []byte +} + +func (q *Queries) ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error) { + rows, err := q.db.QueryContext(ctx, listNodeIDsAndPubKeys, arg.Version, arg.ID, arg.Limit) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListNodeIDsAndPubKeysRow + for rows.Next() { + var i ListNodeIDsAndPubKeysRow + if err := rows.Scan(&i.ID, &i.PubKey); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listNodesPaginated = `-- name: ListNodesPaginated :many SELECT id, version, pub_key, alias, last_update, color, signature FROM nodes diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 35d7b2829..c0abde3c2 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -48,6 +48,7 @@ type Querier interface { GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error) GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) + GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) HighestSCID(ctx context.Context, version int16) ([]byte, error) @@ -64,6 +65,7 @@ type Querier interface { InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error) + ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error) ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error) NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 7c98cc60e..510195906 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -27,6 +27,12 @@ FROM nodes WHERE pub_key = $1 AND version = $2; +-- name: GetNodeIDByPubKey :one +SELECT id +FROM nodes +WHERE pub_key = $1 + AND version = $2; + -- name: ListNodesPaginated :many SELECT * FROM nodes @@ -34,6 +40,13 @@ WHERE version = $1 AND id > $2 ORDER BY id LIMIT $3; +-- name: ListNodeIDsAndPubKeys :many +SELECT id, pub_key +FROM nodes +WHERE version = $1 AND id > $2 +ORDER BY id +LIMIT $3; + -- name: DeleteNodeByPubKey :execresult DELETE FROM nodes WHERE pub_key = $1