mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-27 19:26:23 +02:00
graph/db: refactor buildNode to use batch fetching
Here, we add a new `buildNodeWithBatchData` helper method that can be used to construct a `models.LightningNode` object using pre-fetched batch data. The existing `buildNode` method is then adjusted to use this new helper.
This commit is contained in:
@@ -3352,8 +3352,29 @@ func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
|
||||
// buildNode constructs a LightningNode instance from the given database node
|
||||
// record. The node's features, addresses and extra signed fields are also
|
||||
// fetched from the database and set on the node.
|
||||
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
|
||||
*models.LightningNode, error) {
|
||||
func buildNode(ctx context.Context, db SQLQueries,
|
||||
dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
|
||||
|
||||
// NOTE: buildNode is only used to load the data for a single node, and
|
||||
// so no paged queries will be performed. This means that it's ok to
|
||||
// used pass in default config values here.
|
||||
cfg := sqldb.DefaultPagedQueryConfig()
|
||||
|
||||
data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to batch load node data: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
return buildNodeWithBatchData(dbNode, data)
|
||||
}
|
||||
|
||||
// buildNodeWithBatchData builds a models.LightningNode instance
|
||||
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
|
||||
// features/addresses/extra fields, then the corresponding fields are expected
|
||||
// to be present in the batchNodeData.
|
||||
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
|
||||
batchData *batchNodeData) (*models.LightningNode, error) {
|
||||
|
||||
if dbNode.Version != int16(ProtocolV1) {
|
||||
return nil, fmt.Errorf("unsupported node version: %d",
|
||||
@@ -3387,35 +3408,35 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the node's features.
|
||||
node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
||||
"features: %w", dbNode.ID, err)
|
||||
// Use preloaded features.
|
||||
if features, exists := batchData.features[dbNode.ID]; exists {
|
||||
fv := lnwire.EmptyFeatureVector()
|
||||
for _, bit := range features {
|
||||
fv.Set(lnwire.FeatureBit(bit))
|
||||
}
|
||||
node.Features = fv
|
||||
}
|
||||
|
||||
// Fetch the node's addresses.
|
||||
node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
||||
"addresses: %w", dbNode.ID, err)
|
||||
// Use preloaded addresses.
|
||||
addresses, exists := batchData.addresses[dbNode.ID]
|
||||
if exists && len(addresses) > 0 {
|
||||
node.Addresses, err = buildNodeAddresses(addresses)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build addresses "+
|
||||
"for node(%d): %w", dbNode.ID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Fetch the node's extra signed fields.
|
||||
extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to fetch node(%d) "+
|
||||
"extra signed fields: %w", dbNode.ID, err)
|
||||
}
|
||||
|
||||
recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to serialize extra signed "+
|
||||
"fields: %w", err)
|
||||
}
|
||||
|
||||
if len(recs) != 0 {
|
||||
node.ExtraOpaqueData = recs
|
||||
// Use preloaded extra fields.
|
||||
if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
|
||||
recs, err := lnwire.CustomRecords(extraFields).Serialize()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to serialize extra "+
|
||||
"signed fields: %w", err)
|
||||
}
|
||||
if len(recs) != 0 {
|
||||
node.ExtraOpaqueData = recs
|
||||
}
|
||||
}
|
||||
|
||||
return node, nil
|
||||
@@ -3440,25 +3461,6 @@ func getNodeFeatures(ctx context.Context, db SQLQueries,
|
||||
return features, nil
|
||||
}
|
||||
|
||||
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
|
||||
// given DB ID.
|
||||
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
|
||||
nodeID int64) (map[uint64][]byte, error) {
|
||||
|
||||
fields, err := db.GetExtraNodeTypes(ctx, nodeID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get node(%d) extra "+
|
||||
"signed fields: %w", nodeID, err)
|
||||
}
|
||||
|
||||
extraFields := make(map[uint64][]byte)
|
||||
for _, field := range fields {
|
||||
extraFields[uint64(field.Type)] = field.Value
|
||||
}
|
||||
|
||||
return extraFields, nil
|
||||
}
|
||||
|
||||
// upsertNode upserts the node record into the database. If the node already
|
||||
// exists, then the node's information is updated. If the node doesn't exist,
|
||||
// then a new node is created. The node's features, addresses and extra TLV
|
||||
@@ -3714,55 +3716,13 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
|
||||
for _, row := range rows {
|
||||
address := row.Address
|
||||
|
||||
switch dbAddressType(row.Type) {
|
||||
case addressTypeIPv4:
|
||||
tcp, err := net.ResolveTCPAddr("tcp4", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tcp.IP = tcp.IP.To4()
|
||||
|
||||
addresses = append(addresses, tcp)
|
||||
|
||||
case addressTypeIPv6:
|
||||
tcp, err := net.ResolveTCPAddr("tcp6", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
addresses = append(addresses, tcp)
|
||||
|
||||
case addressTypeTorV3, addressTypeTorV2:
|
||||
service, portStr, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to "+
|
||||
"split tor v3 address: %v", address)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
addresses = append(addresses, &tor.OnionAddr{
|
||||
OnionService: service,
|
||||
Port: port,
|
||||
})
|
||||
|
||||
case addressTypeOpaque:
|
||||
opaque, err := hex.DecodeString(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to "+
|
||||
"decode opaque address: %v", address)
|
||||
}
|
||||
|
||||
addresses = append(addresses, &lnwire.OpaqueAddrs{
|
||||
Payload: opaque,
|
||||
})
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown address type: %v",
|
||||
row.Type)
|
||||
addr, err := parseAddress(dbAddressType(row.Type), address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse address "+
|
||||
"for node(%d): %v: %w", id, address, err)
|
||||
}
|
||||
|
||||
addresses = append(addresses, addr)
|
||||
}
|
||||
|
||||
// If we have no addresses, then we'll return nil instead of an
|
||||
@@ -4676,6 +4636,89 @@ func channelIDToBytes(channelID uint64) []byte {
|
||||
return chanIDB[:]
|
||||
}
|
||||
|
||||
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
|
||||
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
|
||||
if len(addresses) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
result := make([]net.Addr, 0, len(addresses))
|
||||
for _, addr := range addresses {
|
||||
netAddr, err := parseAddress(addr.addrType, addr.address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to parse address %s "+
|
||||
"of type %d: %w", addr.address, addr.addrType,
|
||||
err)
|
||||
}
|
||||
if netAddr != nil {
|
||||
result = append(result, netAddr)
|
||||
}
|
||||
}
|
||||
|
||||
// If we have no valid addresses, return nil instead of empty slice.
|
||||
if len(result) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// parseAddress parses the given address string based on the address type
|
||||
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
|
||||
// and opaque addresses.
|
||||
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
|
||||
switch addrType {
|
||||
case addressTypeIPv4:
|
||||
tcp, err := net.ResolveTCPAddr("tcp4", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tcp.IP = tcp.IP.To4()
|
||||
|
||||
return tcp, nil
|
||||
|
||||
case addressTypeIPv6:
|
||||
tcp, err := net.ResolveTCPAddr("tcp6", address)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tcp, nil
|
||||
|
||||
case addressTypeTorV3, addressTypeTorV2:
|
||||
service, portStr, err := net.SplitHostPort(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to split tor "+
|
||||
"address: %v", address)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &tor.OnionAddr{
|
||||
OnionService: service,
|
||||
Port: port,
|
||||
}, nil
|
||||
|
||||
case addressTypeOpaque:
|
||||
opaque, err := hex.DecodeString(address)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to decode opaque "+
|
||||
"address: %v", address)
|
||||
}
|
||||
|
||||
return &lnwire.OpaqueAddrs{
|
||||
Payload: opaque,
|
||||
}, nil
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown address type: %v", addrType)
|
||||
}
|
||||
}
|
||||
|
||||
// batchNodeData holds all the related data for a batch of nodes.
|
||||
type batchNodeData struct {
|
||||
// features is a map from a DB node ID to the feature bits for that
|
||||
|
Reference in New Issue
Block a user