mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-30 16:10:01 +01:00
graph/db: implement ForEachNodeDirectedChannel and ForEachNodeCacheable
Here we add the `ForEachNodeDirectedChannel` and `ForEachNodeCacheable` SQLStore implementations which then lets us run `TestGraphTraversalCacheable` and `TestGraphCacheForEachNodeChannel` against SQL backends.
This commit is contained in:
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/btcsuite/btcd/chaincfg"
|
||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/lightningnetwork/lnd/fn/v2"
|
||||
"github.com/lightningnetwork/lnd/graph/db/models"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
@@ -1366,7 +1367,7 @@ func TestGraphTraversal(t *testing.T) {
|
||||
func TestGraphTraversalCacheable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
graph := MakeTestGraph(t)
|
||||
graph := MakeTestGraphNew(t)
|
||||
|
||||
// We'd like to test some of the graph traversal capabilities within
|
||||
// the DB, so we'll create a series of fake nodes to insert into the
|
||||
@@ -4162,7 +4163,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := context.Background()
|
||||
|
||||
graph := MakeTestGraph(t)
|
||||
graph := MakeTestGraphNew(t)
|
||||
|
||||
// Unset the channel graph cache to simulate the user running with the
|
||||
// option turned off.
|
||||
@@ -4212,21 +4213,25 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
|
||||
edge1.ExtraOpaqueData = []byte{
|
||||
253, 217, 3, 8, 0, 0, 0, 10, 0, 0, 0, 20,
|
||||
}
|
||||
inboundFee := lnwire.Fee{
|
||||
BaseFee: 10,
|
||||
FeeRate: 20,
|
||||
}
|
||||
edge1.InboundFee = fn.Some(inboundFee)
|
||||
require.NoError(t, graph.UpdateEdgePolicy(edge1))
|
||||
edge1 = copyEdgePolicy(edge1) // Avoid read/write race conditions.
|
||||
|
||||
directedChan := getSingleChannel()
|
||||
require.NotNil(t, directedChan)
|
||||
expectedInbound := lnwire.Fee{
|
||||
BaseFee: 10,
|
||||
FeeRate: 20,
|
||||
}
|
||||
require.Equal(t, expectedInbound, directedChan.InboundFee)
|
||||
require.Equal(t, inboundFee, directedChan.InboundFee)
|
||||
|
||||
// Set an invalid inbound fee and check that persistence fails.
|
||||
edge1.ExtraOpaqueData = []byte{
|
||||
253, 217, 3, 8, 0,
|
||||
}
|
||||
// We need to update the timestamp so that we don't hit the DB conflict
|
||||
// error when we try to update the edge policy.
|
||||
edge1.LastUpdate = edge1.LastUpdate.Add(time.Second)
|
||||
require.ErrorIs(
|
||||
t, graph.UpdateEdgePolicy(edge1), ErrParsingExtraTLVBytes,
|
||||
)
|
||||
@@ -4235,7 +4240,7 @@ func TestGraphCacheForEachNodeChannel(t *testing.T) {
|
||||
// the previous result when we query the channel again.
|
||||
directedChan = getSingleChannel()
|
||||
require.NotNil(t, directedChan)
|
||||
require.Equal(t, expectedInbound, directedChan.InboundFee)
|
||||
require.Equal(t, inboundFee, directedChan.InboundFee)
|
||||
}
|
||||
|
||||
// TestGraphLoading asserts that the cache is properly reconstructed after a
|
||||
|
||||
@@ -56,8 +56,10 @@ type SQLQueries interface {
|
||||
*/
|
||||
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
|
||||
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
|
||||
GetNodeIDByPubKey(ctx context.Context, arg sqlc.GetNodeIDByPubKeyParams) (int64, error)
|
||||
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
|
||||
ListNodesPaginated(ctx context.Context, arg sqlc.ListNodesPaginatedParams) ([]sqlc.Node, error)
|
||||
ListNodeIDsAndPubKeys(ctx context.Context, arg sqlc.ListNodeIDsAndPubKeysParams) ([]sqlc.ListNodeIDsAndPubKeysRow, error)
|
||||
DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
|
||||
|
||||
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
|
||||
@@ -844,6 +846,224 @@ func (s *sqlGraphNodeTx) FetchNode(nodePub route.Vertex) (NodeRTx, error) {
|
||||
return newSQLGraphNodeTx(s.db, s.chain, id, node), nil
|
||||
}
|
||||
|
||||
// ForEachNodeDirectedChannel iterates through all channels of a given node,
|
||||
// executing the passed callback on the directed edge representing the channel
|
||||
// and its incoming policy. If the callback returns an error, then the iteration
|
||||
// is halted with the error propagated back up to the caller.
|
||||
//
|
||||
// Unknown policies are passed into the callback as nil values.
|
||||
//
|
||||
// NOTE: this is part of the graphdb.NodeTraverser interface.
|
||||
func (s *SQLStore) ForEachNodeDirectedChannel(nodePub route.Vertex,
|
||||
cb func(channel *DirectedChannel) error) error {
|
||||
|
||||
var ctx = context.TODO()
|
||||
|
||||
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||
return forEachNodeDirectedChannel(ctx, db, nodePub, cb)
|
||||
}, sqldb.NoOpReset)
|
||||
}
|
||||
|
||||
// ForEachNodeCacheable iterates through all the stored vertices/nodes in the
|
||||
// graph, executing the passed callback with each node encountered. If the
|
||||
// callback returns an error, then the transaction is aborted and the iteration
|
||||
// stops early.
|
||||
//
|
||||
// NOTE: This is a part of the V1Store interface.
|
||||
func (s *SQLStore) ForEachNodeCacheable(cb func(route.Vertex,
|
||||
*lnwire.FeatureVector) error) error {
|
||||
|
||||
ctx := context.TODO()
|
||||
|
||||
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||
return forEachNodeCacheable(ctx, db, func(nodeID int64,
|
||||
nodePub route.Vertex) error {
|
||||
|
||||
features, err := getNodeFeatures(ctx, db, nodeID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch node "+
|
||||
"features: %w", err)
|
||||
}
|
||||
|
||||
return cb(nodePub, features)
|
||||
})
|
||||
}, sqldb.NoOpReset)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch nodes: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forEachNodeDirectedChannel iterates through all channels of a given
|
||||
// node, executing the passed callback on the directed edge representing the
|
||||
// channel and its incoming policy. If the node is not found, no error is
|
||||
// returned.
|
||||
func forEachNodeDirectedChannel(ctx context.Context, db SQLQueries,
|
||||
nodePub route.Vertex, cb func(channel *DirectedChannel) error) error {
|
||||
|
||||
toNodeCallback := func() route.Vertex {
|
||||
return nodePub
|
||||
}
|
||||
|
||||
dbID, err := db.GetNodeIDByPubKey(
|
||||
ctx, sqlc.GetNodeIDByPubKeyParams{
|
||||
Version: int16(ProtocolV1),
|
||||
PubKey: nodePub[:],
|
||||
},
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("unable to fetch node: %w", err)
|
||||
}
|
||||
|
||||
rows, err := db.ListChannelsByNodeID(
|
||||
ctx, sqlc.ListChannelsByNodeIDParams{
|
||||
Version: int16(ProtocolV1),
|
||||
NodeID1: dbID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch channels: %w", err)
|
||||
}
|
||||
|
||||
// Exit early if there are no channels for this node so we don't
|
||||
// do the unnecessary feature fetching.
|
||||
if len(rows) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
features, err := getNodeFeatures(ctx, db, dbID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch node features: %w", err)
|
||||
}
|
||||
|
||||
for _, row := range rows {
|
||||
node1, node2, err := buildNodeVertices(
|
||||
row.Node1Pubkey, row.Node2Pubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build node vertices: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
edge, err := buildCacheableChannelInfo(
|
||||
row.Channel, node1, node2,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var p1, p2 *models.CachedEdgePolicy
|
||||
if dbPol1 != nil {
|
||||
policy1, err := buildChanPolicy(
|
||||
*dbPol1, edge.ChannelID, nil, node2, true,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p1 = models.NewCachedPolicy(policy1)
|
||||
}
|
||||
if dbPol2 != nil {
|
||||
policy2, err := buildChanPolicy(
|
||||
*dbPol2, edge.ChannelID, nil, node1, false,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p2 = models.NewCachedPolicy(policy2)
|
||||
}
|
||||
|
||||
// Determine the outgoing and incoming policy for this
|
||||
// channel and node combo.
|
||||
outPolicy, inPolicy := p1, p2
|
||||
if p1 != nil && node2 == nodePub {
|
||||
outPolicy, inPolicy = p2, p1
|
||||
} else if p2 != nil && node1 != nodePub {
|
||||
outPolicy, inPolicy = p2, p1
|
||||
}
|
||||
|
||||
var cachedInPolicy *models.CachedEdgePolicy
|
||||
if inPolicy != nil {
|
||||
cachedInPolicy = inPolicy
|
||||
cachedInPolicy.ToNodePubKey = toNodeCallback
|
||||
cachedInPolicy.ToNodeFeatures = features
|
||||
}
|
||||
|
||||
directedChannel := &DirectedChannel{
|
||||
ChannelID: edge.ChannelID,
|
||||
IsNode1: nodePub == edge.NodeKey1Bytes,
|
||||
OtherNode: edge.NodeKey2Bytes,
|
||||
Capacity: edge.Capacity,
|
||||
OutPolicySet: outPolicy != nil,
|
||||
InPolicy: cachedInPolicy,
|
||||
}
|
||||
if outPolicy != nil {
|
||||
outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
|
||||
directedChannel.InboundFee = fee
|
||||
})
|
||||
}
|
||||
|
||||
if nodePub == edge.NodeKey2Bytes {
|
||||
directedChannel.OtherNode = edge.NodeKey1Bytes
|
||||
}
|
||||
|
||||
if err := cb(directedChannel); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forEachNodeCacheable fetches all V1 node IDs and pub keys from the database,
|
||||
// and executes the provided callback for each node.
|
||||
func forEachNodeCacheable(ctx context.Context, db SQLQueries,
|
||||
cb func(nodeID int64, nodePub route.Vertex) error) error {
|
||||
|
||||
var lastID int64
|
||||
|
||||
for {
|
||||
nodes, err := db.ListNodeIDsAndPubKeys(
|
||||
ctx, sqlc.ListNodeIDsAndPubKeysParams{
|
||||
Version: int16(ProtocolV1),
|
||||
ID: lastID,
|
||||
Limit: pageSize,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch nodes: %w", err)
|
||||
}
|
||||
|
||||
if len(nodes) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for _, node := range nodes {
|
||||
var pub route.Vertex
|
||||
copy(pub[:], node.PubKey)
|
||||
|
||||
if err := cb(node.ID, pub); err != nil {
|
||||
return fmt.Errorf("forEachNodeCacheable "+
|
||||
"callback failed for node(id=%d): %w",
|
||||
node.ID, err)
|
||||
}
|
||||
|
||||
lastID = node.ID
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// forEachNodeChannel iterates through all channels of a node, executing
|
||||
// the passed callback on each. The call-back is provided with the channel's
|
||||
// edge information, the outgoing policy and the incoming policy for the
|
||||
@@ -1033,6 +1253,20 @@ func getNodeByPubKey(ctx context.Context, db SQLQueries,
|
||||
return dbNode.ID, node, nil
|
||||
}
|
||||
|
||||
// buildCacheableChannelInfo builds a models.CachedEdgeInfo instance from the
|
||||
// provided database channel row and the public keys of the two nodes
|
||||
// involved in the channel.
|
||||
func buildCacheableChannelInfo(dbChan sqlc.Channel, node1Pub,
|
||||
node2Pub route.Vertex) (*models.CachedEdgeInfo, error) {
|
||||
|
||||
return &models.CachedEdgeInfo{
|
||||
ChannelID: byteOrder.Uint64(dbChan.Scid),
|
||||
NodeKey1Bytes: node1Pub,
|
||||
NodeKey2Bytes: node2Pub,
|
||||
Capacity: btcutil.Amount(dbChan.Capacity.Int64),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// buildNode constructs a LightningNode instance from the given database node
|
||||
// record. The node's features, addresses and extra signed fields are also
|
||||
// fetched from the database and set on the node.
|
||||
@@ -1589,7 +1823,7 @@ func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
|
||||
// pass it into the P2P decoding variant.
|
||||
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
|
||||
}
|
||||
if len(parsedTypes) == 0 {
|
||||
return nil, nil
|
||||
|
||||
@@ -533,6 +533,25 @@ func (q *Queries) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeatur
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getNodeIDByPubKey = `-- name: GetNodeIDByPubKey :one
|
||||
SELECT id
|
||||
FROM nodes
|
||||
WHERE pub_key = $1
|
||||
AND version = $2
|
||||
`
|
||||
|
||||
type GetNodeIDByPubKeyParams struct {
|
||||
PubKey []byte
|
||||
Version int16
|
||||
}
|
||||
|
||||
func (q *Queries) GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, getNodeIDByPubKey, arg.PubKey, arg.Version)
|
||||
var id int64
|
||||
err := row.Scan(&id)
|
||||
return id, err
|
||||
}
|
||||
|
||||
const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many
|
||||
SELECT id, version, pub_key, alias, last_update, color, signature
|
||||
FROM nodes
|
||||
@@ -879,6 +898,48 @@ func (q *Queries) ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNo
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listNodeIDsAndPubKeys = `-- name: ListNodeIDsAndPubKeys :many
|
||||
SELECT id, pub_key
|
||||
FROM nodes
|
||||
WHERE version = $1 AND id > $2
|
||||
ORDER BY id
|
||||
LIMIT $3
|
||||
`
|
||||
|
||||
type ListNodeIDsAndPubKeysParams struct {
|
||||
Version int16
|
||||
ID int64
|
||||
Limit int32
|
||||
}
|
||||
|
||||
type ListNodeIDsAndPubKeysRow struct {
|
||||
ID int64
|
||||
PubKey []byte
|
||||
}
|
||||
|
||||
func (q *Queries) ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listNodeIDsAndPubKeys, arg.Version, arg.ID, arg.Limit)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListNodeIDsAndPubKeysRow
|
||||
for rows.Next() {
|
||||
var i ListNodeIDsAndPubKeysRow
|
||||
if err := rows.Scan(&i.ID, &i.PubKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listNodesPaginated = `-- name: ListNodesPaginated :many
|
||||
SELECT id, version, pub_key, alias, last_update, color, signature
|
||||
FROM nodes
|
||||
|
||||
@@ -48,6 +48,7 @@ type Querier interface {
|
||||
GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error)
|
||||
GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error)
|
||||
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
|
||||
GetNodeIDByPubKey(ctx context.Context, arg GetNodeIDByPubKeyParams) (int64, error)
|
||||
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
|
||||
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
|
||||
HighestSCID(ctx context.Context, version int16) ([]byte, error)
|
||||
@@ -64,6 +65,7 @@ type Querier interface {
|
||||
InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error
|
||||
InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error
|
||||
ListChannelsByNodeID(ctx context.Context, arg ListChannelsByNodeIDParams) ([]ListChannelsByNodeIDRow, error)
|
||||
ListNodeIDsAndPubKeys(ctx context.Context, arg ListNodeIDsAndPubKeysParams) ([]ListNodeIDsAndPubKeysRow, error)
|
||||
ListNodesPaginated(ctx context.Context, arg ListNodesPaginatedParams) ([]Node, error)
|
||||
NextInvoiceSettleIndex(ctx context.Context) (int64, error)
|
||||
OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error
|
||||
|
||||
@@ -27,6 +27,12 @@ FROM nodes
|
||||
WHERE pub_key = $1
|
||||
AND version = $2;
|
||||
|
||||
-- name: GetNodeIDByPubKey :one
|
||||
SELECT id
|
||||
FROM nodes
|
||||
WHERE pub_key = $1
|
||||
AND version = $2;
|
||||
|
||||
-- name: ListNodesPaginated :many
|
||||
SELECT *
|
||||
FROM nodes
|
||||
@@ -34,6 +40,13 @@ WHERE version = $1 AND id > $2
|
||||
ORDER BY id
|
||||
LIMIT $3;
|
||||
|
||||
-- name: ListNodeIDsAndPubKeys :many
|
||||
SELECT id, pub_key
|
||||
FROM nodes
|
||||
WHERE version = $1 AND id > $2
|
||||
ORDER BY id
|
||||
LIMIT $3;
|
||||
|
||||
-- name: DeleteNodeByPubKey :execresult
|
||||
DELETE FROM nodes
|
||||
WHERE pub_key = $1
|
||||
|
||||
Reference in New Issue
Block a user