multi: update FetchChanInfos to take in an optional tx

In this commit, the FetchChanInfos ChannelGraph method is updated to
take in an optional read transaction for the case where it is called
from within another transaction.
This commit is contained in:
Elle Mouton
2024-01-22 09:44:25 +02:00
parent 8cf4044215
commit 6c427a6ba9
4 changed files with 32 additions and 14 deletions

View File

@@ -2357,7 +2357,11 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
// skipped and the result will contain only those edges that exist at the time // skipped and the result will contain only those edges that exist at the time
// of the query. This can be used to respond to peer queries that are seeking to // of the query. This can be used to respond to peer queries that are seeking to
// fill in gaps in their view of the channel graph. // fill in gaps in their view of the channel graph.
func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { //
// NOTE: An optional transaction may be provided. If none is provided, then a
// new one will be created.
func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) (
[]ChannelEdge, error) {
// TODO(roasbeef): sort cids? // TODO(roasbeef): sort cids?
var ( var (
@@ -2365,7 +2369,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
cidBytes [8]byte cidBytes [8]byte
) )
err := kvdb.View(c.db, func(tx kvdb.RTx) error { fetchChanInfos := func(tx kvdb.RTx) error {
edges := tx.ReadBucket(edgeBucket) edges := tx.ReadBucket(edgeBucket)
if edges == nil { if edges == nil {
return ErrGraphNoEdgesFound return ErrGraphNoEdgesFound
@@ -2427,9 +2431,20 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) {
}) })
} }
return nil return nil
}, func() { }
chanEdges = nil
}) if tx == nil {
err := kvdb.View(c.db, fetchChanInfos, func() {
chanEdges = nil
})
if err != nil {
return nil, err
}
return chanEdges, nil
}
err := fetchChanInfos(tx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -3673,7 +3688,7 @@ func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error {
// We need to add the channel back into our graph cache, otherwise we // We need to add the channel back into our graph cache, otherwise we
// won't use it for path finding. // won't use it for path finding.
if c.graphCache != nil { if c.graphCache != nil {
edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -2685,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) {
// We'll now attempt to query for the range of channel ID's we just // We'll now attempt to query for the range of channel ID's we just
// inserted into the database. We should get the exact same set of // inserted into the database. We should get the exact same set of
// edges back. // edges back.
resp, err := graph.FetchChanInfos(edgeQuery) resp, err := graph.FetchChanInfos(nil, edgeQuery)
require.NoError(t, err, "unable to fetch chan edges") require.NoError(t, err, "unable to fetch chan edges")
if len(resp) != len(edges) { if len(resp) != len(edges) {
t.Fatalf("expected %v edges, instead got %v", len(edges), t.Fatalf("expected %v edges, instead got %v", len(edges),

View File

@@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash,
chanIDs = append(chanIDs, chanID.ToUint64()) chanIDs = append(chanIDs, chanID.ToUint64())
} }
channels, err := c.graph.FetchChanInfos(chanIDs) channels, err := c.graph.FetchChanInfos(nil, chanIDs)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1008,20 +1008,23 @@ func (r *ChannelRouter) pruneZombieChans() error {
if r.cfg.AssumeChannelValid { if r.cfg.AssumeChannelValid {
disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs() disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs()
if err != nil { if err != nil {
return fmt.Errorf("unable to get disabled channels ids "+ return fmt.Errorf("unable to get disabled channels "+
"chans: %v", err) "ids chans: %v", err)
} }
disabledEdges, err := r.cfg.Graph.FetchChanInfos(disabledChanIDs) disabledEdges, err := r.cfg.Graph.FetchChanInfos(
nil, disabledChanIDs,
)
if err != nil { if err != nil {
return fmt.Errorf("unable to fetch disabled channels edges "+ return fmt.Errorf("unable to fetch disabled channels "+
"chans: %v", err) "edges chans: %v", err)
} }
// Ensuring we won't prune our own channel from the graph. // Ensuring we won't prune our own channel from the graph.
for _, disabledEdge := range disabledEdges { for _, disabledEdge := range disabledEdges {
if !isSelfChannelEdge(disabledEdge.Info) { if !isSelfChannelEdge(disabledEdge.Info) {
chansToPrune[disabledEdge.Info.ChannelID] = struct{}{} chansToPrune[disabledEdge.Info.ChannelID] =
struct{}{}
} }
} }
} }