diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index cc7e5d4e9..e0192bd39 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -3352,8 +3352,29 @@ func buildCacheableChannelInfo(scid []byte, capacity int64, node1Pub, // buildNode constructs a LightningNode instance from the given database node // record. The node's features, addresses and extra signed fields are also // fetched from the database and set on the node. -func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) ( - *models.LightningNode, error) { +func buildNode(ctx context.Context, db SQLQueries, + dbNode *sqlc.GraphNode) (*models.LightningNode, error) { + + // NOTE: buildNode is only used to load the data for a single node, and + // so no paged queries will be performed. This means that it's ok to + // used pass in default config values here. + cfg := sqldb.DefaultPagedQueryConfig() + + data, err := batchLoadNodeData(ctx, cfg, db, []int64{dbNode.ID}) + if err != nil { + return nil, fmt.Errorf("unable to batch load node data: %w", + err) + } + + return buildNodeWithBatchData(dbNode, data) +} + +// buildNodeWithBatchData builds a models.LightningNode instance +// from the provided sqlc.GraphNode and batchNodeData. If the node does have +// features/addresses/extra fields, then the corresponding fields are expected +// to be present in the batchNodeData. +func buildNodeWithBatchData(dbNode *sqlc.GraphNode, + batchData *batchNodeData) (*models.LightningNode, error) { if dbNode.Version != int16(ProtocolV1) { return nil, fmt.Errorf("unsupported node version: %d", @@ -3387,35 +3408,35 @@ func buildNode(ctx context.Context, db SQLQueries, dbNode *sqlc.GraphNode) ( } } - // Fetch the node's features. - node.Features, err = getNodeFeatures(ctx, db, dbNode.ID) - if err != nil { - return nil, fmt.Errorf("unable to fetch node(%d) "+ - "features: %w", dbNode.ID, err) + // Use preloaded features. + if features, exists := batchData.features[dbNode.ID]; exists { + fv := lnwire.EmptyFeatureVector() + for _, bit := range features { + fv.Set(lnwire.FeatureBit(bit)) + } + node.Features = fv } - // Fetch the node's addresses. - node.Addresses, err = getNodeAddresses(ctx, db, dbNode.ID) - if err != nil { - return nil, fmt.Errorf("unable to fetch node(%d) "+ - "addresses: %w", dbNode.ID, err) + // Use preloaded addresses. + addresses, exists := batchData.addresses[dbNode.ID] + if exists && len(addresses) > 0 { + node.Addresses, err = buildNodeAddresses(addresses) + if err != nil { + return nil, fmt.Errorf("unable to build addresses "+ + "for node(%d): %w", dbNode.ID, err) + } } - // Fetch the node's extra signed fields. - extraTLVMap, err := getNodeExtraSignedFields(ctx, db, dbNode.ID) - if err != nil { - return nil, fmt.Errorf("unable to fetch node(%d) "+ - "extra signed fields: %w", dbNode.ID, err) - } - - recs, err := lnwire.CustomRecords(extraTLVMap).Serialize() - if err != nil { - return nil, fmt.Errorf("unable to serialize extra signed "+ - "fields: %w", err) - } - - if len(recs) != 0 { - node.ExtraOpaqueData = recs + // Use preloaded extra fields. + if extraFields, exists := batchData.extraFields[dbNode.ID]; exists { + recs, err := lnwire.CustomRecords(extraFields).Serialize() + if err != nil { + return nil, fmt.Errorf("unable to serialize extra "+ + "signed fields: %w", err) + } + if len(recs) != 0 { + node.ExtraOpaqueData = recs + } } return node, nil @@ -3440,25 +3461,6 @@ func getNodeFeatures(ctx context.Context, db SQLQueries, return features, nil } -// getNodeExtraSignedFields fetches the extra signed fields for a node with the -// given DB ID. -func getNodeExtraSignedFields(ctx context.Context, db SQLQueries, - nodeID int64) (map[uint64][]byte, error) { - - fields, err := db.GetExtraNodeTypes(ctx, nodeID) - if err != nil { - return nil, fmt.Errorf("unable to get node(%d) extra "+ - "signed fields: %w", nodeID, err) - } - - extraFields := make(map[uint64][]byte) - for _, field := range fields { - extraFields[uint64(field.Type)] = field.Value - } - - return extraFields, nil -} - // upsertNode upserts the node record into the database. If the node already // exists, then the node's information is updated. If the node doesn't exist, // then a new node is created. The node's features, addresses and extra TLV @@ -3714,55 +3716,13 @@ func getNodeAddresses(ctx context.Context, db SQLQueries, id int64) ([]net.Addr, for _, row := range rows { address := row.Address - switch dbAddressType(row.Type) { - case addressTypeIPv4: - tcp, err := net.ResolveTCPAddr("tcp4", address) - if err != nil { - return nil, err - } - tcp.IP = tcp.IP.To4() - - addresses = append(addresses, tcp) - - case addressTypeIPv6: - tcp, err := net.ResolveTCPAddr("tcp6", address) - if err != nil { - return nil, err - } - addresses = append(addresses, tcp) - - case addressTypeTorV3, addressTypeTorV2: - service, portStr, err := net.SplitHostPort(address) - if err != nil { - return nil, fmt.Errorf("unable to "+ - "split tor v3 address: %v", address) - } - - port, err := strconv.Atoi(portStr) - if err != nil { - return nil, err - } - - addresses = append(addresses, &tor.OnionAddr{ - OnionService: service, - Port: port, - }) - - case addressTypeOpaque: - opaque, err := hex.DecodeString(address) - if err != nil { - return nil, fmt.Errorf("unable to "+ - "decode opaque address: %v", address) - } - - addresses = append(addresses, &lnwire.OpaqueAddrs{ - Payload: opaque, - }) - - default: - return nil, fmt.Errorf("unknown address type: %v", - row.Type) + addr, err := parseAddress(dbAddressType(row.Type), address) + if err != nil { + return nil, fmt.Errorf("unable to parse address "+ + "for node(%d): %v: %w", id, address, err) } + + addresses = append(addresses, addr) } // If we have no addresses, then we'll return nil instead of an @@ -4676,6 +4636,89 @@ func channelIDToBytes(channelID uint64) []byte { return chanIDB[:] } +// buildNodeAddresses converts a slice of nodeAddress into a slice of net.Addr. +func buildNodeAddresses(addresses []nodeAddress) ([]net.Addr, error) { + if len(addresses) == 0 { + return nil, nil + } + + result := make([]net.Addr, 0, len(addresses)) + for _, addr := range addresses { + netAddr, err := parseAddress(addr.addrType, addr.address) + if err != nil { + return nil, fmt.Errorf("unable to parse address %s "+ + "of type %d: %w", addr.address, addr.addrType, + err) + } + if netAddr != nil { + result = append(result, netAddr) + } + } + + // If we have no valid addresses, return nil instead of empty slice. + if len(result) == 0 { + return nil, nil + } + + return result, nil +} + +// parseAddress parses the given address string based on the address type +// and returns a net.Addr instance. It supports IPv4, IPv6, Tor v2, Tor v3, +// and opaque addresses. +func parseAddress(addrType dbAddressType, address string) (net.Addr, error) { + switch addrType { + case addressTypeIPv4: + tcp, err := net.ResolveTCPAddr("tcp4", address) + if err != nil { + return nil, err + } + + tcp.IP = tcp.IP.To4() + + return tcp, nil + + case addressTypeIPv6: + tcp, err := net.ResolveTCPAddr("tcp6", address) + if err != nil { + return nil, err + } + + return tcp, nil + + case addressTypeTorV3, addressTypeTorV2: + service, portStr, err := net.SplitHostPort(address) + if err != nil { + return nil, fmt.Errorf("unable to split tor "+ + "address: %v", address) + } + + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, err + } + + return &tor.OnionAddr{ + OnionService: service, + Port: port, + }, nil + + case addressTypeOpaque: + opaque, err := hex.DecodeString(address) + if err != nil { + return nil, fmt.Errorf("unable to decode opaque "+ + "address: %v", address) + } + + return &lnwire.OpaqueAddrs{ + Payload: opaque, + }, nil + + default: + return nil, fmt.Errorf("unknown address type: %v", addrType) + } +} + // batchNodeData holds all the related data for a batch of nodes. type batchNodeData struct { // features is a map from a DB node ID to the feature bits for that