mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-03 23:03:37 +02:00
graph/db: refactor buildNode to use batch fetching
Here, we add a new `buildNodeWithBatchData` helper method that can be used to construct a `models.LightningNode` object using pre-fetched batch data. The existing `buildNode` method is then adjusted to use this new helper.
This commit is contained in:
@@ -3352,8 +3352,29 @@ func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
|
|||||||
// buildNode constructs a LightningNode instance from the given database node
|
// buildNode constructs a LightningNode instance from the given database node
|
||||||
// record. The node's features, addresses and extra signed fields are also
|
// record. The node's features, addresses and extra signed fields are also
|
||||||
// fetched from the database and set on the node.
|
// fetched from the database and set on the node.
|
||||||
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
|
func buildNode(ctx context.Context, db SQLQueries,
|
||||||
*models.LightningNode, error) {
|
dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
|
||||||
|
|
||||||
|
// NOTE: buildNode is only used to load the data for a single node, and
|
||||||
|
// so no paged queries will be performed. This means that it's ok to
|
||||||
|
// used pass in default config values here.
|
||||||
|
cfg := sqldb.DefaultPagedQueryConfig()
|
||||||
|
|
||||||
|
data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to batch load node data: %w",
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return buildNodeWithBatchData(dbNode, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildNodeWithBatchData builds a models.LightningNode instance
|
||||||
|
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
|
||||||
|
// features/addresses/extra fields, then the corresponding fields are expected
|
||||||
|
// to be present in the batchNodeData.
|
||||||
|
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
|
||||||
|
batchData *batchNodeData) (*models.LightningNode, error) {
|
||||||
|
|
||||||
if dbNode.Version != int16(ProtocolV1) {
|
if dbNode.Version != int16(ProtocolV1) {
|
||||||
return nil, fmt.Errorf("unsupported node version: %d",
|
return nil, fmt.Errorf("unsupported node version: %d",
|
||||||
@@ -3387,36 +3408,36 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the node's features.
|
// Use preloaded features.
|
||||||
node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
|
if features, exists := batchData.features[dbNode.ID]; exists {
|
||||||
if err != nil {
|
fv := lnwire.EmptyFeatureVector()
|
||||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
for _, bit := range features {
|
||||||
"features: %w", dbNode.ID, err)
|
fv.Set(lnwire.FeatureBit(bit))
|
||||||
|
}
|
||||||
|
node.Features = fv
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the node's addresses.
|
// Use preloaded addresses.
|
||||||
node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID)
|
addresses, exists := batchData.addresses[dbNode.ID]
|
||||||
|
if exists && len(addresses) > 0 {
|
||||||
|
node.Addresses, err = buildNodeAddresses(addresses)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
return nil, fmt.Errorf("unable to build addresses "+
|
||||||
"addresses: %w", dbNode.ID, err)
|
"for node(%d): %w", dbNode.ID, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch the node's extra signed fields.
|
// Use preloaded extra fields.
|
||||||
extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
|
if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
|
||||||
|
recs, err := lnwire.CustomRecords(extraFields).Serialize()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
return nil, fmt.Errorf("unable to serialize extra "+
|
||||||
"extra signed fields: %w", dbNode.ID, err)
|
"signed fields: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to serialize extra signed "+
|
|
||||||
"fields: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(recs) != 0 {
|
if len(recs) != 0 {
|
||||||
node.ExtraOpaqueData = recs
|
node.ExtraOpaqueData = recs
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return node, nil
|
return node, nil
|
||||||
}
|
}
|
||||||
@@ -3440,25 +3461,6 @@ func getNodeFeatures(ctx context.Context, db SQLQueries,
|
|||||||
return features, nil
|
return features, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
|
|
||||||
// given DB ID.
|
|
||||||
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
|
|
||||||
nodeID int64) (map[uint64][]byte, error) {
|
|
||||||
|
|
||||||
fields, err := db.GetExtraNodeTypes(ctx, nodeID)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to get node(%d) extra "+
|
|
||||||
"signed fields: %w", nodeID, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
extraFields := make(map[uint64][]byte)
|
|
||||||
for _, field := range fields {
|
|
||||||
extraFields[uint64(field.Type)] = field.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
return extraFields, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// upsertNode upserts the node record into the database. If the node already
|
// upsertNode upserts the node record into the database. If the node already
|
||||||
// exists, then the node's information is updated. If the node doesn't exist,
|
// exists, then the node's information is updated. If the node doesn't exist,
|
||||||
// then a new node is created. The node's features, addresses and extra TLV
|
// then a new node is created. The node's features, addresses and extra TLV
|
||||||
@@ -3714,55 +3716,13 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
|
|||||||
for _, row := range rows {
|
for _, row := range rows {
|
||||||
address := row.Address
|
address := row.Address
|
||||||
|
|
||||||
switch dbAddressType(row.Type) {
|
addr, err := parseAddress(dbAddressType(row.Type), address)
|
||||||
case addressTypeIPv4:
|
|
||||||
tcp, err := net.ResolveTCPAddr("tcp4", address)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("unable to parse address "+
|
||||||
}
|
"for node(%d): %v: %w", id, address, err)
|
||||||
tcp.IP = tcp.IP.To4()
|
|
||||||
|
|
||||||
addresses = append(addresses, tcp)
|
|
||||||
|
|
||||||
case addressTypeIPv6:
|
|
||||||
tcp, err := net.ResolveTCPAddr("tcp6", address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
addresses = append(addresses, tcp)
|
|
||||||
|
|
||||||
case addressTypeTorV3, addressTypeTorV2:
|
|
||||||
service, portStr, err := net.SplitHostPort(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to "+
|
|
||||||
"split tor v3 address: %v", address)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
port, err := strconv.Atoi(portStr)
|
addresses = append(addresses, addr)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
addresses = append(addresses, &tor.OnionAddr{
|
|
||||||
OnionService: service,
|
|
||||||
Port: port,
|
|
||||||
})
|
|
||||||
|
|
||||||
case addressTypeOpaque:
|
|
||||||
opaque, err := hex.DecodeString(address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("unable to "+
|
|
||||||
"decode opaque address: %v", address)
|
|
||||||
}
|
|
||||||
|
|
||||||
addresses = append(addresses, &lnwire.OpaqueAddrs{
|
|
||||||
Payload: opaque,
|
|
||||||
})
|
|
||||||
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("unknown address type: %v",
|
|
||||||
row.Type)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we have no addresses, then we'll return nil instead of an
|
// If we have no addresses, then we'll return nil instead of an
|
||||||
@@ -4676,6 +4636,89 @@ func channelIDToBytes(channelID uint64) []byte {
|
|||||||
return chanIDB[:]
|
return chanIDB[:]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
|
||||||
|
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
|
||||||
|
if len(addresses) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]net.Addr, 0, len(addresses))
|
||||||
|
for _, addr := range addresses {
|
||||||
|
netAddr, err := parseAddress(addr.addrType, addr.address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to parse address %s "+
|
||||||
|
"of type %d: %w", addr.address, addr.addrType,
|
||||||
|
err)
|
||||||
|
}
|
||||||
|
if netAddr != nil {
|
||||||
|
result = append(result, netAddr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we have no valid addresses, return nil instead of empty slice.
|
||||||
|
if len(result) == 0 {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseAddress parses the given address string based on the address type
|
||||||
|
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
|
||||||
|
// and opaque addresses.
|
||||||
|
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
|
||||||
|
switch addrType {
|
||||||
|
case addressTypeIPv4:
|
||||||
|
tcp, err := net.ResolveTCPAddr("tcp4", address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
tcp.IP = tcp.IP.To4()
|
||||||
|
|
||||||
|
return tcp, nil
|
||||||
|
|
||||||
|
case addressTypeIPv6:
|
||||||
|
tcp, err := net.ResolveTCPAddr("tcp6", address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tcp, nil
|
||||||
|
|
||||||
|
case addressTypeTorV3, addressTypeTorV2:
|
||||||
|
service, portStr, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to split tor "+
|
||||||
|
"address: %v", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &tor.OnionAddr{
|
||||||
|
OnionService: service,
|
||||||
|
Port: port,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
case addressTypeOpaque:
|
||||||
|
opaque, err := hex.DecodeString(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to decode opaque "+
|
||||||
|
"address: %v", address)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &lnwire.OpaqueAddrs{
|
||||||
|
Payload: opaque,
|
||||||
|
}, nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown address type: %v", addrType)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// batchNodeData holds all the related data for a batch of nodes.
|
// batchNodeData holds all the related data for a batch of nodes.
|
||||||
type batchNodeData struct {
|
type batchNodeData struct {
|
||||||
// features is a map from a DB node ID to the feature bits for that
|
// features is a map from a DB node ID to the feature bits for that
|
||||||
|
Reference in New Issue
Block a user