diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index cab857fee..0b812426d 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -101,15 +101,17 @@ func createTestVertex(t testing.TB) *models.LightningNode { func TestNodeInsertionAndDeletion(t *testing.T) { t.Parallel() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) // We'd like to test basic insertion/deletion for vertexes from the // graph, so we'll create a test vertex to start with. + timeStamp := int64(1232342) nodeWithAddrs := func(addrs []net.Addr) *models.LightningNode { + timeStamp++ return &models.LightningNode{ HaveNodeAnnouncement: true, AuthSigBytes: testSig.Serialize(), - LastUpdate: time.Unix(1232342, 0), + LastUpdate: time.Unix(timeStamp, 0), Color: color.RGBA{1, 2, 3, 0}, Alias: "kek", Features: testFeatures, @@ -4315,7 +4317,7 @@ func TestLightningNodePersistence(t *testing.T) { t.Parallel() // Create a new test graph instance. - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) nodeAnnBytes, err := hex.DecodeString(testNodeAnn) require.NoError(t, err) diff --git a/graph/db/notifications.go b/graph/db/notifications.go index 8b69b7d85..d573e0aa5 100644 --- a/graph/db/notifications.go +++ b/graph/db/notifications.go @@ -4,6 +4,7 @@ import ( "fmt" "image/color" "net" + "strconv" "sync" "sync/atomic" @@ -463,3 +464,31 @@ func (c *ChannelGraph) addToTopologyChange(update *TopologyChange, func EncodeHexColor(color color.RGBA) string { return fmt.Sprintf("#%02x%02x%02x", color.R, color.G, color.B) } + +// DecodeHexColor takes a hex color string like "#rrggbb" and returns a +// color.RGBA. +func DecodeHexColor(hex string) (color.RGBA, error) { + r, err := strconv.ParseUint(hex[1:3], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid red component: %w", + err) + } + + g, err := strconv.ParseUint(hex[3:5], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid green component: %w", + err) + } + + b, err := strconv.ParseUint(hex[5:7], 16, 8) + if err != nil { + return color.RGBA{}, fmt.Errorf("invalid blue component: %w", + err) + } + + return color.RGBA{ + R: uint8(r), + G: uint8(g), + B: uint8(b), + }, nil +} diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index 3947ae0d8..4e6c876cf 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -1,16 +1,67 @@ package graphdb import ( + "bytes" + "context" + "database/sql" + "encoding/hex" + "errors" "fmt" + "math" + "net" + "strconv" "sync" + "time" + "github.com/btcsuite/btcd/btcec/v2" "github.com/lightningnetwork/lnd/batch" + "github.com/lightningnetwork/lnd/graph/db/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sqldb/sqlc" + "github.com/lightningnetwork/lnd/tlv" + "github.com/lightningnetwork/lnd/tor" ) +// ProtocolVersion is an enum that defines the gossip protocol version of a +// message. +type ProtocolVersion uint8 + +const ( + // ProtocolV1 is the gossip protocol version defined in BOLT #7. + ProtocolV1 ProtocolVersion = 1 +) + +// String returns a string representation of the protocol version. +func (v ProtocolVersion) String() string { + return fmt.Sprintf("V%d", v) +} + // SQLQueries is a subset of the sqlc.Querier interface that can be used to // execute queries against the SQL graph tables. +// +//nolint:ll,interfacebloat type SQLQueries interface { + /* + Node queries. + */ + UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error) + GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error) + DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error) + + GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error) + UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error + DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error + + InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error + GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error) + DeleteNodeAddresses(ctx context.Context, nodeID int64) error + + InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error + GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error) + GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error) + DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -80,3 +131,749 @@ func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore, return s, nil } + +// TxOptions defines the set of db txn options the SQLQueries +// understands. +type TxOptions struct { + // readOnly governs if a read only transaction is needed or not. + readOnly bool +} + +// ReadOnly returns true if the transaction should be read only. +// +// NOTE: This implements the TxOptions. +func (a *TxOptions) ReadOnly() bool { + return a.readOnly +} + +// NewReadTx creates a new read transaction option set. +func NewReadTx() *TxOptions { + return &TxOptions{ + readOnly: true, + } +} + +// AddLightningNode adds a vertex/node to the graph database. If the node is not +// in the database from before, this will add a new, unconnected one to the +// graph. If it is present from before, this will update that node's +// information. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) AddLightningNode(node *models.LightningNode, + opts ...batch.SchedulerOption) error { + + ctx := context.TODO() + + r := &batch.Request[SQLQueries]{ + Opts: batch.NewSchedulerOptions(opts...), + Do: func(queries SQLQueries) error { + _, err := upsertNode(ctx, queries, node) + return err + }, + } + + return s.nodeScheduler.Execute(ctx, r) +} + +// FetchLightningNode attempts to look up a target node by its identity public +// key. If the node isn't found in the database, then ErrGraphNodeNotFound is +// returned. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) ( + *models.LightningNode, error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + node *models.LightningNode + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + var err error + _, node, err = getNodeByPubKey(ctx, db, pubKey) + + return err + }, func() {}) + if err != nil { + return nil, fmt.Errorf("unable to fetch node: %w", err) + } + + return node, nil +} + +// HasLightningNode determines if the graph has a vertex identified by the +// target node identity public key. If the node exists in the database, a +// timestamp of when the data for the node was lasted updated is returned along +// with a true boolean. Otherwise, an empty time.Time is returned with a false +// boolean. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool, + error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + exists bool + lastUpdate time.Time + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + dbNode, err := db.GetNodeByPubKey( + ctx, sqlc.GetNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return nil + } else if err != nil { + return fmt.Errorf("unable to fetch node: %w", err) + } + + exists = true + + if dbNode.LastUpdate.Valid { + lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0) + } + + return nil + }, func() {}) + if err != nil { + return time.Time{}, false, + fmt.Errorf("unable to fetch node: %w", err) + } + + return lastUpdate, exists, nil +} + +// AddrsForNode returns all known addresses for the target node public key +// that the graph DB is aware of. The returned boolean indicates if the +// given node is unknown to the graph DB or not. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr, + error) { + + ctx := context.TODO() + + var ( + readTx = NewReadTx() + addresses []net.Addr + known bool + ) + err := s.db.ExecTx(ctx, readTx, func(db SQLQueries) error { + var err error + known, addresses, err = getNodeAddresses( + ctx, db, nodePub.SerializeCompressed(), + ) + if err != nil { + return fmt.Errorf("unable to fetch node addresses: %w", + err) + } + + return nil + }, func() {}) + if err != nil { + return false, nil, fmt.Errorf("unable to get addresses for "+ + "node(%x): %w", nodePub.SerializeCompressed(), err) + } + + return known, addresses, nil +} + +// DeleteLightningNode starts a new database transaction to remove a vertex/node +// from the database according to the node's public key. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error { + ctx := context.TODO() + + var writeTxOpts TxOptions + err := s.db.ExecTx(ctx, &writeTxOpts, func(db SQLQueries) error { + res, err := db.DeleteNodeByPubKey( + ctx, sqlc.DeleteNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if err != nil { + return err + } + + rows, err := res.RowsAffected() + if err != nil { + return err + } + + if rows == 0 { + return ErrGraphNodeNotFound + } else if rows > 1 { + return fmt.Errorf("deleted %d rows, expected 1", rows) + } + + return err + }, func() {}) + if err != nil { + return fmt.Errorf("unable to delete node: %w", err) + } + + return nil +} + +// FetchNodeFeatures returns the features of the given node. If no features are +// known for the node, an empty feature vector is returned. +// +// NOTE: this is part of the graphdb.NodeTraverser interface. +func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) ( + *lnwire.FeatureVector, error) { + + ctx := context.TODO() + + return fetchNodeFeatures(ctx, s.db, nodePub) +} + +// getNodeByPubKey attempts to look up a target node by its public key. +func getNodeByPubKey(ctx context.Context, db SQLQueries, + pubKey route.Vertex) (int64, *models.LightningNode, error) { + + dbNode, err := db.GetNodeByPubKey( + ctx, sqlc.GetNodeByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: pubKey[:], + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return 0, nil, ErrGraphNodeNotFound + } else if err != nil { + return 0, nil, fmt.Errorf("unable to fetch node: %w", err) + } + + node, err := buildNode(ctx, db, &dbNode) + if err != nil { + return 0, nil, fmt.Errorf("unable to build node: %w", err) + } + + return dbNode.ID, node, nil +} + +// buildNode constructs a LightningNode instance from the given database node +// record. The node's features, addresses and extra signed fields are also +// fetched from the database and set on the node. +func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) ( + *models.LightningNode, error) { + + if dbNode.Version != int16(ProtocolV1) { + return nil, fmt.Errorf("unsupported node version: %d", + dbNode.Version) + } + + var pub [33]byte + copy(pub[:], dbNode.PubKey) + + node := &models.LightningNode{ + PubKeyBytes: pub, + Features: lnwire.EmptyFeatureVector(), + LastUpdate: time.Unix(0, 0), + ExtraOpaqueData: make([]byte, 0), + } + + if len(dbNode.Signature) == 0 { + return node, nil + } + + node.HaveNodeAnnouncement = true + node.AuthSigBytes = dbNode.Signature + node.Alias = dbNode.Alias.String + node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0) + + var err error + node.Color, err = DecodeHexColor(dbNode.Color.String) + if err != nil { + return nil, fmt.Errorf("unable to decode color: %w", err) + } + + // Fetch the node's features. + node.Features, err = getNodeFeatures(ctx, db, dbNode.ID) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "features: %w", dbNode.ID, err) + } + + // Fetch the node's addresses. + _, node.Addresses, err = getNodeAddresses(ctx, db, pub[:]) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "addresses: %w", dbNode.ID, err) + } + + // Fetch the node's extra signed fields. + extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID) + if err != nil { + return nil, fmt.Errorf("unable to fetch node(%d) "+ + "extra signed fields: %w", dbNode.ID, err) + } + + recs, err := lnwire.CustomRecords(extraTLVMap).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra signed "+ + "fields: %w", err) + } + + if len(recs) != 0 { + node.ExtraOpaqueData = recs + } + + return node, nil +} + +// getNodeFeatures fetches the feature bits and constructs the feature vector +// for a node with the given DB ID. +func getNodeFeatures(ctx context.Context, db SQLQueries, + nodeID int64) (*lnwire.FeatureVector, error) { + + rows, err := db.GetNodeFeatures(ctx, nodeID) + if err != nil { + return nil, fmt.Errorf("unable to get node(%d) features: %w", + nodeID, err) + } + + features := lnwire.EmptyFeatureVector() + for _, feature := range rows { + features.Set(lnwire.FeatureBit(feature.FeatureBit)) + } + + return features, nil +} + +// getNodeExtraSignedFields fetches the extra signed fields for a node with the +// given DB ID. +func getNodeExtraSignedFields(ctx context.Context, db SQLQueries, + nodeID int64) (map[uint64][]byte, error) { + + fields, err := db.GetExtraNodeTypes(ctx, nodeID) + if err != nil { + return nil, fmt.Errorf("unable to get node(%d) extra "+ + "signed fields: %w", nodeID, err) + } + + extraFields := make(map[uint64][]byte) + for _, field := range fields { + extraFields[uint64(field.Type)] = field.Value + } + + return extraFields, nil +} + +// upsertNode upserts the node record into the database. If the node already +// exists, then the node's information is updated. If the node doesn't exist, +// then a new node is created. The node's features, addresses and extra TLV +// types are also updated. The node's DB ID is returned. +func upsertNode(ctx context.Context, db SQLQueries, + node *models.LightningNode) (int64, error) { + + params := sqlc.UpsertNodeParams{ + Version: int16(ProtocolV1), + PubKey: node.PubKeyBytes[:], + } + + if node.HaveNodeAnnouncement { + params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix()) + params.Color = sqldb.SQLStr(EncodeHexColor(node.Color)) + params.Alias = sqldb.SQLStr(node.Alias) + params.Signature = node.AuthSigBytes + } + + nodeID, err := db.UpsertNode(ctx, params) + if err != nil { + return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes, + err) + } + + // We can exit here if we don't have the announcement yet. + if !node.HaveNodeAnnouncement { + return nodeID, nil + } + + // Update the node's features. + err = upsertNodeFeatures(ctx, db, nodeID, node.Features) + if err != nil { + return 0, fmt.Errorf("inserting node features: %w", err) + } + + // Update the node's addresses. + err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses) + if err != nil { + return 0, fmt.Errorf("inserting node addresses: %w", err) + } + + // Convert the flat extra opaque data into a map of TLV types to + // values. + extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData) + if err != nil { + return 0, fmt.Errorf("unable to marshal extra opaque data: %w", + err) + } + + // Update the node's extra signed fields. + err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra) + if err != nil { + return 0, fmt.Errorf("inserting node extra TLVs: %w", err) + } + + return nodeID, nil +} + +// upsertNodeFeatures updates the node's features node_features table. This +// includes deleting any feature bits no longer present and inserting any new +// feature bits. If the feature bit does not yet exist in the features table, +// then an entry is created in that table first. +func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64, + features *lnwire.FeatureVector) error { + + // Get any existing features for the node. + existingFeatures, err := db.GetNodeFeatures(ctx, nodeID) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + // Copy the nodes latest set of feature bits. + newFeatures := make(map[int32]struct{}) + if features != nil { + for feature := range features.Features() { + newFeatures[int32(feature)] = struct{}{} + } + } + + // For any current feature that already exists in the DB, remove it from + // the in-memory map. For any existing feature that does not exist in + // the in-memory map, delete it from the database. + for _, feature := range existingFeatures { + // The feature is still present, so there are no updates to be + // made. + if _, ok := newFeatures[feature.FeatureBit]; ok { + delete(newFeatures, feature.FeatureBit) + continue + } + + // The feature is no longer present, so we remove it from the + // database. + err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{ + NodeID: nodeID, + FeatureBit: feature.FeatureBit, + }) + if err != nil { + return fmt.Errorf("unable to delete node(%d) "+ + "feature(%v): %w", nodeID, feature.FeatureBit, + err) + } + } + + // Any remaining entries in newFeatures are new features that need to be + // added to the database for the first time. + for feature := range newFeatures { + err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{ + NodeID: nodeID, + FeatureBit: feature, + }) + if err != nil { + return fmt.Errorf("unable to insert node(%d) "+ + "feature(%v): %w", nodeID, feature, err) + } + } + + return nil +} + +// fetchNodeFeatures fetches the features for a node with the given public key. +func fetchNodeFeatures(ctx context.Context, queries SQLQueries, + nodePub route.Vertex) (*lnwire.FeatureVector, error) { + + rows, err := queries.GetNodeFeaturesByPubKey( + ctx, sqlc.GetNodeFeaturesByPubKeyParams{ + PubKey: nodePub[:], + Version: int16(ProtocolV1), + }, + ) + if err != nil { + return nil, fmt.Errorf("unable to get node(%s) features: %w", + nodePub, err) + } + + features := lnwire.EmptyFeatureVector() + for _, bit := range rows { + features.Set(lnwire.FeatureBit(bit)) + } + + return features, nil +} + +// dbAddressType is an enum type that represents the different address types +// that we store in the node_addresses table. The address type determines how +// the address is to be serialised/deserialize. +type dbAddressType uint8 + +const ( + addressTypeIPv4 dbAddressType = 1 + addressTypeIPv6 dbAddressType = 2 + addressTypeTorV2 dbAddressType = 3 + addressTypeTorV3 dbAddressType = 4 + addressTypeOpaque dbAddressType = math.MaxInt8 +) + +// upsertNodeAddresses updates the node's addresses in the database. This +// includes deleting any existing addresses and inserting the new set of +// addresses. The deletion is necessary since the ordering of the addresses may +// change, and we need to ensure that the database reflects the latest set of +// addresses so that at the time of reconstructing the node announcement, the +// order is preserved and the signature over the message remains valid. +func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64, + addresses []net.Addr) error { + + // Delete any existing addresses for the node. This is required since + // even if the new set of addresses is the same, the ordering may have + // changed for a given address type. + err := db.DeleteNodeAddresses(ctx, nodeID) + if err != nil { + return fmt.Errorf("unable to delete node(%d) addresses: %w", + nodeID, err) + } + + // Copy the nodes latest set of addresses. + newAddresses := map[dbAddressType][]string{ + addressTypeIPv4: {}, + addressTypeIPv6: {}, + addressTypeTorV2: {}, + addressTypeTorV3: {}, + addressTypeOpaque: {}, + } + addAddr := func(t dbAddressType, addr net.Addr) { + newAddresses[t] = append(newAddresses[t], addr.String()) + } + + for _, address := range addresses { + switch addr := address.(type) { + case *net.TCPAddr: + if ip4 := addr.IP.To4(); ip4 != nil { + addAddr(addressTypeIPv4, addr) + } else if ip6 := addr.IP.To16(); ip6 != nil { + addAddr(addressTypeIPv6, addr) + } else { + return fmt.Errorf("unhandled IP address: %v", + addr) + } + + case *tor.OnionAddr: + switch len(addr.OnionService) { + case tor.V2Len: + addAddr(addressTypeTorV2, addr) + case tor.V3Len: + addAddr(addressTypeTorV3, addr) + default: + return fmt.Errorf("invalid length for a tor " + + "address") + } + + case *lnwire.OpaqueAddrs: + addAddr(addressTypeOpaque, addr) + + default: + return fmt.Errorf("unhandled address type: %T", addr) + } + } + + // Any remaining entries in newAddresses are new addresses that need to + // be added to the database for the first time. + for addrType, addrList := range newAddresses { + for position, addr := range addrList { + err := db.InsertNodeAddress( + ctx, sqlc.InsertNodeAddressParams{ + NodeID: nodeID, + Type: int16(addrType), + Address: addr, + Position: int32(position), + }, + ) + if err != nil { + return fmt.Errorf("unable to insert "+ + "node(%d) address(%v): %w", nodeID, + addr, err) + } + } + } + + return nil +} + +// getNodeAddresses fetches the addresses for a node with the given public key. +func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool, + []net.Addr, error) { + + // GetNodeAddressesByPubKey ensures that the addresses for a given type + // are returned in the same order as they were inserted. + rows, err := db.GetNodeAddressesByPubKey( + ctx, sqlc.GetNodeAddressesByPubKeyParams{ + Version: int16(ProtocolV1), + PubKey: nodePub, + }, + ) + if err != nil { + return false, nil, err + } + + // GetNodeAddressesByPubKey uses a left join so there should always be + // at least one row returned if the node exists even if it has no + // addresses. + if len(rows) == 0 { + return false, nil, nil + } + + addresses := make([]net.Addr, 0, len(rows)) + for _, addr := range rows { + if !(addr.Type.Valid && addr.Address.Valid) { + continue + } + + address := addr.Address.String + + switch dbAddressType(addr.Type.Int16) { + case addressTypeIPv4: + tcp, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + return false, nil, nil + } + tcp.IP = tcp.IP.To4() + + addresses = append(addresses, tcp) + + case addressTypeIPv6: + tcp, err := net.ResolveTCPAddr("tcp6", address) + if err != nil { + return false, nil, nil + } + addresses = append(addresses, tcp) + + case addressTypeTorV3, addressTypeTorV2: + service, portStr, err := net.SplitHostPort(address) + if err != nil { + return false, nil, fmt.Errorf("unable to "+ + "split tor v3 address: %v", + addr.Address) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return false, nil, err + } + + addresses = append(addresses, &tor.OnionAddr{ + OnionService: service, + Port: port, + }) + + case addressTypeOpaque: + opaque, err := hex.DecodeString(address) + if err != nil { + return false, nil, fmt.Errorf("unable to "+ + "decode opaque address: %v", addr) + } + + addresses = append(addresses, &lnwire.OpaqueAddrs{ + Payload: opaque, + }) + + default: + return false, nil, fmt.Errorf("unknown address "+ + "type: %v", addr.Type) + } + } + + return true, addresses, nil +} + +// upsertNodeExtraSignedFields updates the node's extra signed fields in the +// database. This includes updating any existing types, inserting any new types, +// and deleting any types that are no longer present. +func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries, + nodeID int64, extraFields map[uint64][]byte) error { + + // Get any existing extra signed fields for the node. + existingFields, err := db.GetExtraNodeTypes(ctx, nodeID) + if err != nil { + return err + } + + // Make a lookup map of the existing field types so that we can use it + // to keep track of any fields we should delete. + m := make(map[uint64]bool) + for _, field := range existingFields { + m[uint64(field.Type)] = true + } + + // For all the new fields, we'll upsert them and remove them from the + // map of existing fields. + for tlvType, value := range extraFields { + err = db.UpsertNodeExtraType( + ctx, sqlc.UpsertNodeExtraTypeParams{ + NodeID: nodeID, + Type: int64(tlvType), + Value: value, + }, + ) + if err != nil { + return fmt.Errorf("unable to upsert node(%d) extra "+ + "signed field(%v): %w", nodeID, tlvType, err) + } + + // Remove the field from the map of existing fields if it was + // present. + delete(m, tlvType) + } + + // For all the fields that are left in the map of existing fields, we'll + // delete them as they are no longer present in the new set of fields. + for tlvType := range m { + err = db.DeleteExtraNodeType( + ctx, sqlc.DeleteExtraNodeTypeParams{ + NodeID: nodeID, + Type: int64(tlvType), + }, + ) + if err != nil { + return fmt.Errorf("unable to delete node(%d) extra "+ + "signed field(%v): %w", nodeID, tlvType, err) + } + } + + return nil +} + +// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream. +// This then produces a map from TLV type to value. If the input is not a +// valid TLV stream, then an error is returned. +func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) { + r := bytes.NewReader(data) + + tlvStream, err := tlv.NewStream() + if err != nil { + return nil, err + } + + // Since ExtraOpaqueData is provided by a potentially malicious peer, + // pass it into the P2P decoding variant. + parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r) + if err != nil { + return nil, err + } + if len(parsedTypes) == 0 { + return nil, nil + } + + records := make(map[uint64][]byte) + for k, v := range parsedTypes { + records[uint64(k)] = v + } + + return records, nil +} diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go new file mode 100644 index 000000000..1fadc7e3f --- /dev/null +++ b/sqldb/sqlc/graph.sql.go @@ -0,0 +1,359 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.29.0 +// source: graph.sql + +package sqlc + +import ( + "context" + "database/sql" +) + +const deleteExtraNodeType = `-- name: DeleteExtraNodeType :exec +DELETE FROM node_extra_types +WHERE node_id = $1 + AND type = $2 +` + +type DeleteExtraNodeTypeParams struct { + NodeID int64 + Type int64 +} + +func (q *Queries) DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error { + _, err := q.db.ExecContext(ctx, deleteExtraNodeType, arg.NodeID, arg.Type) + return err +} + +const deleteNodeAddresses = `-- name: DeleteNodeAddresses :exec +DELETE FROM node_addresses +WHERE node_id = $1 +` + +func (q *Queries) DeleteNodeAddresses(ctx context.Context, nodeID int64) error { + _, err := q.db.ExecContext(ctx, deleteNodeAddresses, nodeID) + return err +} + +const deleteNodeByPubKey = `-- name: DeleteNodeByPubKey :execresult +DELETE FROM nodes +WHERE pub_key = $1 + AND version = $2 +` + +type DeleteNodeByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) { + return q.db.ExecContext(ctx, deleteNodeByPubKey, arg.PubKey, arg.Version) +} + +const deleteNodeFeature = `-- name: DeleteNodeFeature :exec +DELETE FROM node_features +WHERE node_id = $1 + AND feature_bit = $2 +` + +type DeleteNodeFeatureParams struct { + NodeID int64 + FeatureBit int32 +} + +func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error { + _, err := q.db.ExecContext(ctx, deleteNodeFeature, arg.NodeID, arg.FeatureBit) + return err +} + +const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many +SELECT node_id, type, value +FROM node_extra_types +WHERE node_id = $1 +` + +func (q *Queries) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) { + rows, err := q.db.QueryContext(ctx, getExtraNodeTypes, nodeID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []NodeExtraType + for rows.Next() { + var i NodeExtraType + if err := rows.Scan(&i.NodeID, &i.Type, &i.Value); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeAddressesByPubKey = `-- name: GetNodeAddressesByPubKey :many +SELECT a.type, a.address +FROM nodes n +LEFT JOIN node_addresses a ON a.node_id = n.id +WHERE n.pub_key = $1 AND n.version = $2 +ORDER BY a.type ASC, a.position ASC +` + +type GetNodeAddressesByPubKeyParams struct { + PubKey []byte + Version int16 +} + +type GetNodeAddressesByPubKeyRow struct { + Type sql.NullInt16 + Address sql.NullString +} + +func (q *Queries) GetNodeAddressesByPubKey(ctx context.Context, arg GetNodeAddressesByPubKeyParams) ([]GetNodeAddressesByPubKeyRow, error) { + rows, err := q.db.QueryContext(ctx, getNodeAddressesByPubKey, arg.PubKey, arg.Version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetNodeAddressesByPubKeyRow + for rows.Next() { + var i GetNodeAddressesByPubKeyRow + if err := rows.Scan(&i.Type, &i.Address); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeByPubKey = `-- name: GetNodeByPubKey :one +SELECT id, version, pub_key, alias, last_update, color, signature +FROM nodes +WHERE pub_key = $1 + AND version = $2 +` + +type GetNodeByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error) { + row := q.db.QueryRowContext(ctx, getNodeByPubKey, arg.PubKey, arg.Version) + var i Node + err := row.Scan( + &i.ID, + &i.Version, + &i.PubKey, + &i.Alias, + &i.LastUpdate, + &i.Color, + &i.Signature, + ) + return i, err +} + +const getNodeFeatures = `-- name: GetNodeFeatures :many +SELECT node_id, feature_bit +FROM node_features +WHERE node_id = $1 +` + +func (q *Queries) GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) { + rows, err := q.db.QueryContext(ctx, getNodeFeatures, nodeID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []NodeFeature + for rows.Next() { + var i NodeFeature + if err := rows.Scan(&i.NodeID, &i.FeatureBit); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getNodeFeaturesByPubKey = `-- name: GetNodeFeaturesByPubKey :many +SELECT f.feature_bit +FROM nodes n + JOIN node_features f ON f.node_id = n.id +WHERE n.pub_key = $1 + AND n.version = $2 +` + +type GetNodeFeaturesByPubKeyParams struct { + PubKey []byte + Version int16 +} + +func (q *Queries) GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) { + rows, err := q.db.QueryContext(ctx, getNodeFeaturesByPubKey, arg.PubKey, arg.Version) + if err != nil { + return nil, err + } + defer rows.Close() + var items []int32 + for rows.Next() { + var feature_bit int32 + if err := rows.Scan(&feature_bit); err != nil { + return nil, err + } + items = append(items, feature_bit) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const insertNodeAddress = `-- name: InsertNodeAddress :exec +/* ───────────────────────────────────────────── + node_addresses table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_addresses ( + node_id, + type, + address, + position +) VALUES ( + $1, $2, $3, $4 + ) +` + +type InsertNodeAddressParams struct { + NodeID int64 + Type int16 + Address string + Position int32 +} + +func (q *Queries) InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error { + _, err := q.db.ExecContext(ctx, insertNodeAddress, + arg.NodeID, + arg.Type, + arg.Address, + arg.Position, + ) + return err +} + +const insertNodeFeature = `-- name: InsertNodeFeature :exec +/* ───────────────────────────────────────────── + node_features table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_features ( + node_id, feature_bit +) VALUES ( + $1, $2 +) +` + +type InsertNodeFeatureParams struct { + NodeID int64 + FeatureBit int32 +} + +func (q *Queries) InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error { + _, err := q.db.ExecContext(ctx, insertNodeFeature, arg.NodeID, arg.FeatureBit) + return err +} + +const upsertNode = `-- name: UpsertNode :one +/* ───────────────────────────────────────────── + nodes table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO nodes ( + version, pub_key, alias, last_update, color, signature +) VALUES ( + $1, $2, $3, $4, $5, $6 +) +ON CONFLICT (pub_key, version) + -- Update the following fields if a conflict occurs on pub_key + -- and version. + DO UPDATE SET + alias = EXCLUDED.alias, + last_update = EXCLUDED.last_update, + color = EXCLUDED.color, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > nodes.last_update +RETURNING id +` + +type UpsertNodeParams struct { + Version int16 + PubKey []byte + Alias sql.NullString + LastUpdate sql.NullInt64 + Color sql.NullString + Signature []byte +} + +func (q *Queries) UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) { + row := q.db.QueryRowContext(ctx, upsertNode, + arg.Version, + arg.PubKey, + arg.Alias, + arg.LastUpdate, + arg.Color, + arg.Signature, + ) + var id int64 + err := row.Scan(&id) + return id, err +} + +const upsertNodeExtraType = `-- name: UpsertNodeExtraType :exec +/* ───────────────────────────────────────────── + node_extra_types table queries + ───────────────────────────────────────────── +*/ + +INSERT INTO node_extra_types ( + node_id, type, value +) +VALUES ($1, $2, $3) +ON CONFLICT (type, node_id) + -- Update the value if a conflict occurs on type + -- and node_id. + DO UPDATE SET value = EXCLUDED.value +` + +type UpsertNodeExtraTypeParams struct { + NodeID int64 + Type int64 + Value []byte +} + +func (q *Queries) UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error { + _, err := q.db.ExecContext(ctx, upsertNodeExtraType, arg.NodeID, arg.Type, arg.Value) + return err +} diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index 5da90a6a3..557bb1dfe 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -13,13 +13,18 @@ import ( type Querier interface { ClearKVInvoiceHashIndex(ctx context.Context) error DeleteCanceledInvoices(ctx context.Context) (sql.Result, error) + DeleteExtraNodeType(ctx context.Context, arg DeleteExtraNodeTypeParams) error DeleteInvoice(ctx context.Context, arg DeleteInvoiceParams) (sql.Result, error) + DeleteNodeAddresses(ctx context.Context, nodeID int64) error + DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) + DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) FetchSettledAMPSubInvoices(ctx context.Context, arg FetchSettledAMPSubInvoicesParams) ([]FetchSettledAMPSubInvoicesRow, error) FilterInvoices(ctx context.Context, arg FilterInvoicesParams) ([]Invoice, error) GetAMPInvoiceID(ctx context.Context, setID []byte) (int64, error) GetDatabaseVersion(ctx context.Context) (int32, error) + GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields // from different invoices. It is the caller's responsibility to ensure that // we bubble up an error in those cases. @@ -31,6 +36,10 @@ type Querier interface { GetInvoiceHTLCs(ctx context.Context, invoiceID int64) ([]InvoiceHtlc, error) GetKVInvoicePaymentHashByAddIndex(ctx context.Context, addIndex int64) ([]byte, error) GetMigration(ctx context.Context, version int32) (time.Time, error) + GetNodeAddressesByPubKey(ctx context.Context, arg GetNodeAddressesByPubKeyParams) ([]GetNodeAddressesByPubKeyRow, error) + GetNodeByPubKey(ctx context.Context, arg GetNodeByPubKeyParams) (Node, error) + GetNodeFeatures(ctx context.Context, nodeID int64) ([]NodeFeature, error) + GetNodeFeaturesByPubKey(ctx context.Context, arg GetNodeFeaturesByPubKeyParams) ([]int32, error) InsertAMPSubInvoice(ctx context.Context, arg InsertAMPSubInvoiceParams) error InsertAMPSubInvoiceHTLC(ctx context.Context, arg InsertAMPSubInvoiceHTLCParams) error InsertInvoice(ctx context.Context, arg InsertInvoiceParams) (int64, error) @@ -39,6 +48,8 @@ type Querier interface { InsertInvoiceHTLCCustomRecord(ctx context.Context, arg InsertInvoiceHTLCCustomRecordParams) error InsertKVInvoiceKeyAndAddIndex(ctx context.Context, arg InsertKVInvoiceKeyAndAddIndexParams) error InsertMigratedInvoice(ctx context.Context, arg InsertMigratedInvoiceParams) (int64, error) + InsertNodeAddress(ctx context.Context, arg InsertNodeAddressParams) error + InsertNodeFeature(ctx context.Context, arg InsertNodeFeatureParams) error NextInvoiceSettleIndex(ctx context.Context) (int64, error) OnAMPSubInvoiceCanceled(ctx context.Context, arg OnAMPSubInvoiceCanceledParams) error OnAMPSubInvoiceCreated(ctx context.Context, arg OnAMPSubInvoiceCreatedParams) error @@ -55,6 +66,8 @@ type Querier interface { UpdateInvoiceHTLCs(ctx context.Context, arg UpdateInvoiceHTLCsParams) error UpdateInvoiceState(ctx context.Context, arg UpdateInvoiceStateParams) (sql.Result, error) UpsertAMPSubInvoice(ctx context.Context, arg UpsertAMPSubInvoiceParams) (sql.Result, error) + UpsertNode(ctx context.Context, arg UpsertNodeParams) (int64, error) + UpsertNodeExtraType(ctx context.Context, arg UpsertNodeExtraTypeParams) error } var _ Querier = (*Queries)(nil) diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index e69de29bb..390870086 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -0,0 +1,112 @@ +/* ───────────────────────────────────────────── + nodes table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertNode :one +INSERT INTO nodes ( + version, pub_key, alias, last_update, color, signature +) VALUES ( + $1, $2, $3, $4, $5, $6 +) +ON CONFLICT (pub_key, version) + -- Update the following fields if a conflict occurs on pub_key + -- and version. + DO UPDATE SET + alias = EXCLUDED.alias, + last_update = EXCLUDED.last_update, + color = EXCLUDED.color, + signature = EXCLUDED.signature +WHERE EXCLUDED.last_update > nodes.last_update +RETURNING id; + +-- name: GetNodeByPubKey :one +SELECT * +FROM nodes +WHERE pub_key = $1 + AND version = $2; + +-- name: DeleteNodeByPubKey :execresult +DELETE FROM nodes +WHERE pub_key = $1 + AND version = $2; + +/* ───────────────────────────────────────────── + node_features table queries + ───────────────────────────────────────────── +*/ + +-- name: InsertNodeFeature :exec +INSERT INTO node_features ( + node_id, feature_bit +) VALUES ( + $1, $2 +); + +-- name: GetNodeFeatures :many +SELECT * +FROM node_features +WHERE node_id = $1; + +-- name: GetNodeFeaturesByPubKey :many +SELECT f.feature_bit +FROM nodes n + JOIN node_features f ON f.node_id = n.id +WHERE n.pub_key = $1 + AND n.version = $2; + +-- name: DeleteNodeFeature :exec +DELETE FROM node_features +WHERE node_id = $1 + AND feature_bit = $2; + +/* ───────────────────────────────────────────── + node_addresses table queries + ───────────────────────────────────────────── +*/ + +-- name: InsertNodeAddress :exec +INSERT INTO node_addresses ( + node_id, + type, + address, + position +) VALUES ( + $1, $2, $3, $4 + ); + +-- name: GetNodeAddressesByPubKey :many +SELECT a.type, a.address +FROM nodes n +LEFT JOIN node_addresses a ON a.node_id = n.id +WHERE n.pub_key = $1 AND n.version = $2 +ORDER BY a.type ASC, a.position ASC; + +-- name: DeleteNodeAddresses :exec +DELETE FROM node_addresses +WHERE node_id = $1; + +/* ───────────────────────────────────────────── + node_extra_types table queries + ───────────────────────────────────────────── +*/ + +-- name: UpsertNodeExtraType :exec +INSERT INTO node_extra_types ( + node_id, type, value +) +VALUES ($1, $2, $3) +ON CONFLICT (type, node_id) + -- Update the value if a conflict occurs on type + -- and node_id. + DO UPDATE SET value = EXCLUDED.value; + +-- name: GetExtraNodeTypes :many +SELECT * +FROM node_extra_types +WHERE node_id = $1; + +-- name: DeleteExtraNodeType :exec +DELETE FROM node_extra_types +WHERE node_id = $1 + AND type = $2;