diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index a78b3919c..087a57624 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -42,6 +42,8 @@ byte blobs at the end of gossip messages are valid TLV streams. * Various [preparations](https://github.com/lightningnetwork/lnd/pull/9692) of the graph code before the SQL implementation is added. + * Add graph schemas, queries and CRUD: + * [1](https://github.com/lightningnetwork/lnd/pull/9866) ## RPC Updates diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index cab857fee..aae7e5a7e 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -101,15 +101,17 @@ func createTestVertex(t testing.TB) *models.LightningNode { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. + timeStamp := int64(1232342) nodeWithAddrs := func(addrs []net.Addr) *models.LightningNode { + timeStamp++ return &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), - LastUpdate: time.Unix(1232342, 0), + LastUpdate: time.Unix(timeStamp, 0), Color: color.RGBA{1, 2, 3, 0}, Alias: "kek", Features: testFeatures, @@ -323,10 +325,11 @@ func TestPartialNode(t *testing.T) { require.ErrorIs(t, err, ErrGraphNodeNotFound) } +// TestAliasLookup tests the alias lookup functionality of the graph store. func TestAliasLookup(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test the alias index within the database, so first // create a new test node. @@ -334,9 +337,7 @@ func TestAliasLookup(t *testing.T) { // Add the node to the graph's database, this should also insert an // entry into the alias index for this node. - if err := graph.AddLightningNode(testNode); err != nil { - t.Fatalf("unable to add node: %v", err) - } + require.NoError(t, graph.AddLightningNode(testNode)) // Next, attempt to lookup the alias. The alias should exactly match // the one which the test node was assigned. @@ -344,10 +345,7 @@ func TestAliasLookup(t *testing.T) { require.NoError(t, err, "unable to generate pubkey") dbAlias, err := graph.LookupAlias(nodePub) require.NoError(t, err, "unable to find alias") - if dbAlias != testNode.Alias { - t.Fatalf("aliases don't match, expected %v got %v", - testNode.Alias, dbAlias) - } + require.Equal(t, testNode.Alias, dbAlias) // Ensure that looking up a non-existent alias results in an error. node := createTestVertex(t) @@ -357,10 +355,11 @@ func TestAliasLookup(t *testing.T) { require.ErrorIs(t, err, ErrNodeAliasNotFound) } +// TestSourceNode tests the source node functionality of the graph store. func TestSourceNode(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test the setting/getting of the source node, so we // first create a fake node to use within the test. @@ -371,11 +370,9 @@ func TestSourceNode(t *testing.T) { _, err := graph.SourceNode() require.ErrorIs(t, err, ErrSourceNodeNotSet) - // Set the source the source node, this should insert the node into the + // Set the source node, this should insert the node into the // database in a special way indicating it's the source node. - if err := graph.SetSourceNode(testNode); err != nil { - t.Fatalf("unable to set source node: %v", err) - } + require.NoError(t, graph.SetSourceNode(testNode)) // Retrieve the source node from the database, it should exactly match // the one we set above. @@ -2082,7 +2079,7 @@ func TestChanUpdatesInHorizon(t *testing.T) { func TestNodeUpdatesInHorizon(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) startTime := time.Unix(1234, 0) endTime := startTime @@ -2093,10 +2090,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { time.Unix(999, 0), time.Unix(9999, 0), ) require.NoError(t, err, "unable to query for node updates") - if len(nodeUpdates) != 0 { - t.Fatalf("expected 0 node updates, instead got %v", - len(nodeUpdates)) - } + require.Len(t, nodeUpdates, 0) // We'll create 10 node announcements, each with an update timestamp 10 // seconds after the other. @@ -2115,9 +2109,7 @@ func TestNodeUpdatesInHorizon(t *testing.T) { nodeAnns = append(nodeAnns, *nodeAnn) - if err := graph.AddLightningNode(nodeAnn); err != nil { - t.Fatalf("unable to add lightning node: %v", err) - } + require.NoError(t, graph.AddLightningNode(nodeAnn)) } queryCases := []struct { @@ -2171,15 +2163,8 @@ func TestNodeUpdatesInHorizon(t *testing.T) { resp, err := graph.NodeUpdatesInHorizon( queryCase.start, queryCase.end, ) - if err != nil { - t.Fatalf("unable to query for nodes: %v", err) - } - - if len(resp) != len(queryCase.resp) { - t.Fatalf("expected %v nodes, got %v nodes", - len(queryCase.resp), len(resp)) - - } + require.NoError(t, err) + require.Len(t, resp, len(queryCase.resp)) for i := 0; i < len(resp); i++ { compareNodes(t, &queryCase.resp[i], &resp[i]) @@ -3384,7 +3369,7 @@ func TestAddChannelEdgeShellNodes(t *testing.T) { func TestNodePruningUpdateIndexDeletion(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'll first populate our graph with a single node that will be // removed shortly. @@ -4315,7 +4300,7 @@ func TestLightningNodePersistence(t *testing.T) { t.Parallel() // Create a new test graph instance. - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) nodeAnnBytes, err := hex.DecodeString(testNodeAnn) require.NoError(t, err) diff --git a/graph/db/notifications.go b/graph/db/notifications.go index 8b69b7d85..d573e0aa5 100644 --- a/graph/db/notifications.go +++ b/graph/db/notifications.go @@ -4,6 +4,7 @@ import ( "fmt" "image/color" "net" + "strconv" "sync" "sync/atomic" @@ -463,3 +464,31 @@ func (c *ChannelGraph) addToTopologyChange(update *TopologyChange, func EncodeHexColor(color color.RGBA) string { return fmt.Sprintf("#%02x%02x%02x", color.R, color.G, color.B) } + +// DecodeHexColor takes a hex color string like "#rrggbb" and returns a +// color.RGBA. +func DecodeHexColor(hex string) (color.RGBA, error) { + r, err := strconv.ParseUint(hex[1:3], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid red component: %w", + err) + } + + g, err := strconv.ParseUint(hex[3:5], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid green component: %w", + err) + } + + b, err := strconv.ParseUint(hex[5:7], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid blue component: %w", + err) + } + + return color.RGBA{ + R: uint8(r), + G: uint8(g), + B: uint8(b), + }, nil +} diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 3947ae0d8..1b4cbc566 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1,16 +1,74 @@ package graphdb import ( + "bytes" + "context" + "database/sql" + "encoding/hex" + "errors" "fmt" + "math" + "net" + "strconv" "sync" + "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" ) +// ProtocolVersion is an enum that defines the gossip protocol version of a +// message. +type ProtocolVersion uint8 + +const ( + // ProtocolV1 is the gossip protocol version defined in BOLT #7. + ProtocolV1 ProtocolVersion = 1 +) + +// String returns a string representation of the protocol version. +func (v ProtocolVersion) String() string { + return fmt.Sprintf("V%d", v) +} + // SQLQueries is a subset of the sqlc.Querier interface that can be used to // execute queries against the SQL graph tables. +// +//nolint:ll,interfacebloat type SQLQueries interface { + /* + Node queries. + */ + UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error) + GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error) + GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error) + DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error) + + GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error) + UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error + DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error + + InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error + GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error) + DeleteNodeAddresses(ctx context.Context, nodeID int64) error + + InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error + GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error) + GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error) + DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error + + /* + Source node queries. + */ + AddSourceNode(ctx context.Context, nodeID int64) error + GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error) } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -80,3 +138,923 @@ func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore, return s, nil } + +// TxOptions defines the set of db txn options the SQLQueries +// understands. +type TxOptions struct { + // readOnly governs if a read only transaction is needed or not. + readOnly bool +} + +// ReadOnly returns true if the transaction should be read only. +// +// NOTE: This implements the TxOptions. +func (a *TxOptions) ReadOnly() bool { + return a.readOnly +} + +// NewReadTx creates a new read transaction option set. +func NewReadTx() *TxOptions { + return &TxOptions{ + readOnly: true, + } +} + +// AddLightningNode adds a vertex/node to the graph database. If the node is not +// in the database from before, this will add a new, unconnected one to the +// graph. If it is present from before, this will update that node's +// information. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) AddLightningNode(node *models.LightningNode, + opts ...batch.SchedulerOption) error { + + ctx := context.TODO() + + r := &batch.Request[SQLQueries]{ + Opts: batch.NewSchedulerOptions(opts...), + Do: func(queries SQLQueries) error { + _, err := upsertNode(ctx, queries, node) + return err + }, + } + + return s.nodeScheduler.Execute(ctx, r) +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then ErrGraphNodeNotFound is +// returned. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) ( + *models.LightningNode, error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + node *models.LightningNode + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + var err error + _, node, err = getNodeByPubKey(ctx, db, pubKey) + + return err + }, func() {}) + if err != nil { + return nil, fmt.Errorf("unable to fetch node: %w", err) + } + + return node, nil +} + +// HasLightningNode determines if the graph has a vertex identified by the +// target node identity public key. If the node exists in the database, a +// timestamp of when the data for the node was lasted updated is returned along +// with a true boolean. Otherwise, an empty time.Time is returned with a false +// boolean. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool, + error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + exists bool + lastUpdate time.Time + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + dbNode, err := db.GetNodeByPubKey( + ctx, sqlc.GetNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil + } else if err != nil { + return fmt.Errorf("unable to fetch node: %w", err) + } + + exists = true + + if dbNode.LastUpdate.Valid { + lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0) + } + + return nil + }, func() {}) + if err != nil { + return time.Time{}, false, + fmt.Errorf("unable to fetch node: %w", err) + } + + return lastUpdate, exists, nil +} + +// AddrsForNode returns all known addresses for the target node public key +// that the graph DB is aware of. The returned boolean indicates if the +// given node is unknown to the graph DB or not. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, + error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + addresses []net.Addr + known bool + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + var err error + known, addresses, err = getNodeAddresses( + ctx, db, nodePub.SerializeCompressed(), + ) + if err != nil { + return fmt.Errorf("unable to fetch node addresses: %w", + err) + } + + return nil + }, func() {}) + if err != nil { + return false, nil, fmt.Errorf("unable to get addresses for "+ + "node(%x): %w", nodePub.SerializeCompressed(), err) + } + + return known, addresses, nil +} + +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error { + ctx := context.TODO() + + var writeTxOpts TxOptions + err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + res, err := db.DeleteNodeByPubKey( + ctx, sqlc.DeleteNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if err != nil { + return err + } + + rows, err := res.RowsAffected() + if err != nil { + return err + } + + if rows == 0 { + return ErrGraphNodeNotFound + } else if rows > 1 { + return fmt.Errorf("deleted %d rows, expected 1", rows) + } + + return err + }, func() {}) + if err != nil { + return fmt.Errorf("unable to delete node: %w", err) + } + + return nil +} + +// FetchNodeFeatures returns the features of the given node. If no features are +// known for the node, an empty feature vector is returned. +// +// NOTE: this is part of the graphdb.NodeTraverser interface. +func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + ctx := context.TODO() + + return fetchNodeFeatures(ctx, s.db, nodePub) +} + +// LookupAlias attempts to return the alias as advertised by the target node. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) { + var ( + ctx = context.TODO() + readTx = NewReadTx() + alias string + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + dbNode, err := db.GetNodeByPubKey( + ctx, sqlc.GetNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pub.SerializeCompressed(), + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return ErrNodeAliasNotFound + } else if err != nil { + return fmt.Errorf("unable to fetch node: %w", err) + } + + if !dbNode.Alias.Valid { + return ErrNodeAliasNotFound + } + + alias = dbNode.Alias.String + + return nil + }, func() {}) + if err != nil { + return "", fmt.Errorf("unable to look up alias: %w", err) + } + + return alias, nil +} + +// SourceNode returns the source node of the graph. The source node is treated +// as the center node within a star-graph. This method may be used to kick off +// a path finding algorithm in order to explore the reachability of another +// node based off the source node. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) SourceNode() (*models.LightningNode, error) { + ctx := context.TODO() + + var ( + readTx = NewReadTx() + node *models.LightningNode + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + _, nodePub, err := getSourceNode(ctx, db, ProtocolV1) + if err != nil { + return fmt.Errorf("unable to fetch V1 source node: %w", + err) + } + + _, node, err = getNodeByPubKey(ctx, db, nodePub) + + return err + }, func() {}) + if err != nil { + return nil, fmt.Errorf("unable to fetch source node: %w", err) + } + + return node, nil +} + +// SetSourceNode sets the source node within the graph database. The source +// node is to be used as the center of a star-graph within path finding +// algorithms. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) SetSourceNode(node *models.LightningNode) error { + ctx := context.TODO() + var writeTxOpts TxOptions + + return s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + id, err := upsertNode(ctx, db, node) + if err != nil { + return fmt.Errorf("unable to upsert source node: %w", + err) + } + + // Make sure that if a source node for this version is already + // set, then the ID is the same as the one we are about to set. + dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1) + if err != nil && !errors.Is(err, ErrSourceNodeNotSet) { + return fmt.Errorf("unable to fetch source node: %w", + err) + } else if err == nil { + if dbSourceNodeID != id { + return fmt.Errorf("v1 source node already "+ + "set to a different node: %d vs %d", + dbSourceNodeID, id) + } + + return nil + } + + return db.AddSourceNode(ctx, id) + }, func() {}) +} + +// NodeUpdatesInHorizon returns all the known lightning node which have an +// update timestamp within the passed range. This method can be used by two +// nodes to quickly determine if they have the same set of up to date node +// announcements. +// +// NOTE: This is part of the V1Store interface. +func (s *SQLStore) NodeUpdatesInHorizon(startTime, + endTime time.Time) ([]models.LightningNode, error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + nodes []models.LightningNode + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + dbNodes, err := db.GetNodesByLastUpdateRange( + ctx, sqlc.GetNodesByLastUpdateRangeParams{ + StartTime: sqldb.SQLInt64(startTime.Unix()), + EndTime: sqldb.SQLInt64(endTime.Unix()), + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch nodes: %w", err) + } + + for _, dbNode := range dbNodes { + node, err := buildNode(ctx, db, &dbNode) + if err != nil { + return fmt.Errorf("unable to build node: %w", + err) + } + + nodes = append(nodes, *node) + } + + return nil + }, func() {}) + if err != nil { + return nil, fmt.Errorf("unable to fetch nodes: %w", err) + } + + return nodes, nil +} + +// getNodeByPubKey attempts to look up a target node by its public key. +func getNodeByPubKey(ctx context.Context, db SQLQueries, + pubKey route.Vertex) (int64, *models.LightningNode, error) { + + dbNode, err := db.GetNodeByPubKey( + ctx, sqlc.GetNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil, ErrGraphNodeNotFound + } else if err != nil { + return 0, nil, fmt.Errorf("unable to fetch node: %w", err) + } + + node, err := buildNode(ctx, db, &dbNode) + if err != nil { + return 0, nil, fmt.Errorf("unable to build node: %w", err) + } + + return dbNode.ID, node, nil +} + +// buildNode constructs a LightningNode instance from the given database node +// record. The node's features, addresses and extra signed fields are also +// fetched from the database and set on the node. +func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) ( + *models.LightningNode, error) { + + if dbNode.Version != int16(ProtocolV1) { + return nil, fmt.Errorf("unsupported node version: %d", + dbNode.Version) + } + + var pub [33]byte + copy(pub[:], dbNode.PubKey) + + node := &models.LightningNode{ + PubKeyBytes: pub, + Features: lnwire.EmptyFeatureVector(), + LastUpdate: time.Unix(0, 0), + ExtraOpaqueData: make([]byte, 0), + } + + if len(dbNode.Signature) == 0 { + return node, nil + } + + node.HaveNodeAnnouncement = true + node.AuthSigBytes = dbNode.Signature + node.Alias = dbNode.Alias.String + node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0) + + var err error + node.Color, err = DecodeHexColor(dbNode.Color.String) + if err != nil { + return nil, fmt.Errorf("unable to decode color: %w", err) + } + + // Fetch the node's features. + node.Features, err = getNodeFeatures(ctx, db, dbNode.ID) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "features: %w", dbNode.ID, err) + } + + // Fetch the node's addresses. + _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:]) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "addresses: %w", dbNode.ID, err) + } + + // Fetch the node's extra signed fields. + extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "extra signed fields: %w", dbNode.ID, err) + } + + recs, err := lnwire.CustomRecords(extraTLVMap).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + + if len(recs) != 0 { + node.ExtraOpaqueData = recs + } + + return node, nil +} + +// getNodeFeatures fetches the feature bits and constructs the feature vector +// for a node with the given DB ID. +func getNodeFeatures(ctx context.Context, db SQLQueries, + nodeID int64) (*lnwire.FeatureVector, error) { + + rows, err := db.GetNodeFeatures(ctx, nodeID) + if err != nil { + return nil, fmt.Errorf("unable to get node(%d) features: %w", + nodeID, err) + } + + features := lnwire.EmptyFeatureVector() + for _, feature := range rows { + features.Set(lnwire.FeatureBit(feature.FeatureBit)) + } + + return features, nil +} + +// getNodeExtraSignedFields fetches the extra signed fields for a node with the +// given DB ID. +func getNodeExtraSignedFields(ctx context.Context, db SQLQueries, + nodeID int64) (map[uint64][]byte, error) { + + fields, err := db.GetExtraNodeTypes(ctx, nodeID) + if err != nil { + return nil, fmt.Errorf("unable to get node(%d) extra "+ + "signed fields: %w", nodeID, err) + } + + extraFields := make(map[uint64][]byte) + for _, field := range fields { + extraFields[uint64(field.Type)] = field.Value + } + + return extraFields, nil +} + +// upsertNode upserts the node record into the database. If the node already +// exists, then the node's information is updated. If the node doesn't exist, +// then a new node is created. The node's features, addresses and extra TLV +// types are also updated. The node's DB ID is returned. +func upsertNode(ctx context.Context, db SQLQueries, + node *models.LightningNode) (int64, error) { + + params := sqlc.UpsertNodeParams{ + Version: int16(ProtocolV1), + PubKey: node.PubKeyBytes[:], + } + + if node.HaveNodeAnnouncement { + params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix()) + params.Color = sqldb.SQLStr(EncodeHexColor(node.Color)) + params.Alias = sqldb.SQLStr(node.Alias) + params.Signature = node.AuthSigBytes + } + + nodeID, err := db.UpsertNode(ctx, params) + if err != nil { + return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes, + err) + } + + // We can exit here if we don't have the announcement yet. + if !node.HaveNodeAnnouncement { + return nodeID, nil + } + + // Update the node's features. + err = upsertNodeFeatures(ctx, db, nodeID, node.Features) + if err != nil { + return 0, fmt.Errorf("inserting node features: %w", err) + } + + // Update the node's addresses. + err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses) + if err != nil { + return 0, fmt.Errorf("inserting node addresses: %w", err) + } + + // Convert the flat extra opaque data into a map of TLV types to + // values. + extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData) + if err != nil { + return 0, fmt.Errorf("unable to marshal extra opaque data: %w", + err) + } + + // Update the node's extra signed fields. + err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra) + if err != nil { + return 0, fmt.Errorf("inserting node extra TLVs: %w", err) + } + + return nodeID, nil +} + +// upsertNodeFeatures updates the node's features node_features table. This +// includes deleting any feature bits no longer present and inserting any new +// feature bits. If the feature bit does not yet exist in the features table, +// then an entry is created in that table first. +func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64, + features *lnwire.FeatureVector) error { + + // Get any existing features for the node. + existingFeatures, err := db.GetNodeFeatures(ctx, nodeID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + // Copy the nodes latest set of feature bits. + newFeatures := make(map[int32]struct{}) + if features != nil { + for feature := range features.Features() { + newFeatures[int32(feature)] = struct{}{} + } + } + + // For any current feature that already exists in the DB, remove it from + // the in-memory map. For any existing feature that does not exist in + // the in-memory map, delete it from the database. + for _, feature := range existingFeatures { + // The feature is still present, so there are no updates to be + // made. + if _, ok := newFeatures[feature.FeatureBit]; ok { + delete(newFeatures, feature.FeatureBit) + continue + } + + // The feature is no longer present, so we remove it from the + // database. + err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{ + NodeID: nodeID, + FeatureBit: feature.FeatureBit, + }) + if err != nil { + return fmt.Errorf("unable to delete node(%d) "+ + "feature(%v): %w", nodeID, feature.FeatureBit, + err) + } + } + + // Any remaining entries in newFeatures are new features that need to be + // added to the database for the first time. + for feature := range newFeatures { + err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{ + NodeID: nodeID, + FeatureBit: feature, + }) + if err != nil { + return fmt.Errorf("unable to insert node(%d) "+ + "feature(%v): %w", nodeID, feature, err) + } + } + + return nil +} + +// fetchNodeFeatures fetches the features for a node with the given public key. +func fetchNodeFeatures(ctx context.Context, queries SQLQueries, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { + + rows, err := queries.GetNodeFeaturesByPubKey( + ctx, sqlc.GetNodeFeaturesByPubKeyParams{ + PubKey: nodePub[:], + Version: int16(ProtocolV1), + }, + ) + if err != nil { + return nil, fmt.Errorf("unable to get node(%s) features: %w", + nodePub, err) + } + + features := lnwire.EmptyFeatureVector() + for _, bit := range rows { + features.Set(lnwire.FeatureBit(bit)) + } + + return features, nil +} + +// dbAddressType is an enum type that represents the different address types +// that we store in the node_addresses table. The address type determines how +// the address is to be serialised/deserialize. +type dbAddressType uint8 + +const ( + addressTypeIPv4 dbAddressType = 1 + addressTypeIPv6 dbAddressType = 2 + addressTypeTorV2 dbAddressType = 3 + addressTypeTorV3 dbAddressType = 4 + addressTypeOpaque dbAddressType = math.MaxInt8 +) + +// upsertNodeAddresses updates the node's addresses in the database. This +// includes deleting any existing addresses and inserting the new set of +// addresses. The deletion is necessary since the ordering of the addresses may +// change, and we need to ensure that the database reflects the latest set of +// addresses so that at the time of reconstructing the node announcement, the +// order is preserved and the signature over the message remains valid. +func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64, + addresses []net.Addr) error { + + // Delete any existing addresses for the node. This is required since + // even if the new set of addresses is the same, the ordering may have + // changed for a given address type. + err := db.DeleteNodeAddresses(ctx, nodeID) + if err != nil { + return fmt.Errorf("unable to delete node(%d) addresses: %w", + nodeID, err) + } + + // Copy the nodes latest set of addresses. + newAddresses := map[dbAddressType][]string{ + addressTypeIPv4: {}, + addressTypeIPv6: {}, + addressTypeTorV2: {}, + addressTypeTorV3: {}, + addressTypeOpaque: {}, + } + addAddr := func(t dbAddressType, addr net.Addr) { + newAddresses[t] = append(newAddresses[t], addr.String()) + } + + for _, address := range addresses { + switch addr := address.(type) { + case *net.TCPAddr: + if ip4 := addr.IP.To4(); ip4 != nil { + addAddr(addressTypeIPv4, addr) + } else if ip6 := addr.IP.To16(); ip6 != nil { + addAddr(addressTypeIPv6, addr) + } else { + return fmt.Errorf("unhandled IP address: %v", + addr) + } + + case *tor.OnionAddr: + switch len(addr.OnionService) { + case tor.V2Len: + addAddr(addressTypeTorV2, addr) + case tor.V3Len: + addAddr(addressTypeTorV3, addr) + default: + return fmt.Errorf("invalid length for a tor " + + "address") + } + + case *lnwire.OpaqueAddrs: + addAddr(addressTypeOpaque, addr) + + default: + return fmt.Errorf("unhandled address type: %T", addr) + } + } + + // Any remaining entries in newAddresses are new addresses that need to + // be added to the database for the first time. + for addrType, addrList := range newAddresses { + for position, addr := range addrList { + err := db.InsertNodeAddress( + ctx, sqlc.InsertNodeAddressParams{ + NodeID: nodeID, + Type: int16(addrType), + Address: addr, + Position: int32(position), + }, + ) + if err != nil { + return fmt.Errorf("unable to insert "+ + "node(%d) address(%v): %w", nodeID, + addr, err) + } + } + } + + return nil +} + +// getNodeAddresses fetches the addresses for a node with the given public key. +func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool, + []net.Addr, error) { + + // GetNodeAddressesByPubKey ensures that the addresses for a given type + // are returned in the same order as they were inserted. + rows, err := db.GetNodeAddressesByPubKey( + ctx, sqlc.GetNodeAddressesByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: nodePub, + }, + ) + if err != nil { + return false, nil, err + } + + // GetNodeAddressesByPubKey uses a left join so there should always be + // at least one row returned if the node exists even if it has no + // addresses. + if len(rows) == 0 { + return false, nil, nil + } + + addresses := make([]net.Addr, 0, len(rows)) + for _, addr := range rows { + if !(addr.Type.Valid && addr.Address.Valid) { + continue + } + + address := addr.Address.String + + switch dbAddressType(addr.Type.Int16) { + case addressTypeIPv4: + tcp, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + return false, nil, nil + } + tcp.IP = tcp.IP.To4() + + addresses = append(addresses, tcp) + + case addressTypeIPv6: + tcp, err := net.ResolveTCPAddr("tcp6", address) + if err != nil { + return false, nil, nil + } + addresses = append(addresses, tcp) + + case addressTypeTorV3, addressTypeTorV2: + service, portStr, err := net.SplitHostPort(address) + if err != nil { + return false, nil, fmt.Errorf("unable to "+ + "split tor v3 address: %v", + addr.Address) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return false, nil, err + } + + addresses = append(addresses, &tor.OnionAddr{ + OnionService: service, + Port: port, + }) + + case addressTypeOpaque: + opaque, err := hex.DecodeString(address) + if err != nil { + return false, nil, fmt.Errorf("unable to "+ + "decode opaque address: %v", addr) + } + + addresses = append(addresses, &lnwire.OpaqueAddrs{ + Payload: opaque, + }) + + default: + return false, nil, fmt.Errorf("unknown address "+ + "type: %v", addr.Type) + } + } + + return true, addresses, nil +} + +// upsertNodeExtraSignedFields updates the node's extra signed fields in the +// database. This includes updating any existing types, inserting any new types, +// and deleting any types that are no longer present. +func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries, + nodeID int64, extraFields map[uint64][]byte) error { + + // Get any existing extra signed fields for the node. + existingFields, err := db.GetExtraNodeTypes(ctx, nodeID) + if err != nil { + return err + } + + // Make a lookup map of the existing field types so that we can use it + // to keep track of any fields we should delete. + m := make(map[uint64]bool) + for _, field := range existingFields { + m[uint64(field.Type)] = true + } + + // For all the new fields, we'll upsert them and remove them from the + // map of existing fields. + for tlvType, value := range extraFields { + err = db.UpsertNodeExtraType( + ctx, sqlc.UpsertNodeExtraTypeParams{ + NodeID: nodeID, + Type: int64(tlvType), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("unable to upsert node(%d) extra "+ + "signed field(%v): %w", nodeID, tlvType, err) + } + + // Remove the field from the map of existing fields if it was + // present. + delete(m, tlvType) + } + + // For all the fields that are left in the map of existing fields, we'll + // delete them as they are no longer present in the new set of fields. + for tlvType := range m { + err = db.DeleteExtraNodeType( + ctx, sqlc.DeleteExtraNodeTypeParams{ + NodeID: nodeID, + Type: int64(tlvType), + }, + ) + if err != nil { + return fmt.Errorf("unable to delete node(%d) extra "+ + "signed field(%v): %w", nodeID, tlvType, err) + } + } + + return nil +} + +// getSourceNode returns the DB node ID and pub key of the source node for the +// specified protocol version. +func getSourceNode(ctx context.Context, db SQLQueries, + version ProtocolVersion) (int64, route.Vertex, error) { + + var pubKey route.Vertex + + nodes, err := db.GetSourceNodesByVersion(ctx, int16(version)) + if err != nil { + return 0, pubKey, fmt.Errorf("unable to fetch source node: %w", + err) + } + + if len(nodes) == 0 { + return 0, pubKey, ErrSourceNodeNotSet + } else if len(nodes) > 1 { + return 0, pubKey, fmt.Errorf("multiple source nodes for "+ + "protocol %s found", version) + } + + copy(pubKey[:], nodes[0].PubKey) + + return nodes[0].NodeID, pubKey, nil +} + +// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream. +// This then produces a map from TLV type to value. If the input is not a +// valid TLV stream, then an error is returned. +func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) { + r := bytes.NewReader(data) + + tlvStream, err := tlv.NewStream() + if err != nil { + return nil, err + } + + // Since ExtraOpaqueData is provided by a potentially malicious peer, + // pass it into the P2P decoding variant. + parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r) + if err != nil { + return nil, err + } + if len(parsedTypes) == 0 { + return nil, nil + } + + records := make(map[uint64][]byte) + for k, v := range parsedTypes { + records[uint64(k)] = v + } + + return records, nil +} diff --git a/sqldb/migrations_dev.go b/sqldb/migrations_dev.go index 61323dbff..112e68dd0 100644 --- a/sqldb/migrations_dev.go +++ b/sqldb/migrations_dev.go @@ -2,4 +2,10 @@ package sqldb -var migrationAdditions = []MigrationConfig{} +var migrationAdditions = []MigrationConfig{ + { + Name: "000007_graph", + Version: 8, + SchemaVersion: 7, + }, +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go new file mode 100644 index 000000000..dcdd79053 --- /dev/null +++ b/sqldb/sqlc/graph.sql.go @@ -0,0 +1,453 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: graph.sql + +package sqlc + +import ( + "context" + "database/sql" +) + +const addSourceNode = `-- name: AddSourceNode :exec +/* ───────────────────────────────────────────── + source_nodes table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO source_nodes (node_id) +VALUES ($1) +ON CONFLICT (node_id) DO NOTHING +` + +func (q *Queries) AddSourceNode(ctx context.Context, nodeID int64) error { + _, err := q.db.ExecContext(ctx, addSourceNode, nodeID) + return err +} + +const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec +DELETE FROM node_extra_types +WHERE node_id = $1 + AND type = $2 +` + +type DeleteExtraNodeTypeParams struct { + NodeID int64 + Type int64 +} + +func (q *Queries) DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error { + _, err := q.db.ExecContext(ctx, deleteExtraNodeType, arg.NodeID, arg.Type) + return err +} + +const deleteNodeAddresses = `-- name: DeleteNodeAddresses :exec +DELETE FROM node_addresses +WHERE node_id = $1 +` + +func (q *Queries) DeleteNodeAddresses(ctx context.Context, nodeID int64) error { + _, err := q.db.ExecContext(ctx, deleteNodeAddresses, nodeID) + return err +} + +const deleteNodeByPubKey = `-- name: DeleteNodeByPubKey :execresult +DELETE FROM nodes +WHERE pub_key = $1 + AND version = $2 +` + +type DeleteNodeByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) { + return q.db.ExecContext(ctx, deleteNodeByPubKey, arg.PubKey, arg.Version) +} + +const deleteNodeFeature = `-- name: DeleteNodeFeature :exec +DELETE FROM node_features +WHERE node_id = $1 + AND feature_bit = $2 +` + +type DeleteNodeFeatureParams struct { + NodeID int64 + FeatureBit int32 +} + +func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error { + _, err := q.db.ExecContext(ctx, deleteNodeFeature, arg.NodeID, arg.FeatureBit) + return err +} + +const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many +SELECT node_id, type, value +FROM node_extra_types +WHERE node_id = $1 +` + +func (q *Queries) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) { + rows, err := q.db.QueryContext(ctx, getExtraNodeTypes, nodeID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []NodeExtraType + for rows.Next() { + var i NodeExtraType + if err := rows.Scan(&i.NodeID, &i.Type, &i.Value); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeAddressesByPubKey = `-- name: GetNodeAddressesByPubKey :many +SELECT a.type, a.address +FROM nodes n +LEFT JOIN node_addresses a ON a.node_id = n.id +WHERE n.pub_key = $1 AND n.version = $2 +ORDER BY a.type ASC, a.position ASC +` + +type GetNodeAddressesByPubKeyParams struct { + PubKey []byte + Version int16 +} + +type GetNodeAddressesByPubKeyRow struct { + Type sql.NullInt16 + Address sql.NullString +} + +func (q *Queries) GetNodeAddressesByPubKey(ctx context.Context, arg GetNodeAddressesByPubKeyParams) ([]GetNodeAddressesByPubKeyRow, error) { + rows, err := q.db.QueryContext(ctx, getNodeAddressesByPubKey, arg.PubKey, arg.Version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetNodeAddressesByPubKeyRow + for rows.Next() { + var i GetNodeAddressesByPubKeyRow + if err := rows.Scan(&i.Type, &i.Address); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeByPubKey = `-- name: GetNodeByPubKey :one +SELECT id, version, pub_key, alias, last_update, color, signature +FROM nodes +WHERE pub_key = $1 + AND version = $2 +` + +type GetNodeByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error) { + row := q.db.QueryRowContext(ctx, getNodeByPubKey, arg.PubKey, arg.Version) + var i Node + err := row.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + ) + return i, err +} + +const getNodeFeatures = `-- name: GetNodeFeatures :many +SELECT node_id, feature_bit +FROM node_features +WHERE node_id = $1 +` + +func (q *Queries) GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) { + rows, err := q.db.QueryContext(ctx, getNodeFeatures, nodeID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []NodeFeature + for rows.Next() { + var i NodeFeature + if err := rows.Scan(&i.NodeID, &i.FeatureBit); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeFeaturesByPubKey = `-- name: GetNodeFeaturesByPubKey :many +SELECT f.feature_bit +FROM nodes n + JOIN node_features f ON f.node_id = n.id +WHERE n.pub_key = $1 + AND n.version = $2 +` + +type GetNodeFeaturesByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, getNodeFeaturesByPubKey, arg.PubKey, arg.Version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var feature_bit int32 + if err := rows.Scan(&feature_bit); err != nil { + return nil, err + } + items = append(items, feature_bit) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodesByLastUpdateRange = `-- name: GetNodesByLastUpdateRange :many +SELECT id, version, pub_key, alias, last_update, color, signature +FROM nodes +WHERE last_update >= $1 + AND last_update < $2 +` + +type GetNodesByLastUpdateRangeParams struct { + StartTime sql.NullInt64 + EndTime sql.NullInt64 +} + +func (q *Queries) GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) { + rows, err := q.db.QueryContext(ctx, getNodesByLastUpdateRange, arg.StartTime, arg.EndTime) + if err != nil { + return nil, err + } + defer rows.Close() + var items []Node + for rows.Next() { + var i Node + if err := rows.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getSourceNodesByVersion = `-- name: GetSourceNodesByVersion :many +SELECT sn.node_id, n.pub_key +FROM source_nodes sn + JOIN nodes n ON sn.node_id = n.id +WHERE n.version = $1 +` + +type GetSourceNodesByVersionRow struct { + NodeID int64 + PubKey []byte +} + +func (q *Queries) GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) { + rows, err := q.db.QueryContext(ctx, getSourceNodesByVersion, version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetSourceNodesByVersionRow + for rows.Next() { + var i GetSourceNodesByVersionRow + if err := rows.Scan(&i.NodeID, &i.PubKey); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertNodeAddress = `-- name: InsertNodeAddress :exec +/* ───────────────────────────────────────────── + node_addresses table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_addresses ( + node_id, + type, + address, + position +) VALUES ( + $1, $2, $3, $4 + ) +` + +type InsertNodeAddressParams struct { + NodeID int64 + Type int16 + Address string + Position int32 +} + +func (q *Queries) InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error { + _, err := q.db.ExecContext(ctx, insertNodeAddress, + arg.NodeID, + arg.Type, + arg.Address, + arg.Position, + ) + return err +} + +const insertNodeFeature = `-- name: InsertNodeFeature :exec +/* ───────────────────────────────────────────── + node_features table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_features ( + node_id, feature_bit +) VALUES ( + $1, $2 +) +` + +type InsertNodeFeatureParams struct { + NodeID int64 + FeatureBit int32 +} + +func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error { + _, err := q.db.ExecContext(ctx, insertNodeFeature, arg.NodeID, arg.FeatureBit) + return err +} + +const upsertNode = `-- name: UpsertNode :one +/* ───────────────────────────────────────────── + nodes table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO nodes ( + version, pub_key, alias, last_update, color, signature +) VALUES ( + $1, $2, $3, $4, $5, $6 +) +ON CONFLICT (pub_key, version) + -- Update the following fields if a conflict occurs on pub_key + -- and version. + DO UPDATE SET + alias = EXCLUDED.alias, + last_update = EXCLUDED.last_update, + color = EXCLUDED.color, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > nodes.last_update +RETURNING id +` + +type UpsertNodeParams struct { + Version int16 + PubKey []byte + Alias sql.NullString + LastUpdate sql.NullInt64 + Color sql.NullString + Signature []byte +} + +func (q *Queries) UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, upsertNode, + arg.Version, + arg.PubKey, + arg.Alias, + arg.LastUpdate, + arg.Color, + arg.Signature, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertNodeExtraType = `-- name: UpsertNodeExtraType :exec +/* ───────────────────────────────────────────── + node_extra_types table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_extra_types ( + node_id, type, value +) +VALUES ($1, $2, $3) +ON CONFLICT (type, node_id) + -- Update the value if a conflict occurs on type + -- and node_id. + DO UPDATE SET value = EXCLUDED.value +` + +type UpsertNodeExtraTypeParams struct { + NodeID int64 + Type int64 + Value []byte +} + +func (q *Queries) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error { + _, err := q.db.ExecContext(ctx, upsertNodeExtraType, arg.NodeID, arg.Type, arg.Value) + return err +} diff --git a/sqldb/sqlc/migrations/000007_graph.down.sql b/sqldb/sqlc/migrations/000007_graph.down.sql new file mode 100644 index 000000000..79489d7bd --- /dev/null +++ b/sqldb/sqlc/migrations/000007_graph.down.sql @@ -0,0 +1,13 @@ +-- Drop indexes. +DROP INDEX IF EXISTS nodes_unique; +DROP INDEX IF EXISTS node_extra_types_unique; +DROP INDEX IF EXISTS node_features_unique; +DROP INDEX IF EXISTS node_addresses_unique; +DROP INDEX IF EXISTS source_nodes_unique; + +-- Drop tables in order of reverse dependencies. +DROP TABLE IF EXISTS source_nodes; +DROP TABLE IF EXISTS node_addresses; +DROP TABLE IF EXISTS node_features; +DROP TABLE IF EXISTS node_extra_types; +DROP TABLE IF EXISTS nodes; \ No newline at end of file diff --git a/sqldb/sqlc/migrations/000007_graph.up.sql b/sqldb/sqlc/migrations/000007_graph.up.sql new file mode 100644 index 000000000..0efc750e3 --- /dev/null +++ b/sqldb/sqlc/migrations/000007_graph.up.sql @@ -0,0 +1,97 @@ +/* ───────────────────────────────────────────── + node data tables + ───────────────────────────────────────────── +*/ + +-- nodes stores all the nodes that we are aware of in the LN graph. +CREATE TABLE IF NOT EXISTS nodes ( + -- The db ID of the node. This will only be used DB level + -- relations. + id INTEGER PRIMARY KEY, + + -- The protocol version that this node was gossiped on. + version SMALLINT NOT NULL, + + -- The public key (serialised compressed) of the node. + pub_key BLOB NOT NULL, + + -- The alias of the node. + alias TEXT, + + -- The unix timestamp of the last time the node was updated. + last_update BIGINT, + + -- The color of the node. + color VARCHAR, + + -- The signature of the node announcement. If this is null, then + -- the node announcement has not been received yet and this record + -- is a shell node. This can be the case if we receive a channel + -- announcement for a channel that is connected to a node that we + -- have not yet received a node announcement for. + signature BLOB +); + +-- A node (identified by a public key) can only have one active node +-- announcement per protocol. +CREATE UNIQUE INDEX IF NOT EXISTS nodes_unique ON nodes ( + pub_key, version +); + +-- node_extra_types stores any extra TLV fields covered by a node announcement that +-- we do not have an explicit column for in the nodes table. +CREATE TABLE IF NOT EXISTS node_extra_types ( + -- The node id this TLV field belongs to. + node_id BIGINT NOT NULL REFERENCES nodes(id) ON DELETE CASCADE, + + -- The Type field. + type BIGINT NOT NULL, + + -- The value field. + value BLOB +); +CREATE UNIQUE INDEX IF NOT EXISTS node_extra_types_unique ON node_extra_types ( + type, node_id +); + +-- node_features contains the feature bits of a node. +CREATE TABLE IF NOT EXISTS node_features ( + -- The node id this feature belongs to. + node_id BIGINT NOT NULL REFERENCES nodes(id) ON DELETE CASCADE, + + -- The feature bit value. + feature_bit INTEGER NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS node_features_unique ON node_features ( + node_id, feature_bit +); + +-- node_addresses contains the advertised addresses of nodes. +CREATE TABLE IF NOT EXISTS node_addresses ( + -- The node id this feature belongs to. + node_id BIGINT NOT NULL REFERENCES nodes(id) ON DELETE CASCADE, + + -- An enum that represents the type of address. This will + -- dictate how the address column should be parsed. + type SMALLINT NOT NULL, + + -- position is position of this address in the list of addresses + -- under the given type as it appeared in the node announcement. + -- We need to store this so that when we reconstruct the node + -- announcement, we preserve the original order of the addresses + -- so that the signature of the announcement remains valid. + position INTEGER NOT NULL, + + -- The advertised address of the node. + address TEXT NOT NULL +); +CREATE UNIQUE INDEX IF NOT EXISTS node_addresses_unique ON node_addresses ( + node_id, type, position +); + +CREATE TABLE IF NOT EXISTS source_nodes ( + node_id BIGINT NOT NULL REFERENCES nodes (id) ON DELETE CASCADE +); +CREATE UNIQUE INDEX IF NOT EXISTS source_nodes_unique ON source_nodes ( + node_id +); \ No newline at end of file diff --git a/sqldb/sqlc/models.go b/sqldb/sqlc/models.go index b96cc0e94..07269c359 100644 --- a/sqldb/sqlc/models.go +++ b/sqldb/sqlc/models.go @@ -102,3 +102,35 @@ type MigrationTracker struct { Version int32 MigrationTime time.Time } + +type Node struct { + ID int64 + Version int16 + PubKey []byte + Alias sql.NullString + LastUpdate sql.NullInt64 + Color sql.NullString + Signature []byte +} + +type NodeAddress struct { + NodeID int64 + Type int16 + Position int32 + Address string +} + +type NodeExtraType struct { + NodeID int64 + Type int64 + Value []byte +} + +type NodeFeature struct { + NodeID int64 + FeatureBit int32 +} + +type SourceNode struct { + NodeID int64 +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 5da90a6a3..7492a21e0 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -11,15 +11,21 @@ import ( ) type Querier interface { + AddSourceNode(ctx context.Context, nodeID int64) error ClearKVInvoiceHashIndex(ctx context.Context) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) + DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) + DeleteNodeAddresses(ctx context.Context, nodeID int64) error + DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) + DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetDatabaseVersion(ctx context.Context) (int32, error) + GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. @@ -31,6 +37,12 @@ type Querier interface { GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) GetMigration(ctx context.Context, version int32) (time.Time, error) + GetNodeAddressesByPubKey(ctx context.Context, arg GetNodeAddressesByPubKeyParams) ([]GetNodeAddressesByPubKeyRow, error) + GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error) + GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) + GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) + GetNodesByLastUpdateRange(ctx context.Context, arg GetNodesByLastUpdateRangeParams) ([]Node, error) + GetSourceNodesByVersion(ctx context.Context, version int16) ([]GetSourceNodesByVersionRow, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) @@ -39,6 +51,8 @@ type Querier interface { InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) + InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error + InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error @@ -55,6 +69,8 @@ type Querier interface { UpdateInvoiceHTLCs(ctx context.Context, arg UpdateInvoiceHTLCsParams) error UpdateInvoiceState(ctx context.Context, arg UpdateInvoiceStateParams) (sql.Result, error) UpsertAMPSubInvoice(ctx context.Context, arg UpsertAMPSubInvoiceParams) (sql.Result, error) + UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) + UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error } var _ Querier = (*Queries)(nil) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql new file mode 100644 index 000000000..a6745227e --- /dev/null +++ b/sqldb/sqlc/queries/graph.sql @@ -0,0 +1,134 @@ +/* ───────────────────────────────────────────── + nodes table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertNode :one +INSERT INTO nodes ( + version, pub_key, alias, last_update, color, signature +) VALUES ( + $1, $2, $3, $4, $5, $6 +) +ON CONFLICT (pub_key, version) + -- Update the following fields if a conflict occurs on pub_key + -- and version. + DO UPDATE SET + alias = EXCLUDED.alias, + last_update = EXCLUDED.last_update, + color = EXCLUDED.color, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > nodes.last_update +RETURNING id; + +-- name: GetNodeByPubKey :one +SELECT * +FROM nodes +WHERE pub_key = $1 + AND version = $2; + +-- name: DeleteNodeByPubKey :execresult +DELETE FROM nodes +WHERE pub_key = $1 + AND version = $2; + +/* ───────────────────────────────────────────── + node_features table queries + ───────────────────────────────────────────── +*/ + +-- name: InsertNodeFeature :exec +INSERT INTO node_features ( + node_id, feature_bit +) VALUES ( + $1, $2 +); + +-- name: GetNodeFeatures :many +SELECT * +FROM node_features +WHERE node_id = $1; + +-- name: GetNodeFeaturesByPubKey :many +SELECT f.feature_bit +FROM nodes n + JOIN node_features f ON f.node_id = n.id +WHERE n.pub_key = $1 + AND n.version = $2; + +-- name: DeleteNodeFeature :exec +DELETE FROM node_features +WHERE node_id = $1 + AND feature_bit = $2; + +/* ───────────────────────────────────────────── + node_addresses table queries + ───────────────────────────────────────────── +*/ + +-- name: InsertNodeAddress :exec +INSERT INTO node_addresses ( + node_id, + type, + address, + position +) VALUES ( + $1, $2, $3, $4 + ); + +-- name: GetNodeAddressesByPubKey :many +SELECT a.type, a.address +FROM nodes n +LEFT JOIN node_addresses a ON a.node_id = n.id +WHERE n.pub_key = $1 AND n.version = $2 +ORDER BY a.type ASC, a.position ASC; + +-- name: GetNodesByLastUpdateRange :many +SELECT * +FROM nodes +WHERE last_update >= @start_time + AND last_update < @end_time; + +-- name: DeleteNodeAddresses :exec +DELETE FROM node_addresses +WHERE node_id = $1; + +/* ───────────────────────────────────────────── + node_extra_types table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertNodeExtraType :exec +INSERT INTO node_extra_types ( + node_id, type, value +) +VALUES ($1, $2, $3) +ON CONFLICT (type, node_id) + -- Update the value if a conflict occurs on type + -- and node_id. + DO UPDATE SET value = EXCLUDED.value; + +-- name: GetExtraNodeTypes :many +SELECT * +FROM node_extra_types +WHERE node_id = $1; + +-- name: DeleteExtraNodeType :exec +DELETE FROM node_extra_types +WHERE node_id = $1 + AND type = $2; + +/* ───────────────────────────────────────────── + source_nodes table queries + ───────────────────────────────────────────── +*/ + +-- name: AddSourceNode :exec +INSERT INTO source_nodes (node_id) +VALUES ($1) +ON CONFLICT (node_id) DO NOTHING; + +-- name: GetSourceNodesByVersion :many +SELECT sn.node_id, n.pub_key +FROM source_nodes sn + JOIN nodes n ON sn.node_id = n.id +WHERE n.version = $1;