diff --git a/channeldb/graph.go b/channeldb/graph.go index 452d86ecf..bb7774895 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -2357,7 +2357,11 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // skipped and the result will contain only those edges that exist at the time // of the query. This can be used to respond to peer queries that are seeking to // fill in gaps in their view of the channel graph. -func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { +// +// NOTE: An optional transaction may be provided. If none is provided, then a +// new one will be created. +func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( + []ChannelEdge, error) { // TODO(roasbeef): sort cids? var ( @@ -2365,7 +2369,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { cidBytes [8]byte ) - err := kvdb.View(c.db, func(tx kvdb.RTx) error { + fetchChanInfos := func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2427,9 +2431,20 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { }) } return nil - }, func() { - chanEdges = nil - }) + } + + if tx == nil { + err := kvdb.View(c.db, fetchChanInfos, func() { + chanEdges = nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil + } + + err := fetchChanInfos(tx) if err != nil { return nil, err } @@ -3673,7 +3688,7 @@ func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error { // We need to add the channel back into our graph cache, otherwise we // won't use it for path finding. if c.graphCache != nil { - edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID}) if err != nil { return err } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index a37b84a01..8fe90c545 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2685,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) { // We'll now attempt to query for the range of channel ID's we just // inserted into the database. We should get the exact same set of // edges back. - resp, err := graph.FetchChanInfos(edgeQuery) + resp, err := graph.FetchChanInfos(nil, edgeQuery) require.NoError(t, err, "unable to fetch chan edges") if len(resp) != len(edges) { t.Fatalf("expected %v edges, instead got %v", len(edges), diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 34e6d4a9d..bd6571b87 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, chanIDs = append(chanIDs, chanID.ToUint64()) } - channels, err := c.graph.FetchChanInfos(chanIDs) + channels, err := c.graph.FetchChanInfos(nil, chanIDs) if err != nil { return nil, err } diff --git a/routing/router.go b/routing/router.go index 3465ec0b5..c602573eb 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1008,20 +1008,23 @@ func (r *ChannelRouter) pruneZombieChans() error { if r.cfg.AssumeChannelValid { disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs() if err != nil { - return fmt.Errorf("unable to get disabled channels ids "+ - "chans: %v", err) + return fmt.Errorf("unable to get disabled channels "+ + "ids chans: %v", err) } - disabledEdges, err := r.cfg.Graph.FetchChanInfos(disabledChanIDs) + disabledEdges, err := r.cfg.Graph.FetchChanInfos( + nil, disabledChanIDs, + ) if err != nil { - return fmt.Errorf("unable to fetch disabled channels edges "+ - "chans: %v", err) + return fmt.Errorf("unable to fetch disabled channels "+ + "edges chans: %v", err) } // Ensuring we won't prune our own channel from the graph. for _, disabledEdge := range disabledEdges { if !isSelfChannelEdge(disabledEdge.Info) { - chansToPrune[disabledEdge.Info.ChannelID] = struct{}{} + chansToPrune[disabledEdge.Info.ChannelID] = + struct{}{} } } }