Files
lnd/graph/db/sql_store.go
2025-06-02 14:53:18 +02:00

1260 lines
35 KiB
Go

package graphdb
import (
"bytes"
"context"
"database/sql"
"encoding/hex"
"errors"
"fmt"
"math"
"net"
"strconv"
"sync"
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/lightningnetwork/lnd/sqldb/sqlc"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/tor"
)
// ProtocolVersion is an enum that defines the gossip protocol version of a
// message.
type ProtocolVersion uint8
const (
// ProtocolV1 is the gossip protocol version defined in BOLT #7.
ProtocolV1 ProtocolVersion = 1
)
// String returns a string representation of the protocol version.
func (v ProtocolVersion) String() string {
return fmt.Sprintf("V%d", v)
}
// SQLQueries is a subset of the sqlc.Querier interface that can be used to
// execute queries against the SQL graph tables.
//
//nolint:ll,interfacebloat
type SQLQueries interface {
/*
Node queries.
*/
UpsertNode(ctx context.Context, arg sqlc.UpsertNodeParams) (int64, error)
GetNodeByPubKey(ctx context.Context, arg sqlc.GetNodeByPubKeyParams) (sqlc.Node, error)
GetNodesByLastUpdateRange(ctx context.Context, arg sqlc.GetNodesByLastUpdateRangeParams) ([]sqlc.Node, error)
DeleteNodeByPubKey(ctx context.Context, arg sqlc.DeleteNodeByPubKeyParams) (sql.Result, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]sqlc.NodeExtraType, error)
UpsertNodeExtraType(ctx context.Context, arg sqlc.UpsertNodeExtraTypeParams) error
DeleteExtraNodeType(ctx context.Context, arg sqlc.DeleteExtraNodeTypeParams) error
InsertNodeAddress(ctx context.Context, arg sqlc.InsertNodeAddressParams) error
GetNodeAddressesByPubKey(ctx context.Context, arg sqlc.GetNodeAddressesByPubKeyParams) ([]sqlc.GetNodeAddressesByPubKeyRow, error)
DeleteNodeAddresses(ctx context.Context, nodeID int64) error
InsertNodeFeature(ctx context.Context, arg sqlc.InsertNodeFeatureParams) error
GetNodeFeatures(ctx context.Context, nodeID int64) ([]sqlc.NodeFeature, error)
GetNodeFeaturesByPubKey(ctx context.Context, arg sqlc.GetNodeFeaturesByPubKeyParams) ([]int32, error)
DeleteNodeFeature(ctx context.Context, arg sqlc.DeleteNodeFeatureParams) error
/*
Source node queries.
*/
AddSourceNode(ctx context.Context, nodeID int64) error
GetSourceNodesByVersion(ctx context.Context, version int16) ([]sqlc.GetSourceNodesByVersionRow, error)
/*
Channel queries.
*/
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
HighestSCID(ctx context.Context, version int16) ([]byte, error)
CreateChannelExtraType(ctx context.Context, arg sqlc.CreateChannelExtraTypeParams) error
InsertChannelFeature(ctx context.Context, arg sqlc.InsertChannelFeatureParams) error
}
// BatchedSQLQueries is a version of SQLQueries that's capable of batched
// database operations.
type BatchedSQLQueries interface {
SQLQueries
sqldb.BatchedTx[SQLQueries]
}
// SQLStore is an implementation of the V1Store interface that uses a SQL
// database as the backend.
//
// NOTE: currently, this temporarily embeds the KVStore struct so that we can
// implement the V1Store interface incrementally. For any method not
// implemented, things will fall back to the KVStore. This is ONLY the case
// for the time being while this struct is purely used in unit tests only.
type SQLStore struct {
db BatchedSQLQueries
// cacheMu guards all caches (rejectCache and chanCache). If
// this mutex will be acquired at the same time as the DB mutex then
// the cacheMu MUST be acquired first to prevent deadlock.
cacheMu sync.RWMutex
rejectCache *rejectCache
chanCache *channelCache
chanScheduler batch.Scheduler[SQLQueries]
nodeScheduler batch.Scheduler[SQLQueries]
// Temporary fall-back to the KVStore so that we can implement the
// interface incrementally.
*KVStore
}
// A compile-time assertion to ensure that SQLStore implements the V1Store
// interface.
var _ V1Store = (*SQLStore)(nil)
// NewSQLStore creates a new SQLStore instance given an open BatchedSQLQueries
// storage backend.
func NewSQLStore(db BatchedSQLQueries, kvStore *KVStore,
options ...StoreOptionModifier) (*SQLStore, error) {
opts := DefaultOptions()
for _, o := range options {
o(opts)
}
if opts.NoMigration {
return nil, fmt.Errorf("the NoMigration option is not yet " +
"supported for SQL stores")
}
s := &SQLStore{
db: db,
KVStore: kvStore,
rejectCache: newRejectCache(opts.RejectCacheSize),
chanCache: newChannelCache(opts.ChannelCacheSize),
}
s.chanScheduler = batch.NewTimeScheduler(
db, &s.cacheMu, opts.BatchCommitInterval,
)
s.nodeScheduler = batch.NewTimeScheduler(
db, nil, opts.BatchCommitInterval,
)
return s, nil
}
// AddLightningNode adds a vertex/node to the graph database. If the node is not
// in the database from before, this will add a new, unconnected one to the
// graph. If it is present from before, this will update that node's
// information.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) AddLightningNode(node *models.LightningNode,
opts ...batch.SchedulerOption) error {
ctx := context.TODO()
r := &batch.Request[SQLQueries]{
Opts: batch.NewSchedulerOptions(opts...),
Do: func(queries SQLQueries) error {
_, err := upsertNode(ctx, queries, node)
return err
},
}
return s.nodeScheduler.Execute(ctx, r)
}
// FetchLightningNode attempts to look up a target node by its identity public
// key. If the node isn't found in the database, then ErrGraphNodeNotFound is
// returned.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) FetchLightningNode(pubKey route.Vertex) (
*models.LightningNode, error) {
ctx := context.TODO()
var node *models.LightningNode
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var err error
_, node, err = getNodeByPubKey(ctx, db, pubKey)
return err
}, sqldb.NoOpReset)
if err != nil {
return nil, fmt.Errorf("unable to fetch node: %w", err)
}
return node, nil
}
// HasLightningNode determines if the graph has a vertex identified by the
// target node identity public key. If the node exists in the database, a
// timestamp of when the data for the node was lasted updated is returned along
// with a true boolean. Otherwise, an empty time.Time is returned with a false
// boolean.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) HasLightningNode(pubKey [33]byte) (time.Time, bool,
error) {
ctx := context.TODO()
var (
exists bool
lastUpdate time.Time
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbNode, err := db.GetNodeByPubKey(
ctx, sqlc.GetNodeByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: pubKey[:],
},
)
if errors.Is(err, sql.ErrNoRows) {
return nil
} else if err != nil {
return fmt.Errorf("unable to fetch node: %w", err)
}
exists = true
if dbNode.LastUpdate.Valid {
lastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return time.Time{}, false,
fmt.Errorf("unable to fetch node: %w", err)
}
return lastUpdate, exists, nil
}
// AddrsForNode returns all known addresses for the target node public key
// that the graph DB is aware of. The returned boolean indicates if the
// given node is unknown to the graph DB or not.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) AddrsForNode(nodePub *btcec.PublicKey) (bool, []net.Addr,
error) {
ctx := context.TODO()
var (
addresses []net.Addr
known bool
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
var err error
known, addresses, err = getNodeAddresses(
ctx, db, nodePub.SerializeCompressed(),
)
if err != nil {
return fmt.Errorf("unable to fetch node addresses: %w",
err)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return false, nil, fmt.Errorf("unable to get addresses for "+
"node(%x): %w", nodePub.SerializeCompressed(), err)
}
return known, addresses, nil
}
// DeleteLightningNode starts a new database transaction to remove a vertex/node
// from the database according to the node's public key.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) DeleteLightningNode(pubKey route.Vertex) error {
ctx := context.TODO()
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
res, err := db.DeleteNodeByPubKey(
ctx, sqlc.DeleteNodeByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: pubKey[:],
},
)
if err != nil {
return err
}
rows, err := res.RowsAffected()
if err != nil {
return err
}
if rows == 0 {
return ErrGraphNodeNotFound
} else if rows > 1 {
return fmt.Errorf("deleted %d rows, expected 1", rows)
}
return err
}, sqldb.NoOpReset)
if err != nil {
return fmt.Errorf("unable to delete node: %w", err)
}
return nil
}
// FetchNodeFeatures returns the features of the given node. If no features are
// known for the node, an empty feature vector is returned.
//
// NOTE: this is part of the graphdb.NodeTraverser interface.
func (s *SQLStore) FetchNodeFeatures(nodePub route.Vertex) (
*lnwire.FeatureVector, error) {
ctx := context.TODO()
return fetchNodeFeatures(ctx, s.db, nodePub)
}
// LookupAlias attempts to return the alias as advertised by the target node.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) LookupAlias(pub *btcec.PublicKey) (string, error) {
var (
ctx = context.TODO()
alias string
)
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbNode, err := db.GetNodeByPubKey(
ctx, sqlc.GetNodeByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: pub.SerializeCompressed(),
},
)
if errors.Is(err, sql.ErrNoRows) {
return ErrNodeAliasNotFound
} else if err != nil {
return fmt.Errorf("unable to fetch node: %w", err)
}
if !dbNode.Alias.Valid {
return ErrNodeAliasNotFound
}
alias = dbNode.Alias.String
return nil
}, sqldb.NoOpReset)
if err != nil {
return "", fmt.Errorf("unable to look up alias: %w", err)
}
return alias, nil
}
// SourceNode returns the source node of the graph. The source node is treated
// as the center node within a star-graph. This method may be used to kick off
// a path finding algorithm in order to explore the reachability of another
// node based off the source node.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) SourceNode() (*models.LightningNode, error) {
ctx := context.TODO()
var node *models.LightningNode
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
_, nodePub, err := getSourceNode(ctx, db, ProtocolV1)
if err != nil {
return fmt.Errorf("unable to fetch V1 source node: %w",
err)
}
_, node, err = getNodeByPubKey(ctx, db, nodePub)
return err
}, sqldb.NoOpReset)
if err != nil {
return nil, fmt.Errorf("unable to fetch source node: %w", err)
}
return node, nil
}
// SetSourceNode sets the source node within the graph database. The source
// node is to be used as the center of a star-graph within path finding
// algorithms.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) SetSourceNode(node *models.LightningNode) error {
ctx := context.TODO()
return s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
id, err := upsertNode(ctx, db, node)
if err != nil {
return fmt.Errorf("unable to upsert source node: %w",
err)
}
// Make sure that if a source node for this version is already
// set, then the ID is the same as the one we are about to set.
dbSourceNodeID, _, err := getSourceNode(ctx, db, ProtocolV1)
if err != nil && !errors.Is(err, ErrSourceNodeNotSet) {
return fmt.Errorf("unable to fetch source node: %w",
err)
} else if err == nil {
if dbSourceNodeID != id {
return fmt.Errorf("v1 source node already "+
"set to a different node: %d vs %d",
dbSourceNodeID, id)
}
return nil
}
return db.AddSourceNode(ctx, id)
}, sqldb.NoOpReset)
}
// NodeUpdatesInHorizon returns all the known lightning node which have an
// update timestamp within the passed range. This method can be used by two
// nodes to quickly determine if they have the same set of up to date node
// announcements.
//
// NOTE: This is part of the V1Store interface.
func (s *SQLStore) NodeUpdatesInHorizon(startTime,
endTime time.Time) ([]models.LightningNode, error) {
ctx := context.TODO()
var nodes []models.LightningNode
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
dbNodes, err := db.GetNodesByLastUpdateRange(
ctx, sqlc.GetNodesByLastUpdateRangeParams{
StartTime: sqldb.SQLInt64(startTime.Unix()),
EndTime: sqldb.SQLInt64(endTime.Unix()),
},
)
if err != nil {
return fmt.Errorf("unable to fetch nodes: %w", err)
}
for _, dbNode := range dbNodes {
node, err := buildNode(ctx, db, &dbNode)
if err != nil {
return fmt.Errorf("unable to build node: %w",
err)
}
nodes = append(nodes, *node)
}
return nil
}, sqldb.NoOpReset)
if err != nil {
return nil, fmt.Errorf("unable to fetch nodes: %w", err)
}
return nodes, nil
}
// AddChannelEdge adds a new (undirected, blank) edge to the graph database. An
// undirected edge from the two target nodes are created. The information stored
// denotes the static attributes of the channel, such as the channelID, the keys
// involved in creation of the channel, and the set of features that the channel
// supports. The chanPoint and chanID are used to uniquely identify the edge
// globally within the database.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) AddChannelEdge(edge *models.ChannelEdgeInfo,
opts ...batch.SchedulerOption) error {
ctx := context.TODO()
var alreadyExists bool
r := &batch.Request[SQLQueries]{
Opts: batch.NewSchedulerOptions(opts...),
Reset: func() {
alreadyExists = false
},
Do: func(tx SQLQueries) error {
err := insertChannel(ctx, tx, edge)
// Silence ErrEdgeAlreadyExist so that the batch can
// succeed, but propagate the error via local state.
if errors.Is(err, ErrEdgeAlreadyExist) {
alreadyExists = true
return nil
}
return err
},
OnCommit: func(err error) error {
switch {
case err != nil:
return err
case alreadyExists:
return ErrEdgeAlreadyExist
default:
s.rejectCache.remove(edge.ChannelID)
s.chanCache.remove(edge.ChannelID)
return nil
}
},
}
return s.chanScheduler.Execute(ctx, r)
}
// HighestChanID returns the "highest" known channel ID in the channel graph.
// This represents the "newest" channel from the PoV of the chain. This method
// can be used by peers to quickly determine if their graphs are in sync.
//
// NOTE: This is part of the V1Store interface.
func (s *SQLStore) HighestChanID() (uint64, error) {
ctx := context.TODO()
var highestChanID uint64
err := s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
chanID, err := db.HighestSCID(ctx, int16(ProtocolV1))
if errors.Is(err, sql.ErrNoRows) {
return nil
} else if err != nil {
return fmt.Errorf("unable to fetch highest chan ID: %w",
err)
}
highestChanID = byteOrder.Uint64(chanID)
return nil
}, sqldb.NoOpReset)
if err != nil {
return 0, fmt.Errorf("unable to fetch highest chan ID: %w", err)
}
return highestChanID, nil
}
// getNodeByPubKey attempts to look up a target node by its public key.
func getNodeByPubKey(ctx context.Context, db SQLQueries,
pubKey route.Vertex) (int64, *models.LightningNode, error) {
dbNode, err := db.GetNodeByPubKey(
ctx, sqlc.GetNodeByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: pubKey[:],
},
)
if errors.Is(err, sql.ErrNoRows) {
return 0, nil, ErrGraphNodeNotFound
} else if err != nil {
return 0, nil, fmt.Errorf("unable to fetch node: %w", err)
}
node, err := buildNode(ctx, db, &dbNode)
if err != nil {
return 0, nil, fmt.Errorf("unable to build node: %w", err)
}
return dbNode.ID, node, nil
}
// buildNode constructs a LightningNode instance from the given database node
// record. The node's features, addresses and extra signed fields are also
// fetched from the database and set on the node.
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.Node) (
*models.LightningNode, error) {
if dbNode.Version != int16(ProtocolV1) {
return nil, fmt.Errorf("unsupported node version: %d",
dbNode.Version)
}
var pub [33]byte
copy(pub[:], dbNode.PubKey)
node := &models.LightningNode{
PubKeyBytes: pub,
Features: lnwire.EmptyFeatureVector(),
LastUpdate: time.Unix(0, 0),
}
if len(dbNode.Signature) == 0 {
return node, nil
}
node.HaveNodeAnnouncement = true
node.AuthSigBytes = dbNode.Signature
node.Alias = dbNode.Alias.String
node.LastUpdate = time.Unix(dbNode.LastUpdate.Int64, 0)
var err error
node.Color, err = DecodeHexColor(dbNode.Color.String)
if err != nil {
return nil, fmt.Errorf("unable to decode color: %w", err)
}
// Fetch the node's features.
node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"features: %w", dbNode.ID, err)
}
// Fetch the node's addresses.
_, node.Addresses, err = getNodeAddresses(ctx, db, pub[:])
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"addresses: %w", dbNode.ID, err)
}
// Fetch the node's extra signed fields.
extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"extra signed fields: %w", dbNode.ID, err)
}
recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
if err != nil {
return nil, fmt.Errorf("unable to serialize extra signed "+
"fields: %w", err)
}
if len(recs) != 0 {
node.ExtraOpaqueData = recs
}
return node, nil
}
// getNodeFeatures fetches the feature bits and constructs the feature vector
// for a node with the given DB ID.
func getNodeFeatures(ctx context.Context, db SQLQueries,
nodeID int64) (*lnwire.FeatureVector, error) {
rows, err := db.GetNodeFeatures(ctx, nodeID)
if err != nil {
return nil, fmt.Errorf("unable to get node(%d) features: %w",
nodeID, err)
}
features := lnwire.EmptyFeatureVector()
for _, feature := range rows {
features.Set(lnwire.FeatureBit(feature.FeatureBit))
}
return features, nil
}
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
// given DB ID.
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
nodeID int64) (map[uint64][]byte, error) {
fields, err := db.GetExtraNodeTypes(ctx, nodeID)
if err != nil {
return nil, fmt.Errorf("unable to get node(%d) extra "+
"signed fields: %w", nodeID, err)
}
extraFields := make(map[uint64][]byte)
for _, field := range fields {
extraFields[uint64(field.Type)] = field.Value
}
return extraFields, nil
}
// upsertNode upserts the node record into the database. If the node already
// exists, then the node's information is updated. If the node doesn't exist,
// then a new node is created. The node's features, addresses and extra TLV
// types are also updated. The node's DB ID is returned.
func upsertNode(ctx context.Context, db SQLQueries,
node *models.LightningNode) (int64, error) {
params := sqlc.UpsertNodeParams{
Version: int16(ProtocolV1),
PubKey: node.PubKeyBytes[:],
}
if node.HaveNodeAnnouncement {
params.LastUpdate = sqldb.SQLInt64(node.LastUpdate.Unix())
params.Color = sqldb.SQLStr(EncodeHexColor(node.Color))
params.Alias = sqldb.SQLStr(node.Alias)
params.Signature = node.AuthSigBytes
}
nodeID, err := db.UpsertNode(ctx, params)
if err != nil {
return 0, fmt.Errorf("upserting node(%x): %w", node.PubKeyBytes,
err)
}
// We can exit here if we don't have the announcement yet.
if !node.HaveNodeAnnouncement {
return nodeID, nil
}
// Update the node's features.
err = upsertNodeFeatures(ctx, db, nodeID, node.Features)
if err != nil {
return 0, fmt.Errorf("inserting node features: %w", err)
}
// Update the node's addresses.
err = upsertNodeAddresses(ctx, db, nodeID, node.Addresses)
if err != nil {
return 0, fmt.Errorf("inserting node addresses: %w", err)
}
// Convert the flat extra opaque data into a map of TLV types to
// values.
extra, err := marshalExtraOpaqueData(node.ExtraOpaqueData)
if err != nil {
return 0, fmt.Errorf("unable to marshal extra opaque data: %w",
err)
}
// Update the node's extra signed fields.
err = upsertNodeExtraSignedFields(ctx, db, nodeID, extra)
if err != nil {
return 0, fmt.Errorf("inserting node extra TLVs: %w", err)
}
return nodeID, nil
}
// upsertNodeFeatures updates the node's features node_features table. This
// includes deleting any feature bits no longer present and inserting any new
// feature bits. If the feature bit does not yet exist in the features table,
// then an entry is created in that table first.
func upsertNodeFeatures(ctx context.Context, db SQLQueries, nodeID int64,
features *lnwire.FeatureVector) error {
// Get any existing features for the node.
existingFeatures, err := db.GetNodeFeatures(ctx, nodeID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return err
}
// Copy the nodes latest set of feature bits.
newFeatures := make(map[int32]struct{})
if features != nil {
for feature := range features.Features() {
newFeatures[int32(feature)] = struct{}{}
}
}
// For any current feature that already exists in the DB, remove it from
// the in-memory map. For any existing feature that does not exist in
// the in-memory map, delete it from the database.
for _, feature := range existingFeatures {
// The feature is still present, so there are no updates to be
// made.
if _, ok := newFeatures[feature.FeatureBit]; ok {
delete(newFeatures, feature.FeatureBit)
continue
}
// The feature is no longer present, so we remove it from the
// database.
err := db.DeleteNodeFeature(ctx, sqlc.DeleteNodeFeatureParams{
NodeID: nodeID,
FeatureBit: feature.FeatureBit,
})
if err != nil {
return fmt.Errorf("unable to delete node(%d) "+
"feature(%v): %w", nodeID, feature.FeatureBit,
err)
}
}
// Any remaining entries in newFeatures are new features that need to be
// added to the database for the first time.
for feature := range newFeatures {
err = db.InsertNodeFeature(ctx, sqlc.InsertNodeFeatureParams{
NodeID: nodeID,
FeatureBit: feature,
})
if err != nil {
return fmt.Errorf("unable to insert node(%d) "+
"feature(%v): %w", nodeID, feature, err)
}
}
return nil
}
// fetchNodeFeatures fetches the features for a node with the given public key.
func fetchNodeFeatures(ctx context.Context, queries SQLQueries,
nodePub route.Vertex) (*lnwire.FeatureVector, error) {
rows, err := queries.GetNodeFeaturesByPubKey(
ctx, sqlc.GetNodeFeaturesByPubKeyParams{
PubKey: nodePub[:],
Version: int16(ProtocolV1),
},
)
if err != nil {
return nil, fmt.Errorf("unable to get node(%s) features: %w",
nodePub, err)
}
features := lnwire.EmptyFeatureVector()
for _, bit := range rows {
features.Set(lnwire.FeatureBit(bit))
}
return features, nil
}
// dbAddressType is an enum type that represents the different address types
// that we store in the node_addresses table. The address type determines how
// the address is to be serialised/deserialize.
type dbAddressType uint8
const (
addressTypeIPv4 dbAddressType = 1
addressTypeIPv6 dbAddressType = 2
addressTypeTorV2 dbAddressType = 3
addressTypeTorV3 dbAddressType = 4
addressTypeOpaque dbAddressType = math.MaxInt8
)
// upsertNodeAddresses updates the node's addresses in the database. This
// includes deleting any existing addresses and inserting the new set of
// addresses. The deletion is necessary since the ordering of the addresses may
// change, and we need to ensure that the database reflects the latest set of
// addresses so that at the time of reconstructing the node announcement, the
// order is preserved and the signature over the message remains valid.
func upsertNodeAddresses(ctx context.Context, db SQLQueries, nodeID int64,
addresses []net.Addr) error {
// Delete any existing addresses for the node. This is required since
// even if the new set of addresses is the same, the ordering may have
// changed for a given address type.
err := db.DeleteNodeAddresses(ctx, nodeID)
if err != nil {
return fmt.Errorf("unable to delete node(%d) addresses: %w",
nodeID, err)
}
// Copy the nodes latest set of addresses.
newAddresses := map[dbAddressType][]string{
addressTypeIPv4: {},
addressTypeIPv6: {},
addressTypeTorV2: {},
addressTypeTorV3: {},
addressTypeOpaque: {},
}
addAddr := func(t dbAddressType, addr net.Addr) {
newAddresses[t] = append(newAddresses[t], addr.String())
}
for _, address := range addresses {
switch addr := address.(type) {
case *net.TCPAddr:
if ip4 := addr.IP.To4(); ip4 != nil {
addAddr(addressTypeIPv4, addr)
} else if ip6 := addr.IP.To16(); ip6 != nil {
addAddr(addressTypeIPv6, addr)
} else {
return fmt.Errorf("unhandled IP address: %v",
addr)
}
case *tor.OnionAddr:
switch len(addr.OnionService) {
case tor.V2Len:
addAddr(addressTypeTorV2, addr)
case tor.V3Len:
addAddr(addressTypeTorV3, addr)
default:
return fmt.Errorf("invalid length for a tor " +
"address")
}
case *lnwire.OpaqueAddrs:
addAddr(addressTypeOpaque, addr)
default:
return fmt.Errorf("unhandled address type: %T", addr)
}
}
// Any remaining entries in newAddresses are new addresses that need to
// be added to the database for the first time.
for addrType, addrList := range newAddresses {
for position, addr := range addrList {
err := db.InsertNodeAddress(
ctx, sqlc.InsertNodeAddressParams{
NodeID: nodeID,
Type: int16(addrType),
Address: addr,
Position: int32(position),
},
)
if err != nil {
return fmt.Errorf("unable to insert "+
"node(%d) address(%v): %w", nodeID,
addr, err)
}
}
}
return nil
}
// getNodeAddresses fetches the addresses for a node with the given public key.
func getNodeAddresses(ctx context.Context, db SQLQueries, nodePub []byte) (bool,
[]net.Addr, error) {
// GetNodeAddressesByPubKey ensures that the addresses for a given type
// are returned in the same order as they were inserted.
rows, err := db.GetNodeAddressesByPubKey(
ctx, sqlc.GetNodeAddressesByPubKeyParams{
Version: int16(ProtocolV1),
PubKey: nodePub,
},
)
if err != nil {
return false, nil, err
}
// GetNodeAddressesByPubKey uses a left join so there should always be
// at least one row returned if the node exists even if it has no
// addresses.
if len(rows) == 0 {
return false, nil, nil
}
addresses := make([]net.Addr, 0, len(rows))
for _, addr := range rows {
if !(addr.Type.Valid && addr.Address.Valid) {
continue
}
address := addr.Address.String
switch dbAddressType(addr.Type.Int16) {
case addressTypeIPv4:
tcp, err := net.ResolveTCPAddr("tcp4", address)
if err != nil {
return false, nil, nil
}
tcp.IP = tcp.IP.To4()
addresses = append(addresses, tcp)
case addressTypeIPv6:
tcp, err := net.ResolveTCPAddr("tcp6", address)
if err != nil {
return false, nil, nil
}
addresses = append(addresses, tcp)
case addressTypeTorV3, addressTypeTorV2:
service, portStr, err := net.SplitHostPort(address)
if err != nil {
return false, nil, fmt.Errorf("unable to "+
"split tor v3 address: %v",
addr.Address)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return false, nil, err
}
addresses = append(addresses, &tor.OnionAddr{
OnionService: service,
Port: port,
})
case addressTypeOpaque:
opaque, err := hex.DecodeString(address)
if err != nil {
return false, nil, fmt.Errorf("unable to "+
"decode opaque address: %v", addr)
}
addresses = append(addresses, &lnwire.OpaqueAddrs{
Payload: opaque,
})
default:
return false, nil, fmt.Errorf("unknown address "+
"type: %v", addr.Type)
}
}
return true, addresses, nil
}
// upsertNodeExtraSignedFields updates the node's extra signed fields in the
// database. This includes updating any existing types, inserting any new types,
// and deleting any types that are no longer present.
func upsertNodeExtraSignedFields(ctx context.Context, db SQLQueries,
nodeID int64, extraFields map[uint64][]byte) error {
// Get any existing extra signed fields for the node.
existingFields, err := db.GetExtraNodeTypes(ctx, nodeID)
if err != nil {
return err
}
// Make a lookup map of the existing field types so that we can use it
// to keep track of any fields we should delete.
m := make(map[uint64]bool)
for _, field := range existingFields {
m[uint64(field.Type)] = true
}
// For all the new fields, we'll upsert them and remove them from the
// map of existing fields.
for tlvType, value := range extraFields {
err = db.UpsertNodeExtraType(
ctx, sqlc.UpsertNodeExtraTypeParams{
NodeID: nodeID,
Type: int64(tlvType),
Value: value,
},
)
if err != nil {
return fmt.Errorf("unable to upsert node(%d) extra "+
"signed field(%v): %w", nodeID, tlvType, err)
}
// Remove the field from the map of existing fields if it was
// present.
delete(m, tlvType)
}
// For all the fields that are left in the map of existing fields, we'll
// delete them as they are no longer present in the new set of fields.
for tlvType := range m {
err = db.DeleteExtraNodeType(
ctx, sqlc.DeleteExtraNodeTypeParams{
NodeID: nodeID,
Type: int64(tlvType),
},
)
if err != nil {
return fmt.Errorf("unable to delete node(%d) extra "+
"signed field(%v): %w", nodeID, tlvType, err)
}
}
return nil
}
// getSourceNode returns the DB node ID and pub key of the source node for the
// specified protocol version.
func getSourceNode(ctx context.Context, db SQLQueries,
version ProtocolVersion) (int64, route.Vertex, error) {
var pubKey route.Vertex
nodes, err := db.GetSourceNodesByVersion(ctx, int16(version))
if err != nil {
return 0, pubKey, fmt.Errorf("unable to fetch source node: %w",
err)
}
if len(nodes) == 0 {
return 0, pubKey, ErrSourceNodeNotSet
} else if len(nodes) > 1 {
return 0, pubKey, fmt.Errorf("multiple source nodes for "+
"protocol %s found", version)
}
copy(pubKey[:], nodes[0].PubKey)
return nodes[0].NodeID, pubKey, nil
}
// marshalExtraOpaqueData takes a flat byte slice parses it as a TLV stream.
// This then produces a map from TLV type to value. If the input is not a
// valid TLV stream, then an error is returned.
func marshalExtraOpaqueData(data []byte) (map[uint64][]byte, error) {
r := bytes.NewReader(data)
tlvStream, err := tlv.NewStream()
if err != nil {
return nil, err
}
// Since ExtraOpaqueData is provided by a potentially malicious peer,
// pass it into the P2P decoding variant.
parsedTypes, err := tlvStream.DecodeWithParsedTypesP2P(r)
if err != nil {
return nil, err
}
if len(parsedTypes) == 0 {
return nil, nil
}
records := make(map[uint64][]byte)
for k, v := range parsedTypes {
records[uint64(k)] = v
}
return records, nil
}
// insertChannel inserts a new channel record into the database.
func insertChannel(ctx context.Context, db SQLQueries,
edge *models.ChannelEdgeInfo) error {
var chanIDB [8]byte
byteOrder.PutUint64(chanIDB[:], edge.ChannelID)
// Make sure that the channel doesn't already exist. We do this
// explicitly instead of relying on catching a unique constraint error
// because relying on SQL to throw that error would abort the entire
// batch of transactions.
_, err := db.GetChannelBySCID(
ctx, sqlc.GetChannelBySCIDParams{
Scid: chanIDB[:],
Version: int16(ProtocolV1),
},
)
if err == nil {
return ErrEdgeAlreadyExist
} else if !errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("unable to fetch channel: %w", err)
}
// Make sure that at least a "shell" entry for each node is present in
// the nodes table.
node1DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey1Bytes)
if err != nil {
return fmt.Errorf("unable to create shell node: %w", err)
}
node2DBID, err := maybeCreateShellNode(ctx, db, edge.NodeKey2Bytes)
if err != nil {
return fmt.Errorf("unable to create shell node: %w", err)
}
var capacity sql.NullInt64
if edge.Capacity != 0 {
capacity = sqldb.SQLInt64(int64(edge.Capacity))
}
createParams := sqlc.CreateChannelParams{
Version: int16(ProtocolV1),
Scid: chanIDB[:],
NodeID1: node1DBID,
NodeID2: node2DBID,
Outpoint: edge.ChannelPoint.String(),
Capacity: capacity,
BitcoinKey1: edge.BitcoinKey1Bytes[:],
BitcoinKey2: edge.BitcoinKey2Bytes[:],
}
if edge.AuthProof != nil {
proof := edge.AuthProof
createParams.Node1Signature = proof.NodeSig1Bytes
createParams.Node2Signature = proof.NodeSig2Bytes
createParams.Bitcoin1Signature = proof.BitcoinSig1Bytes
createParams.Bitcoin2Signature = proof.BitcoinSig2Bytes
}
// Insert the new channel record.
dbChanID, err := db.CreateChannel(ctx, createParams)
if err != nil {
return err
}
// Insert any channel features.
if len(edge.Features) != 0 {
chanFeatures := lnwire.NewRawFeatureVector()
err := chanFeatures.Decode(bytes.NewReader(edge.Features))
if err != nil {
return err
}
fv := lnwire.NewFeatureVector(chanFeatures, lnwire.Features)
for feature := range fv.Features() {
err = db.InsertChannelFeature(
ctx, sqlc.InsertChannelFeatureParams{
ChannelID: dbChanID,
FeatureBit: int32(feature),
},
)
if err != nil {
return fmt.Errorf("unable to insert "+
"channel(%d) feature(%v): %w", dbChanID,
feature, err)
}
}
}
// Finally, insert any extra TLV fields in the channel announcement.
extra, err := marshalExtraOpaqueData(edge.ExtraOpaqueData)
if err != nil {
return fmt.Errorf("unable to marshal extra opaque data: %w",
err)
}
for tlvType, value := range extra {
err := db.CreateChannelExtraType(
ctx, sqlc.CreateChannelExtraTypeParams{
ChannelID: dbChanID,
Type: int64(tlvType),
Value: value,
},
)
if err != nil {
return fmt.Errorf("unable to upsert channel(%d) extra "+
"signed field(%v): %w", edge.ChannelID,
tlvType, err)
}
}
return nil
}
// maybeCreateShellNode checks if a shell node entry exists for the
// given public key. If it does not exist, then a new shell node entry is
// created. The ID of the node is returned. A shell node only has a protocol
// version and public key persisted.
func maybeCreateShellNode(ctx context.Context, db SQLQueries,
pubKey route.Vertex) (int64, error) {
dbNode, err := db.GetNodeByPubKey(
ctx, sqlc.GetNodeByPubKeyParams{
PubKey: pubKey[:],
Version: int16(ProtocolV1),
},
)
// The node exists. Return the ID.
if err == nil {
return dbNode.ID, nil
} else if !errors.Is(err, sql.ErrNoRows) {
return 0, err
}
// Otherwise, the node does not exist, so we create a shell entry for
// it.
id, err := db.UpsertNode(ctx, sqlc.UpsertNodeParams{
Version: int16(ProtocolV1),
PubKey: pubKey[:],
})
if err != nil {
return 0, fmt.Errorf("unable to create shell node: %w", err)
}
return id, nil
}