graph/db: refactor buildNode to use batch fetching

Here, we add a new `buildNodeWithBatchData` helper method that can be
used to construct a `models.LightningNode` object using pre-fetched
batch data. The existing `buildNode` method is then adjusted to use this
new helper.
This commit is contained in:
Elle Mouton
2025-07-29 07:52:16 +02:00
parent 0dc0d320f8
commit d05b918c7a

View File

@@ -3352,8 +3352,29 @@ func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub,
// buildNode constructs a LightningNode instance from the given database node
// record. The node's features, addresses and extra signed fields are also
// fetched from the database and set on the node.
func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
*models.LightningNode, error) {
func buildNode(ctx context.Context, db SQLQueries,
dbNode *sqlc.GraphNode) (*models.LightningNode, error) {
// NOTE: buildNode is only used to load the data for a single node, and
// so no paged queries will be performed. This means that it's ok to
// used pass in default config values here.
cfg := sqldb.DefaultPagedQueryConfig()
data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID})
if err != nil {
return nil, fmt.Errorf("unable to batch load node data: %w",
err)
}
return buildNodeWithBatchData(dbNode, data)
}
// buildNodeWithBatchData builds a models.LightningNode instance
// from the provided sqlc.GraphNode and batchNodeData. If the node does have
// features/addresses/extra fields, then the corresponding fields are expected
// to be present in the batchNodeData.
func buildNodeWithBatchData(dbNode *sqlc.GraphNode,
batchData *batchNodeData) (*models.LightningNode, error) {
if dbNode.Version != int16(ProtocolV1) {
return nil, fmt.Errorf("unsupported node version: %d",
@@ -3387,35 +3408,35 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) (
}
}
// Fetch the node's features.
node.Features, err = getNodeFeatures(ctx, db, dbNode.ID)
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"features: %w", dbNode.ID, err)
// Use preloaded features.
if features, exists := batchData.features[dbNode.ID]; exists {
fv := lnwire.EmptyFeatureVector()
for _, bit := range features {
fv.Set(lnwire.FeatureBit(bit))
}
node.Features = fv
}
// Fetch the node's addresses.
node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID)
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"addresses: %w", dbNode.ID, err)
// Use preloaded addresses.
addresses, exists := batchData.addresses[dbNode.ID]
if exists && len(addresses) > 0 {
node.Addresses, err = buildNodeAddresses(addresses)
if err != nil {
return nil, fmt.Errorf("unable to build addresses "+
"for node(%d): %w", dbNode.ID, err)
}
}
// Fetch the node's extra signed fields.
extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID)
if err != nil {
return nil, fmt.Errorf("unable to fetch node(%d) "+
"extra signed fields: %w", dbNode.ID, err)
}
recs, err := lnwire.CustomRecords(extraTLVMap).Serialize()
if err != nil {
return nil, fmt.Errorf("unable to serialize extra signed "+
"fields: %w", err)
}
if len(recs) != 0 {
node.ExtraOpaqueData = recs
// Use preloaded extra fields.
if extraFields, exists := batchData.extraFields[dbNode.ID]; exists {
recs, err := lnwire.CustomRecords(extraFields).Serialize()
if err != nil {
return nil, fmt.Errorf("unable to serialize extra "+
"signed fields: %w", err)
}
if len(recs) != 0 {
node.ExtraOpaqueData = recs
}
}
return node, nil
@@ -3440,25 +3461,6 @@ func getNodeFeatures(ctx context.Context, db SQLQueries,
return features, nil
}
// getNodeExtraSignedFields fetches the extra signed fields for a node with the
// given DB ID.
func getNodeExtraSignedFields(ctx context.Context, db SQLQueries,
nodeID int64) (map[uint64][]byte, error) {
fields, err := db.GetExtraNodeTypes(ctx, nodeID)
if err != nil {
return nil, fmt.Errorf("unable to get node(%d) extra "+
"signed fields: %w", nodeID, err)
}
extraFields := make(map[uint64][]byte)
for _, field := range fields {
extraFields[uint64(field.Type)] = field.Value
}
return extraFields, nil
}
// upsertNode upserts the node record into the database. If the node already
// exists, then the node's information is updated. If the node doesn't exist,
// then a new node is created. The node's features, addresses and extra TLV
@@ -3714,55 +3716,13 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr,
for _, row := range rows {
address := row.Address
switch dbAddressType(row.Type) {
case addressTypeIPv4:
tcp, err := net.ResolveTCPAddr("tcp4", address)
if err != nil {
return nil, err
}
tcp.IP = tcp.IP.To4()
addresses = append(addresses, tcp)
case addressTypeIPv6:
tcp, err := net.ResolveTCPAddr("tcp6", address)
if err != nil {
return nil, err
}
addresses = append(addresses, tcp)
case addressTypeTorV3, addressTypeTorV2:
service, portStr, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("unable to "+
"split tor v3 address: %v", address)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
}
addresses = append(addresses, &tor.OnionAddr{
OnionService: service,
Port: port,
})
case addressTypeOpaque:
opaque, err := hex.DecodeString(address)
if err != nil {
return nil, fmt.Errorf("unable to "+
"decode opaque address: %v", address)
}
addresses = append(addresses, &lnwire.OpaqueAddrs{
Payload: opaque,
})
default:
return nil, fmt.Errorf("unknown address type: %v",
row.Type)
addr, err := parseAddress(dbAddressType(row.Type), address)
if err != nil {
return nil, fmt.Errorf("unable to parse address "+
"for node(%d): %v: %w", id, address, err)
}
addresses = append(addresses, addr)
}
// If we have no addresses, then we'll return nil instead of an
@@ -4676,6 +4636,89 @@ func channelIDToBytes(channelID uint64) []byte {
return chanIDB[:]
}
// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr.
func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) {
if len(addresses) == 0 {
return nil, nil
}
result := make([]net.Addr, 0, len(addresses))
for _, addr := range addresses {
netAddr, err := parseAddress(addr.addrType, addr.address)
if err != nil {
return nil, fmt.Errorf("unable to parse address %s "+
"of type %d: %w", addr.address, addr.addrType,
err)
}
if netAddr != nil {
result = append(result, netAddr)
}
}
// If we have no valid addresses, return nil instead of empty slice.
if len(result) == 0 {
return nil, nil
}
return result, nil
}
// parseAddress parses the given address string based on the address type
// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3,
// and opaque addresses.
func parseAddress(addrType dbAddressType, address string) (net.Addr, error) {
switch addrType {
case addressTypeIPv4:
tcp, err := net.ResolveTCPAddr("tcp4", address)
if err != nil {
return nil, err
}
tcp.IP = tcp.IP.To4()
return tcp, nil
case addressTypeIPv6:
tcp, err := net.ResolveTCPAddr("tcp6", address)
if err != nil {
return nil, err
}
return tcp, nil
case addressTypeTorV3, addressTypeTorV2:
service, portStr, err := net.SplitHostPort(address)
if err != nil {
return nil, fmt.Errorf("unable to split tor "+
"address: %v", address)
}
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, err
}
return &tor.OnionAddr{
OnionService: service,
Port: port,
}, nil
case addressTypeOpaque:
opaque, err := hex.DecodeString(address)
if err != nil {
return nil, fmt.Errorf("unable to decode opaque "+
"address: %v", address)
}
return &lnwire.OpaqueAddrs{
Payload: opaque,
}, nil
default:
return nil, fmt.Errorf("unknown address type: %v", addrType)
}
}
// batchNodeData holds all the related data for a batch of nodes.
type batchNodeData struct {
// features is a map from a DB node ID to the feature bits for that