mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-25 13:12:11 +02:00
graph/db: batch fetch channels in ForEachNodeCached
Previously, ForEachNodeCached would batch fetch node _feature_ data but would still fetch the channel set of each node in a node-by-node fashion which is not ideal. So this commit updates this method to make use of the new sqldb.ExecuteCollectAndBatchWithSharedDataQuery helper. It lets us batch load channel data for a range of node IDs. This _greatly_ improves the performance of the method.
This commit is contained in:
@@ -100,6 +100,7 @@ type SQLQueries interface {
|
||||
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
|
||||
HighestSCID(ctx context.Context, version int16) ([]byte, error)
|
||||
ListChannelsByNodeID(ctx context.Context, arg sqlc.ListChannelsByNodeIDParams) ([]sqlc.ListChannelsByNodeIDRow, error)
|
||||
ListChannelsForNodeIDs(ctx context.Context, arg sqlc.ListChannelsForNodeIDsParams) ([]sqlc.ListChannelsForNodeIDsRow, error)
|
||||
ListChannelsWithPoliciesPaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesPaginatedParams) ([]sqlc.ListChannelsWithPoliciesPaginatedRow, error)
|
||||
ListChannelsWithPoliciesForCachePaginated(ctx context.Context, arg sqlc.ListChannelsWithPoliciesForCachePaginatedParams) ([]sqlc.ListChannelsWithPoliciesForCachePaginatedRow, error)
|
||||
ListChannelsPaginated(ctx context.Context, arg sqlc.ListChannelsPaginatedParams) ([]sqlc.ListChannelsPaginatedRow, error)
|
||||
@@ -1085,112 +1086,185 @@ func (s *SQLStore) ForEachNodeCached(ctx context.Context,
|
||||
cb func(node route.Vertex, chans map[uint64]*DirectedChannel) error,
|
||||
reset func()) error {
|
||||
|
||||
handleNode := func(db SQLQueries, nodeID int64,
|
||||
nodePub route.Vertex, features *lnwire.FeatureVector) error {
|
||||
|
||||
toNodeCallback := func() route.Vertex {
|
||||
return nodePub
|
||||
}
|
||||
|
||||
rows, err := db.ListChannelsByNodeID(
|
||||
ctx, sqlc.ListChannelsByNodeIDParams{
|
||||
Version: int16(ProtocolV1),
|
||||
NodeID1: nodeID,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to fetch channels of "+
|
||||
"node(id=%d): %w", nodeID, err)
|
||||
}
|
||||
|
||||
channels := make(map[uint64]*DirectedChannel, len(rows))
|
||||
for _, row := range rows {
|
||||
node1, node2, err := buildNodeVertices(
|
||||
row.Node1Pubkey, row.Node2Pubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e, err := getAndBuildEdgeInfo(
|
||||
ctx, db, s.cfg.ChainHash, row.GraphChannel,
|
||||
node1, node2,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build channel "+
|
||||
"info: %w", err)
|
||||
}
|
||||
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(row)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to extract channel "+
|
||||
"policies: %w", err)
|
||||
}
|
||||
|
||||
p1, p2, err := getAndBuildChanPolicies(
|
||||
ctx, db, dbPol1, dbPol2, e.ChannelID, node1,
|
||||
node2,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to build channel "+
|
||||
"policies: %w", err)
|
||||
}
|
||||
|
||||
// Determine the outgoing and incoming policy
|
||||
// for this channel and node combo.
|
||||
outPolicy, inPolicy := p1, p2
|
||||
if p1 != nil && p1.ToNode == nodePub {
|
||||
outPolicy, inPolicy = p2, p1
|
||||
} else if p2 != nil && p2.ToNode != nodePub {
|
||||
outPolicy, inPolicy = p2, p1
|
||||
}
|
||||
|
||||
var cachedInPolicy *models.CachedEdgePolicy
|
||||
if inPolicy != nil {
|
||||
cachedInPolicy = models.NewCachedPolicy(
|
||||
inPolicy,
|
||||
)
|
||||
cachedInPolicy.ToNodePubKey = toNodeCallback
|
||||
cachedInPolicy.ToNodeFeatures = features
|
||||
}
|
||||
|
||||
var inboundFee lnwire.Fee
|
||||
if outPolicy != nil {
|
||||
outPolicy.InboundFee.WhenSome(
|
||||
func(fee lnwire.Fee) {
|
||||
inboundFee = fee
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
directedChannel := &DirectedChannel{
|
||||
ChannelID: e.ChannelID,
|
||||
IsNode1: nodePub == e.NodeKey1Bytes,
|
||||
OtherNode: e.NodeKey2Bytes,
|
||||
Capacity: e.Capacity,
|
||||
OutPolicySet: outPolicy != nil,
|
||||
InPolicy: cachedInPolicy,
|
||||
InboundFee: inboundFee,
|
||||
}
|
||||
|
||||
if nodePub == e.NodeKey2Bytes {
|
||||
directedChannel.OtherNode = e.NodeKey1Bytes
|
||||
}
|
||||
|
||||
channels[e.ChannelID] = directedChannel
|
||||
}
|
||||
|
||||
return cb(nodePub, channels)
|
||||
type nodeCachedBatchData struct {
|
||||
features map[int64][]int
|
||||
chanBatchData *batchChannelData
|
||||
chanMap map[int64][]sqlc.ListChannelsForNodeIDsRow
|
||||
}
|
||||
|
||||
return s.db.ExecTx(ctx, sqldb.ReadTxOpt(), func(db SQLQueries) error {
|
||||
return forEachNodeCacheable(
|
||||
ctx, s.cfg.QueryCfg, db,
|
||||
func(nodeID int64, nodePub route.Vertex,
|
||||
features *lnwire.FeatureVector) error {
|
||||
// pageQueryFunc is used to query the next page of nodes.
|
||||
pageQueryFunc := func(ctx context.Context, lastID int64,
|
||||
limit int32) ([]sqlc.ListNodeIDsAndPubKeysRow, error) {
|
||||
|
||||
return handleNode(db, nodeID, nodePub, features)
|
||||
return db.ListNodeIDsAndPubKeys(
|
||||
ctx, sqlc.ListNodeIDsAndPubKeysParams{
|
||||
Version: int16(ProtocolV1),
|
||||
ID: lastID,
|
||||
Limit: limit,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
// batchDataFunc is then used to batch load the data required
|
||||
// for each page of nodes.
|
||||
batchDataFunc := func(ctx context.Context,
|
||||
nodeIDs []int64) (*nodeCachedBatchData, error) {
|
||||
|
||||
// Batch load node features.
|
||||
nodeFeatures, err := batchLoadNodeFeaturesHelper(
|
||||
ctx, s.cfg.QueryCfg, db, nodeIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to batch load "+
|
||||
"node features: %w", err)
|
||||
}
|
||||
|
||||
// Batch load ALL unique channels for ALL nodes in this
|
||||
// page.
|
||||
allChannels, err := db.ListChannelsForNodeIDs(
|
||||
ctx, sqlc.ListChannelsForNodeIDsParams{
|
||||
Version: int16(ProtocolV1),
|
||||
Node1Ids: nodeIDs,
|
||||
Node2Ids: nodeIDs,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to batch "+
|
||||
"fetch channels for nodes: %w", err)
|
||||
}
|
||||
|
||||
// Deduplicate channels and collect IDs.
|
||||
var (
|
||||
allChannelIDs []int64
|
||||
allPolicyIDs []int64
|
||||
)
|
||||
uniqueChannels := make(
|
||||
map[int64]sqlc.ListChannelsForNodeIDsRow,
|
||||
)
|
||||
|
||||
for _, channel := range allChannels {
|
||||
channelID := channel.GraphChannel.ID
|
||||
|
||||
// Only process each unique channel once.
|
||||
_, exists := uniqueChannels[channelID]
|
||||
if exists {
|
||||
continue
|
||||
}
|
||||
|
||||
uniqueChannels[channelID] = channel
|
||||
allChannelIDs = append(allChannelIDs, channelID)
|
||||
|
||||
if channel.Policy1ID.Valid {
|
||||
allPolicyIDs = append(
|
||||
allPolicyIDs,
|
||||
channel.Policy1ID.Int64,
|
||||
)
|
||||
}
|
||||
if channel.Policy2ID.Valid {
|
||||
allPolicyIDs = append(
|
||||
allPolicyIDs,
|
||||
channel.Policy2ID.Int64,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Batch load channel data for all unique channels.
|
||||
channelBatchData, err := batchLoadChannelData(
|
||||
ctx, s.cfg.QueryCfg, db, allChannelIDs,
|
||||
allPolicyIDs,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to batch "+
|
||||
"load channel data: %w", err)
|
||||
}
|
||||
|
||||
// Create map of node ID to channels that involve this
|
||||
// node.
|
||||
nodeIDSet := make(map[int64]bool)
|
||||
for _, nodeID := range nodeIDs {
|
||||
nodeIDSet[nodeID] = true
|
||||
}
|
||||
|
||||
nodeChannelMap := make(
|
||||
map[int64][]sqlc.ListChannelsForNodeIDsRow,
|
||||
)
|
||||
for _, channel := range uniqueChannels {
|
||||
// Add channel to both nodes if they're in our
|
||||
// current page.
|
||||
node1 := channel.GraphChannel.NodeID1
|
||||
if nodeIDSet[node1] {
|
||||
nodeChannelMap[node1] = append(
|
||||
nodeChannelMap[node1], channel,
|
||||
)
|
||||
}
|
||||
node2 := channel.GraphChannel.NodeID2
|
||||
if nodeIDSet[node2] {
|
||||
nodeChannelMap[node2] = append(
|
||||
nodeChannelMap[node2], channel,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return &nodeCachedBatchData{
|
||||
features: nodeFeatures,
|
||||
chanBatchData: channelBatchData,
|
||||
chanMap: nodeChannelMap,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// processItem is used to process each node in the current page.
|
||||
processItem := func(ctx context.Context,
|
||||
nodeData sqlc.ListNodeIDsAndPubKeysRow,
|
||||
batchData *nodeCachedBatchData) error {
|
||||
|
||||
// Build feature vector for this node.
|
||||
fv := lnwire.EmptyFeatureVector()
|
||||
features, exists := batchData.features[nodeData.ID]
|
||||
if exists {
|
||||
for _, bit := range features {
|
||||
fv.Set(lnwire.FeatureBit(bit))
|
||||
}
|
||||
}
|
||||
|
||||
var nodePub route.Vertex
|
||||
copy(nodePub[:], nodeData.PubKey)
|
||||
|
||||
nodeChannels := batchData.chanMap[nodeData.ID]
|
||||
|
||||
toNodeCallback := func() route.Vertex {
|
||||
return nodePub
|
||||
}
|
||||
|
||||
// Build cached channels map for this node.
|
||||
channels := make(map[uint64]*DirectedChannel)
|
||||
for _, channelRow := range nodeChannels {
|
||||
directedChan, err := buildDirectedChannel(
|
||||
s.cfg.ChainHash, nodeData.ID, nodePub,
|
||||
channelRow, batchData.chanBatchData, fv,
|
||||
toNodeCallback,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channels[directedChan.ChannelID] = directedChan
|
||||
}
|
||||
|
||||
return cb(nodePub, channels)
|
||||
}
|
||||
|
||||
return sqldb.ExecuteCollectAndBatchWithSharedDataQuery(
|
||||
ctx, s.cfg.QueryCfg, int64(-1), pageQueryFunc,
|
||||
func(node sqlc.ListNodeIDsAndPubKeysRow) int64 {
|
||||
return node.ID
|
||||
},
|
||||
func(node sqlc.ListNodeIDsAndPubKeysRow) (int64,
|
||||
error) {
|
||||
|
||||
return node.ID, nil
|
||||
},
|
||||
batchDataFunc, processItem,
|
||||
)
|
||||
}, reset)
|
||||
}
|
||||
@@ -4411,6 +4485,50 @@ func extractChannelPolicies(row any) (*sqlc.GraphChannelPolicy,
|
||||
|
||||
return policy1, policy2, nil
|
||||
|
||||
case sqlc.ListChannelsForNodeIDsRow:
|
||||
if r.Policy1ID.Valid {
|
||||
policy1 = &sqlc.GraphChannelPolicy{
|
||||
ID: r.Policy1ID.Int64,
|
||||
Version: r.Policy1Version.Int16,
|
||||
ChannelID: r.GraphChannel.ID,
|
||||
NodeID: r.Policy1NodeID.Int64,
|
||||
Timelock: r.Policy1Timelock.Int32,
|
||||
FeePpm: r.Policy1FeePpm.Int64,
|
||||
BaseFeeMsat: r.Policy1BaseFeeMsat.Int64,
|
||||
MinHtlcMsat: r.Policy1MinHtlcMsat.Int64,
|
||||
MaxHtlcMsat: r.Policy1MaxHtlcMsat,
|
||||
LastUpdate: r.Policy1LastUpdate,
|
||||
InboundBaseFeeMsat: r.Policy1InboundBaseFeeMsat,
|
||||
InboundFeeRateMilliMsat: r.Policy1InboundFeeRateMilliMsat,
|
||||
Disabled: r.Policy1Disabled,
|
||||
MessageFlags: r.Policy1MessageFlags,
|
||||
ChannelFlags: r.Policy1ChannelFlags,
|
||||
Signature: r.Policy1Signature,
|
||||
}
|
||||
}
|
||||
if r.Policy2ID.Valid {
|
||||
policy2 = &sqlc.GraphChannelPolicy{
|
||||
ID: r.Policy2ID.Int64,
|
||||
Version: r.Policy2Version.Int16,
|
||||
ChannelID: r.GraphChannel.ID,
|
||||
NodeID: r.Policy2NodeID.Int64,
|
||||
Timelock: r.Policy2Timelock.Int32,
|
||||
FeePpm: r.Policy2FeePpm.Int64,
|
||||
BaseFeeMsat: r.Policy2BaseFeeMsat.Int64,
|
||||
MinHtlcMsat: r.Policy2MinHtlcMsat.Int64,
|
||||
MaxHtlcMsat: r.Policy2MaxHtlcMsat,
|
||||
LastUpdate: r.Policy2LastUpdate,
|
||||
InboundBaseFeeMsat: r.Policy2InboundBaseFeeMsat,
|
||||
InboundFeeRateMilliMsat: r.Policy2InboundFeeRateMilliMsat,
|
||||
Disabled: r.Policy2Disabled,
|
||||
MessageFlags: r.Policy2MessageFlags,
|
||||
ChannelFlags: r.Policy2ChannelFlags,
|
||||
Signature: r.Policy2Signature,
|
||||
}
|
||||
}
|
||||
|
||||
return policy1, policy2, nil
|
||||
|
||||
case sqlc.ListChannelsByNodeIDRow:
|
||||
if r.Policy1ID.Valid {
|
||||
policy1 = &sqlc.GraphChannelPolicy{
|
||||
@@ -5118,3 +5236,83 @@ func forEachChannelWithPolicies(ctx context.Context, db SQLQueries,
|
||||
collectFunc, batchDataFunc, processItem,
|
||||
)
|
||||
}
|
||||
|
||||
// buildDirectedChannel builds a DirectedChannel instance from the provided
|
||||
// data.
|
||||
func buildDirectedChannel(chain chainhash.Hash, nodeID int64,
|
||||
nodePub route.Vertex, channelRow sqlc.ListChannelsForNodeIDsRow,
|
||||
channelBatchData *batchChannelData, features *lnwire.FeatureVector,
|
||||
toNodeCallback func() route.Vertex) (*DirectedChannel, error) {
|
||||
|
||||
node1, node2, err := buildNodeVertices(
|
||||
channelRow.Node1Pubkey, channelRow.Node2Pubkey,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build node vertices: %w", err)
|
||||
}
|
||||
|
||||
edge, err := buildEdgeInfoWithBatchData(
|
||||
chain, channelRow.GraphChannel, node1, node2, channelBatchData,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build channel info: %w", err)
|
||||
}
|
||||
|
||||
dbPol1, dbPol2, err := extractChannelPolicies(channelRow)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to extract channel policies: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
p1, p2, err := buildChanPoliciesWithBatchData(
|
||||
dbPol1, dbPol2, edge.ChannelID, node1, node2,
|
||||
channelBatchData,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to build channel policies: %w",
|
||||
err)
|
||||
}
|
||||
|
||||
// Determine outgoing and incoming policy for this specific node.
|
||||
p1ToNode := channelRow.GraphChannel.NodeID2
|
||||
p2ToNode := channelRow.GraphChannel.NodeID1
|
||||
outPolicy, inPolicy := p1, p2
|
||||
if (p1 != nil && p1ToNode == nodeID) ||
|
||||
(p2 != nil && p2ToNode != nodeID) {
|
||||
|
||||
outPolicy, inPolicy = p2, p1
|
||||
}
|
||||
|
||||
// Build cached policy.
|
||||
var cachedInPolicy *models.CachedEdgePolicy
|
||||
if inPolicy != nil {
|
||||
cachedInPolicy = models.NewCachedPolicy(inPolicy)
|
||||
cachedInPolicy.ToNodePubKey = toNodeCallback
|
||||
cachedInPolicy.ToNodeFeatures = features
|
||||
}
|
||||
|
||||
// Extract inbound fee.
|
||||
var inboundFee lnwire.Fee
|
||||
if outPolicy != nil {
|
||||
outPolicy.InboundFee.WhenSome(func(fee lnwire.Fee) {
|
||||
inboundFee = fee
|
||||
})
|
||||
}
|
||||
|
||||
// Build directed channel.
|
||||
directedChannel := &DirectedChannel{
|
||||
ChannelID: edge.ChannelID,
|
||||
IsNode1: nodePub == edge.NodeKey1Bytes,
|
||||
OtherNode: edge.NodeKey2Bytes,
|
||||
Capacity: edge.Capacity,
|
||||
OutPolicySet: outPolicy != nil,
|
||||
InPolicy: cachedInPolicy,
|
||||
InboundFee: inboundFee,
|
||||
}
|
||||
|
||||
if nodePub == edge.NodeKey2Bytes {
|
||||
directedChannel.OtherNode = edge.NodeKey1Bytes
|
||||
}
|
||||
|
||||
return directedChannel, nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user