sqldb+graph/db: source nodes table, queries and CRUD

In this commit, we add the `source_nodes` table. It points to entries in
the `nodes` table. This table will store one entry per protocol version
that we are announcing a node_announcement on.

With this commit, we can run the TestSourceNode unit test against our
SQL backends.
This commit is contained in:
Elle Mouton
2025-05-19 12:11:15 +02:00
parent 86d48390ca
commit 0064d33cda
8 changed files with 184 additions and 5 deletions

View File

@@ -355,10 +355,11 @@ func TestAliasLookup(t *testing.T) {
require.ErrorIs(t, err, ErrNodeAliasNotFound) require.ErrorIs(t, err, ErrNodeAliasNotFound)
} }
// TestSourceNode tests the source node functionality of the graph store.
func TestSourceNode(t *testing.T) { func TestSourceNode(t *testing.T) {
t.Parallel() t.Parallel()
graph := MakeTestGraph(t) graph := MakeTestGraphNew(t)
// We'd like to test the setting/getting of the source node, so we // We'd like to test the setting/getting of the source node, so we
// first create a fake node to use within the test. // first create a fake node to use within the test.
@@ -369,11 +370,9 @@ func TestSourceNode(t *testing.T) {
_, err := graph.SourceNode() _, err := graph.SourceNode()
require.ErrorIs(t, err, ErrSourceNodeNotSet) require.ErrorIs(t, err, ErrSourceNodeNotSet)
// Set the source the source node, this should insert the node into the // Set the source node, this should insert the node into the
// database in a special way indicating it's the source node. // database in a special way indicating it's the source node.
if err := graph.SetSourceNode(testNode); err != nil { require.NoError(t, graph.SetSourceNode(testNode))
t.Fatalf("unable to set source node: %v", err)
}
// Retrieve the source node from the database, it should exactly match // Retrieve the source node from the database, it should exactly match
// the one we set above. // the one we set above.

View File

@@ -63,6 +63,12 @@ type SQLQueries interface {
GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error) GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
/*
Source node queries.
*/
AddSourceNode(ctx context.Context, nodeID int64) error
GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
} }
// BatchedSQLQueries is a version of SQLQueries that's capable of batched // BatchedSQLQueries is a version of SQLQueries that's capable of batched
@@ -372,6 +378,73 @@ func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
return alias, nil return alias, nil
} }
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
// node based off the source node.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
ctx := context.TODO()
var (
readTx = NewReadTx()
node *models.LightningNode
)
err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error {
_, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
if err != nil {
return fmt.Errorf("unable to fetch V1 source node: %w",
err)
}
_, node, err = getNodeByPubKey(ctx, db, nodePub)
return err
}, func() {})
if err != nil {
return nil, fmt.Errorf("unable to fetch source node: %w", err)
}
return node, nil
}
// SetSourceNode sets the source node within the graph database. The source
// node is to be used as the center of a star-graph within path finding
// algorithms.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
ctx := context.TODO()
var writeTxOpts TxOptions
return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error {
id, err := upsertNode(ctx, db, node)
if err != nil {
return fmt.Errorf("unable to upsert source node: %w",
err)
}
// Make sure that if a source node for this version is already
// set, then the ID is the same as the one we are about to set.
dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
return fmt.Errorf("unable to fetch source node: %w",
err)
} else if err == nil {
if dbSourceNodeID != id {
return fmt.Errorf("v1 source node already "+
"set to a different node: %d vs %d",
dbSourceNodeID, id)
}
return nil
}
return db.AddSourceNode(ctx, id)
}, func() {})
}
// NodeUpdatesInHorizon returns all the known lightning node which have an // NodeUpdatesInHorizon returns all the known lightning node which have an
// update timestamp within the passed range. This method can be used by two // update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node // nodes to quickly determine if they have the same set of up to date node
@@ -932,6 +1005,31 @@ func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
return nil return nil
} }
// getSourceNode returns the DB node ID and pub key of the source node for the
// specified protocol version.
func getSourceNode(ctx context.Context, db SQLQueries,
version ProtocolVersion) (int64, route.Vertex, error) {
var pubKey route.Vertex
nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
if err != nil {
return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
err)
}
if len(nodes) == 0 {
return 0, pubKey, ErrSourceNodeNotSet
} else if len(nodes) > 1 {
return 0, pubKey, fmt.Errorf("multiple source nodes for "+
"protocol %s found", version)
}
copy(pubKey[:], nodes[0].PubKey)
return nodes[0].NodeID, pubKey, nil
}
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream. // marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
// This then produces a map from TLV type to value. If the input is not a // This then produces a map from TLV type to value. If the input is not a
// valid TLV stream, then an error is returned. // valid TLV stream, then an error is returned.

