graph/db: implement ForEachNodeDirectedChannel and ForEachNodeCacheable

Here we add the `ForEachNodeDirectedChannel` and `ForEachNodeCacheable`
SQLStore implementations which then lets us run
`TestGraphTraversalCacheable` and `TestGraphCacheForEachNodeChannel`
against SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 16:12:08 +02:00
parent 8af32951c7
commit 0607982886
5 changed files with 324 additions and 9 deletions

View File

@@ -23,6 +23,7 @@ import (
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
@@ -1366,7 +1367,7 @@ func TestGraphTraversal(t *testing.T) {
func TestGraphTraversalCacheable(t *testing.T) {
t.Parallel()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// We'd like to test some of the graph traversal capabilities within
// the DB, so we'll create a series of fake nodes to insert into the
@@ -4162,7 +4163,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
t.Parallel()
ctx := context.Background()
graph := MakeTestGraph(t)
graph := MakeTestGraphNew(t)
// Unset the channel graph cache to simulate the user running with the
// option turned off.
@@ -4212,21 +4213,25 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
edge1.ExtraOpaqueData = []byte{
253, 217, 3, 8, 0, 0, 0, 10, 0, 0, 0, 20,
}
inboundFee := lnwire.Fee{
BaseFee: 10,
FeeRate: 20,
}
edge1.InboundFee = fn.Some(inboundFee)
require.NoError(t, graph.UpdateEdgePolicy(edge1))
edge1 = copyEdgePolicy(edge1) // Avoid read/write race conditions.
directedChan := getSingleChannel()
require.NotNil(t, directedChan)
expectedInbound := lnwire.Fee{
BaseFee: 10,
FeeRate: 20,
}
require.Equal(t, expectedInbound, directedChan.InboundFee)
require.Equal(t, inboundFee, directedChan.InboundFee)
// Set an invalid inbound fee and check that persistence fails.
edge1.ExtraOpaqueData = []byte{
253, 217, 3, 8, 0,
}
// We need to update the timestamp so that we don't hit the DB conflict
// error when we try to update the edge policy.
edge1.LastUpdate = edge1.LastUpdate.Add(time.Second)
require.ErrorIs(
t, graph.UpdateEdgePolicy(edge1), ErrParsingExtraTLVBytes,
)
@@ -4235,7 +4240,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
// the previous result when we query the channel again.
directedChan = getSingleChannel()
require.NotNil(t, directedChan)
require.Equal(t, expectedInbound, directedChan.InboundFee)
require.Equal(t, inboundFee, directedChan.InboundFee)
}
// TestGraphLoading asserts that the cache is properly reconstructed after a

View File

@@ -56,8 +56,10 @@ type SQLQueries interface {
*/
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
@@ -844,6 +846,224 @@ func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
}
// ForEachNodeDirectedChannel iterates through all channels of a given node,
// executing the passed callback on the directed edge representing the channel
// and its incoming policy. If the callback returns an error, then the iteration
// is halted with the error propagated back up to the caller.
//
// Unknown policies are passed into the callback as nil values.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
cb func(channel *DirectedChannel) error) error {
var ctx = context.TODO()
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
}, sqldb.NoOpReset)
}
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
// graph, executing the passed callback with each node encountered. If the
// callback returns an error, then the transaction is aborted and the iteration
// stops early.
//
// NOTE: This is a part of the V1Store interface.
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
*lnwire.FeatureVector) error) error {
ctx := context.TODO()
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
return forEachNodeCacheable(ctx, db, func(nodeID int64,
nodePub route.Vertex) error {
features, err := getNodeFeatures(ctx, db, nodeID)
if err != nil {
return fmt.Errorf("unable to fetch node "+
"features: %w", err)
}
return cb(nodePub, features)
})
}, sqldb.NoOpReset)
if err != nil {
return fmt.Errorf("unable to fetch nodes: %w", err)
}
return nil
}
// forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is
// returned.
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
toNodeCallback := func() route.Vertex {
return nodePub
}
dbID, err := db.GetNodeIDByPubKey(
ctx, sqlc.GetNodeIDByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: nodePub[:],
},
)
if errors.Is(err, sql.ErrNoRows) {
return nil
} else if err != nil {
return fmt.Errorf("unable to fetch node: %w", err)
}
rows, err := db.ListChannelsByNodeID(
ctx, sqlc.ListChannelsByNodeIDParams{
Version: int16(ProtocolV1),
NodeID1: dbID,
},
)
if err != nil {
return fmt.Errorf("unable to fetch channels: %w", err)
}
// Exit early if there are no channels for this node so we don't
// do the unnecessary feature fetching.
if len(rows) == 0 {
return nil
}
features, err := getNodeFeatures(ctx, db, dbID)
if err != nil {
return fmt.Errorf("unable to fetch node features: %w", err)
}
for _, row := range rows {
node1, node2, err := buildNodeVertices(
row.Node1Pubkey, row.Node2Pubkey,
)
if err != nil {
return fmt.Errorf("unable to build node vertices: %w",
err)
}
edge, err := buildCacheableChannelInfo(
row.Channel, node1, node2,
)
if err != nil {
return err
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return err
}
var p1, p2 *models.CachedEdgePolicy
if dbPol1 != nil {
policy1, err := buildChanPolicy(
*dbPol1, edge.ChannelID, nil, node2, true,
)
if err != nil {
return err
}
p1 = models.NewCachedPolicy(policy1)
}
if dbPol2 != nil {
policy2, err := buildChanPolicy(
*dbPol2, edge.ChannelID, nil, node1, false,
)
if err != nil {
return err
}
p2 = models.NewCachedPolicy(policy2)
}
// Determine the outgoing and incoming policy for this
// channel and node combo.
outPolicy, inPolicy := p1, p2
if p1 != nil && node2 == nodePub {
outPolicy, inPolicy = p2, p1
} else if p2 != nil && node1 != nodePub {
outPolicy, inPolicy = p2, p1
}
var cachedInPolicy *models.CachedEdgePolicy
if inPolicy != nil {
cachedInPolicy = inPolicy
cachedInPolicy.ToNodePubKey = toNodeCallback
cachedInPolicy.ToNodeFeatures = features
}
directedChannel := &DirectedChannel{
ChannelID: edge.ChannelID,
IsNode1: nodePub == edge.NodeKey1Bytes,
OtherNode: edge.NodeKey2Bytes,
Capacity: edge.Capacity,
OutPolicySet: outPolicy != nil,
InPolicy: cachedInPolicy,
}
if outPolicy != nil {
outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
directedChannel.InboundFee = fee
})
}
if nodePub == edge.NodeKey2Bytes {
directedChannel.OtherNode = edge.NodeKey1Bytes
}
if err := cb(directedChannel); err != nil {
return err
}
}
return nil
}
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
// and executes the provided callback for each node.
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
cb func(nodeID int64, nodePub route.Vertex) error) error {
var lastID int64
for {
nodes, err := db.ListNodeIDsAndPubKeys(
ctx, sqlc.ListNodeIDsAndPubKeysParams{
Version: int16(ProtocolV1),
ID: lastID,
Limit: pageSize,
},
)
if err != nil {
return fmt.Errorf("unable to fetch nodes: %w", err)
}
if len(nodes) == 0 {
break
}
for _, node := range nodes {
var pub route.Vertex
copy(pub[:], node.PubKey)
if err := cb(node.ID, pub); err != nil {
return fmt.Errorf("forEachNodeCacheable "+
"callback failed for node(id=%d): %w",
node.ID, err)
}
lastID = node.ID
}
}
return nil
}
// forEachNodeChannel iterates through all channels of a node, executing
// the passed callback on each. The call-back is provided with the channel's
// edge information, the outgoing policy and the incoming policy for the
@@ -1033,6 +1253,20 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries,
return dbNode.ID, node, nil
}
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
// provided database channel row and the public keys of the two nodes
// involved in the channel.
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
node2Pub route.Vertex) (*models.CachedEdgeInfo, error) {
return &models.CachedEdgeInfo{
ChannelID: byteOrder.Uint64(dbChan.Scid),
NodeKey1Bytes: node1Pub,
NodeKey2Bytes: node2Pub,
Capacity: btcutil.Amount(dbChan.Capacity.Int64),
}, nil
}
// buildNode constructs a LightningNode instance from the given database node
// record. The node's features, addresses and extra signed fields are also
// fetched from the database and set on the node.
@@ -1589,7 +1823,7 @@ func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
// pass it into the P2P decoding variant.
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return nil, err
return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
}
if len(parsedTypes) == 0 {
return nil, nil

View File

@@ -533,6 +533,25 @@ func (q *Queries) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeatur
return items, nil
}
const getNodeIDByPubKey = `-- name: GetNodeIDByPubKey :one
SELECT id
FROM nodes
WHERE pub_key = $1
AND version = $2
`
type GetNodeIDByPubKeyParams struct {
PubKey []byte
Version int16
}
func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) {
row := q.db.QueryRowContext(ctx, getNodeIDByPubKey, arg.PubKey, arg.Version)
var id int64
err := row.Scan(&id)
return id, err
}
const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many
SELECT id, version, pub_key, alias, last_update, color, signature
FROM nodes
@@ -879,6 +898,48 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo
return items, nil
}
const listNodeIDsAndPubKeys = `-- name: ListNodeIDsAndPubKeys :many
SELECT id, pub_key
FROM nodes
WHERE version = $1 AND id > $2
ORDER BY id
LIMIT $3
`
type ListNodeIDsAndPubKeysParams struct {
Version int16
ID int64
Limit int32
}
type ListNodeIDsAndPubKeysRow struct {
ID int64
PubKey []byte
}
func (q *Queries) ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error) {
rows, err := q.db.QueryContext(ctx, listNodeIDsAndPubKeys, arg.Version, arg.ID, arg.Limit)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ListNodeIDsAndPubKeysRow
for rows.Next() {
var i ListNodeIDsAndPubKeysRow
if err := rows.Scan(&i.ID, &i.PubKey); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const listNodesPaginated = `-- name: ListNodesPaginated :many
SELECT id, version, pub_key, alias, last_update, color, signature
FROM nodes

View File

@@ -48,6 +48,7 @@ type Querier interface {
GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error)
GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error)
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
@@ -64,6 +65,7 @@ type Querier interface {
InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error
InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error)
ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error)
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error

View File

@@ -27,6 +27,12 @@ FROM nodes
WHERE pub_key = $1
AND version = $2;
-- name: GetNodeIDByPubKey :one
SELECT id
FROM nodes
WHERE pub_key = $1
AND version = $2;
-- name: ListNodesPaginated :many
SELECT *
FROM nodes
@@ -34,6 +40,13 @@ WHERE version = $1 AND id > $2
ORDER BY id
LIMIT $3;
-- name: ListNodeIDsAndPubKeys :many
SELECT id, pub_key
FROM nodes
WHERE version = $1 AND id > $2
ORDER BY id
LIMIT $3;
-- name: DeleteNodeByPubKey :execresult
DELETE FROM nodes
WHERE pub_key = $1