diff --git a/channeldb/graph.go b/channeldb/graph.go index 7aa3544b8..5e57e4cd2 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -232,6 +232,71 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, return g, nil } +// channelMapKey is the key structure used for storing channel edge policies. +type channelMapKey struct { + nodeKey route.Vertex + chanID [8]byte +} + +// getChannelMap loads all channel edge policies from the database and stores +// them in a map. +func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( + map[channelMapKey]*ChannelEdgePolicy, error) { + + // Create a map to store all channel edge policies. + channelMap := make(map[channelMapKey]*ChannelEdgePolicy) + + err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error { + // Skip embedded buckets. + if bytes.Equal(k, edgeIndexBucket) || + bytes.Equal(k, edgeUpdateIndexBucket) || + bytes.Equal(k, zombieBucket) || + bytes.Equal(k, disabledEdgePolicyBucket) || + bytes.Equal(k, channelPointBucket) { + + return nil + } + + // Validate key length. + if len(k) != 33+8 { + return fmt.Errorf("invalid edge key %x encountered", k) + } + + var key channelMapKey + copy(key.nodeKey[:], k[:33]) + copy(key.chanID[:], k[33:]) + + // No need to deserialize unknown policy. + if bytes.Equal(edgeBytes, unknownPolicy) { + return nil + } + + edgeReader := bytes.NewReader(edgeBytes) + edge, err := deserializeChanEdgePolicyRaw( + edgeReader, + ) + + switch { + // If the db policy was missing an expected optional field, we + // return nil as if the policy was unknown. + case err == ErrEdgePolicyOptionalFieldNotFound: + return nil + + case err != nil: + return err + } + + channelMap[key] = edge + + return nil + }) + if err != nil { + return nil, err + } + + return channelMap, nil +} + var graphTopLevelBuckets = [][]byte{ nodeBucket, edgeBucket, @@ -332,50 +397,47 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { func (c *ChannelGraph) ForEachChannel(cb func(*ChannelEdgeInfo, *ChannelEdgePolicy, *ChannelEdgePolicy) error) error { - // TODO(roasbeef): ptr map to reduce # of allocs? no duplicates - - return kvdb.View(c.db, func(tx kvdb.RTx) error { - // First, grab the node bucket. This will be used to populate - // the Node pointers in each edge read from disk. - nodes := tx.ReadBucket(nodeBucket) - if nodes == nil { - return ErrGraphNotFound - } - - // Next, grab the edge bucket which stores the edges, and also - // the index itself so we can group the directed edges together - // logically. + return c.db.View(func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound } + + // First, load all edges in memory indexed by node and channel + // id. + channelMap, err := c.getChannelMap(edges) + if err != nil { + return err + } + edgeIndex := edges.NestedReadBucket(edgeIndexBucket) if edgeIndex == nil { return ErrGraphNoEdgesFound } - // For each edge pair within the edge index, we fetch each edge - // itself and also the node information in order to fully - // populated the object. - return edgeIndex.ForEach(func(chanID, edgeInfoBytes []byte) error { - infoReader := bytes.NewReader(edgeInfoBytes) - edgeInfo, err := deserializeChanEdgeInfo(infoReader) - if err != nil { - return err - } - edgeInfo.db = c.db + // Load edge index, recombine each channel with the policies + // loaded above and invoke the callback. + return kvdb.ForAll(edgeIndex, func(k, edgeInfoBytes []byte) error { + var chanID [8]byte + copy(chanID[:], k) - edge1, edge2, err := fetchChanEdgePolicies( - edgeIndex, edges, nodes, chanID, c.db, - ) + edgeInfoReader := bytes.NewReader(edgeInfoBytes) + info, err := deserializeChanEdgeInfo(edgeInfoReader) if err != nil { return err } - // With both edges read, execute the call back. IF this - // function returns an error then the transaction will - // be aborted. - return cb(&edgeInfo, edge1, edge2) + policy1 := channelMap[channelMapKey{ + nodeKey: info.NodeKey1Bytes, + chanID: chanID, + }] + + policy2 := channelMap[channelMapKey{ + nodeKey: info.NodeKey2Bytes, + chanID: chanID, + }] + + return cb(&info, policy1, policy2) }) }, func() {}) }