View File

@@ -10,6 +10,22 @@ import (
"database/sql" "database/sql"
) )
const addSourceNode = `-- name: AddSourceNode :exec
/* ─────────────────────────────────────────────
source_nodes table queries
─────────────────────────────────────────────
*/
INSERT INTO source_nodes (node_id)
VALUES ($1)
ON CONFLICT (node_id) DO NOTHING
`
func (q *Queries) AddSourceNode(ctx context.Context, nodeID int64) error {
_, err := q.db.ExecContext(ctx, addSourceNode, nodeID)
return err
}
const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec
DELETE FROM node_extra_types DELETE FROM node_extra_types
WHERE node_id = $1 WHERE node_id = $1
@@ -272,6 +288,41 @@ func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByL
return items, nil return items, nil
} }
const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many
SELECT sn.node_id, n.pub_key
FROM source_nodes sn
JOIN nodes n ON sn.node_id = n.id
WHERE n.version = $1
`
type GetSourceNodesByVersionRow struct {
NodeID int64
PubKey []byte
}
func (q *Queries) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) {
rows, err := q.db.QueryContext(ctx, getSourceNodesByVersion, version)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetSourceNodesByVersionRow
for rows.Next() {
var i GetSourceNodesByVersionRow
if err := rows.Scan(&i.NodeID, &i.PubKey); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertNodeAddress = `-- name: InsertNodeAddress :exec const insertNodeAddress = `-- name: InsertNodeAddress :exec
/* ───────────────────────────────────────────── /* ─────────────────────────────────────────────
node_addresses table queries node_addresses table queries

View File

@@ -3,8 +3,10 @@ DROP INDEX IF EXISTS nodes_unique;
DROP INDEX IF EXISTS node_extra_types_unique; DROP INDEX IF EXISTS node_extra_types_unique;
DROP INDEX IF EXISTS node_features_unique; DROP INDEX IF EXISTS node_features_unique;
DROP INDEX IF EXISTS node_addresses_unique; DROP INDEX IF EXISTS node_addresses_unique;
DROP INDEX IF EXISTS source_nodes_unique;
-- Drop tables in order of reverse dependencies. -- Drop tables in order of reverse dependencies.
DROP TABLE IF EXISTS source_nodes;
DROP TABLE IF EXISTS node_addresses; DROP TABLE IF EXISTS node_addresses;
DROP TABLE IF EXISTS node_features; DROP TABLE IF EXISTS node_features;
DROP TABLE IF EXISTS node_extra_types; DROP TABLE IF EXISTS node_extra_types;

View File

@@ -87,4 +87,11 @@ CREATE TABLE IF NOT EXISTS node_addresses (
); );
CREATE UNIQUE INDEX IF NOT EXISTS node_addresses_unique ON node_addresses ( CREATE UNIQUE INDEX IF NOT EXISTS node_addresses_unique ON node_addresses (
node_id, type, position node_id, type, position
);
CREATE TABLE IF NOT EXISTS source_nodes (
node_id BIGINT NOT NULL REFERENCES nodes (id) ON DELETE CASCADE
);
CREATE UNIQUE INDEX IF NOT EXISTS source_nodes_unique ON source_nodes (
node_id
); );

View File

@@ -130,3 +130,7 @@ type NodeFeature struct {
NodeID int64 NodeID int64
FeatureBit int32 FeatureBit int32
} }
type SourceNode struct {
NodeID int64
}

View File

@@ -11,6 +11,7 @@ import (
) )
type Querier interface { type Querier interface {
AddSourceNode(ctx context.Context, nodeID int64) error
ClearKVInvoiceHashIndex(ctx context.Context) error ClearKVInvoiceHashIndex(ctx context.Context) error
DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteCanceledInvoices(ctx context.Context) (sql.Result, error)
DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error
@@ -41,6 +42,7 @@ type Querier interface {
GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error)
GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error)
GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error)
GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error)
InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error
InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error
InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error)

View File

@@ -116,3 +116,19 @@ WHERE node_id = $1;
DELETE FROM node_extra_types DELETE FROM node_extra_types
WHERE node_id = $1 WHERE node_id = $1
AND type = $2; AND type = $2;
/* ─────────────────────────────────────────────
source_nodes table queries
─────────────────────────────────────────────
*/
-- name: AddSourceNode :exec
INSERT INTO source_nodes (node_id)
VALUES ($1)
ON CONFLICT (node_id) DO NOTHING;
-- name: GetSourceNodesByVersion :many
SELECT sn.node_id, n.pub_key
FROM source_nodes sn
JOIN nodes n ON sn.node_id = n.id
WHERE n.version = $1;