diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index 42c88d49c..aae7e5a7e 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -355,10 +355,11 @@ func TestAliasLookup(t *testing.T) { require.ErrorIs(t, err, ErrNodeAliasNotFound) } +// TestSourceNode tests the source node functionality of the graph store. func TestSourceNode(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. @@ -369,11 +370,9 @@ func TestSourceNode(t *testing.T) { _, err := graph.SourceNode() require.ErrorIs(t, err, ErrSourceNodeNotSet) - // Set the source the source node, this should insert the node into the + // Set the source node, this should insert the node into the // database in a special way indicating it's the source node. - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } + require.NoError(t, graph.SetSourceNode(testNode)) // Retrieve the source node from the database, it should exactly match // the one we set above. diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 33fee1ae9..1b4cbc566 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -63,6 +63,12 @@ type SQLQueries interface { GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error) DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error + + /* + Source node queries. + */ + AddSourceNode(ctx context.Context, nodeID int64) error + GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error) } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -372,6 +378,73 @@ func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) { return alias, nil } +// SourceNode returns the source node of the graph. The source node is treated +// as the center node within a star-graph. This method may be used to kick off +// a path finding algorithm in order to explore the reachability of another +// node based off the source node. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) SourceNode() (*models.LightningNode, error) { + ctx := context.TODO() + + var ( + readTx = NewReadTx() + node *models.LightningNode + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + _, nodePub, err := getSourceNode(ctx, db, ProtocolV1) + if err != nil { + return fmt.Errorf("unable to fetch V1 source node: %w", + err) + } + + _, node, err = getNodeByPubKey(ctx, db, nodePub) + + return err + }, func() {}) + if err != nil { + return nil, fmt.Errorf("unable to fetch source node: %w", err) + } + + return node, nil +} + +// SetSourceNode sets the source node within the graph database. The source +// node is to be used as the center of a star-graph within path finding +// algorithms. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) SetSourceNode(node *models.LightningNode) error { + ctx := context.TODO() + var writeTxOpts TxOptions + + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := upsertNode(ctx, db, node) + if err != nil { + return fmt.Errorf("unable to upsert source node: %w", + err) + } + + // Make sure that if a source node for this version is already + // set, then the ID is the same as the one we are about to set. + dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1) + if err != nil && !errors.Is(err, ErrSourceNodeNotSet) { + return fmt.Errorf("unable to fetch source node: %w", + err) + } else if err == nil { + if dbSourceNodeID != id { + return fmt.Errorf("v1 source node already "+ + "set to a different node: %d vs %d", + dbSourceNodeID, id) + } + + return nil + } + + return db.AddSourceNode(ctx, id) + }, func() {}) +} + // NodeUpdatesInHorizon returns all the known lightning node which have an // update timestamp within the passed range. This method can be used by two // nodes to quickly determine if they have the same set of up to date node @@ -932,6 +1005,31 @@ func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries, return nil } +// getSourceNode returns the DB node ID and pub key of the source node for the +// specified protocol version. +func getSourceNode(ctx context.Context, db SQLQueries, + version ProtocolVersion) (int64, route.Vertex, error) { + + var pubKey route.Vertex + + nodes, err := db.GetSourceNodesByVersion(ctx, int16(version)) + if err != nil { + return 0, pubKey, fmt.Errorf("unable to fetch source node: %w", + err) + } + + if len(nodes) == 0 { + return 0, pubKey, ErrSourceNodeNotSet + } else if len(nodes) > 1 { + return 0, pubKey, fmt.Errorf("multiple source nodes for "+ + "protocol %s found", version) + } + + copy(pubKey[:], nodes[0].PubKey) + + return nodes[0].NodeID, pubKey, nil +} + // marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream. // This then produces a map from TLV type to value. If the input is not a // valid TLV stream, then an error is returned. diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index cde66e77e..dcdd79053 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -10,6 +10,22 @@ import ( "database/sql" ) +const addSourceNode = `-- name: AddSourceNode :exec +/* ───────────────────────────────────────────── + source_nodes table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO source_nodes (node_id) +VALUES ($1) +ON CONFLICT (node_id) DO NOTHING +` + +func (q *Queries) AddSourceNode(ctx context.Context, nodeID int64) error { + _, err := q.db.ExecContext(ctx, addSourceNode, nodeID) + return err +} + const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec DELETE FROM node_extra_types WHERE node_id = $1 @@ -272,6 +288,41 @@ func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByL return items, nil } +const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many +SELECT sn.node_id, n.pub_key +FROM source_nodes sn + JOIN nodes n ON sn.node_id = n.id +WHERE n.version = $1 +` + +type GetSourceNodesByVersionRow struct { + NodeID int64 + PubKey []byte +} + +func (q *Queries) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) { + rows, err := q.db.QueryContext(ctx, getSourceNodesByVersion, version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetSourceNodesByVersionRow + for rows.Next() { + var i GetSourceNodesByVersionRow + if err := rows.Scan(&i.NodeID, &i.PubKey); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const insertNodeAddress = `-- name: InsertNodeAddress :exec /* ───────────────────────────────────────────── node_addresses table queries diff --git a/sqldb/sqlc/migrations/000007_graph.down.sql b/sqldb/sqlc/migrations/000007_graph.down.sql index 29f01750a..79489d7bd 100644 --- a/sqldb/sqlc/migrations/000007_graph.down.sql +++ b/sqldb/sqlc/migrations/000007_graph.down.sql @@ -3,8 +3,10 @@ DROP INDEX IF EXISTS nodes_unique; DROP INDEX IF EXISTS node_extra_types_unique; DROP INDEX IF EXISTS node_features_unique; DROP INDEX IF EXISTS node_addresses_unique; +DROP INDEX IF EXISTS source_nodes_unique; -- Drop tables in order of reverse dependencies. +DROP TABLE IF EXISTS source_nodes; DROP TABLE IF EXISTS node_addresses; DROP TABLE IF EXISTS node_features; DROP TABLE IF EXISTS node_extra_types; diff --git a/sqldb/sqlc/migrations/000007_graph.up.sql b/sqldb/sqlc/migrations/000007_graph.up.sql index 4e3ee6903..0efc750e3 100644 --- a/sqldb/sqlc/migrations/000007_graph.up.sql +++ b/sqldb/sqlc/migrations/000007_graph.up.sql @@ -87,4 +87,11 @@ CREATE TABLE IF NOT EXISTS node_addresses ( ); CREATE UNIQUE INDEX IF NOT EXISTS node_addresses_unique ON node_addresses ( node_id, type, position +); + +CREATE TABLE IF NOT EXISTS source_nodes ( + node_id BIGINT NOT NULL REFERENCES nodes (id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX IF NOT EXISTS source_nodes_unique ON source_nodes ( + node_id ); \ No newline at end of file diff --git a/sqldb/sqlc/models.go b/sqldb/sqlc/models.go index ad256ca13..07269c359 100644 --- a/sqldb/sqlc/models.go +++ b/sqldb/sqlc/models.go @@ -130,3 +130,7 @@ type NodeFeature struct { NodeID int64 FeatureBit int32 } + +type SourceNode struct { + NodeID int64 +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index e26bf05c7..7492a21e0 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -11,6 +11,7 @@ import ( ) type Querier interface { + AddSourceNode(ctx context.Context, nodeID int64) error ClearKVInvoiceHashIndex(ctx context.Context) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error @@ -41,6 +42,7 @@ type Querier interface { GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) + GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 5c8c9b800..a6745227e 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -116,3 +116,19 @@ WHERE node_id = $1; DELETE FROM node_extra_types WHERE node_id = $1 AND type = $2; + +/* ───────────────────────────────────────────── + source_nodes table queries + ───────────────────────────────────────────── +*/ + +-- name: AddSourceNode :exec +INSERT INTO source_nodes (node_id) +VALUES ($1) +ON CONFLICT (node_id) DO NOTHING; + +-- name: GetSourceNodesByVersion :many +SELECT sn.node_id, n.pub_key +FROM source_nodes sn + JOIN nodes n ON sn.node_id = n.id +WHERE n.version = $1;