diff --git a/channeldb/graph.go b/channeldb/graph.go index 1ab2897bb..bb7774895 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -176,6 +176,9 @@ const ( type ChannelGraph struct { db kvdb.Backend + // cacheMu guards all caches (rejectCache, chanCache, graphCache). If + // this mutex will be acquired at the same time as the DB mutex then + // the cacheMu MUST be acquired first to prevent deadlock. cacheMu sync.RWMutex rejectCache *rejectCache chanCache *channelCache @@ -1331,8 +1334,8 @@ func (c *ChannelGraph) PruneGraph(spentOutputs []*wire.OutPoint, // will be returned if that outpoint isn't known to be // a channel. If no error is returned, then a channel // was successfully pruned. - err = c.delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, + err = c.delChannelEdgeUnsafe( + edges, edgeIndex, chanIndex, zombieIndex, chanID, false, false, ) if err != nil && err != ErrEdgeNotFound { @@ -1562,10 +1565,6 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( if err != nil { return err } - nodes, err := tx.CreateTopLevelBucket(nodeBucket) - if err != nil { - return err - } // Scan from chanIDStart to chanIDEnd, deleting every // found edge. @@ -1590,8 +1589,8 @@ func (c *ChannelGraph) DisconnectBlockAtHeight(height uint32) ( } for _, k := range keys { - err = c.delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, + err = c.delChannelEdgeUnsafe( + edges, edgeIndex, chanIndex, zombieIndex, k, false, false, ) if err != nil && err != ErrEdgeNotFound { @@ -1734,8 +1733,8 @@ func (c *ChannelGraph) DeleteChannelEdges(strictZombiePruning, markZombie bool, var rawChanID [8]byte for _, chanID := range chanIDs { byteOrder.PutUint64(rawChanID[:], chanID) - err := c.delChannelEdge( - edges, edgeIndex, chanIndex, zombieIndex, nodes, + err := c.delChannelEdgeUnsafe( + edges, edgeIndex, chanIndex, zombieIndex, rawChanID[:], markZombie, strictZombiePruning, ) if err != nil { @@ -2091,6 +2090,9 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, var newChanIDs []uint64 + c.cacheMu.Lock() + defer c.cacheMu.Unlock() + err := kvdb.Update(c.db, func(tx kvdb.RwTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { @@ -2143,7 +2145,7 @@ func (c *ChannelGraph) FilterKnownChanIDs(chansInfo []ChannelUpdateInfo, // and we let it be added to the set of IDs to // query our peer for. case isZombie && !isStillZombie: - err := c.markEdgeLive(tx, scid) + err := c.markEdgeLiveUnsafe(tx, scid) if err != nil { return err } @@ -2355,7 +2357,11 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, // skipped and the result will contain only those edges that exist at the time // of the query. This can be used to respond to peer queries that are seeking to // fill in gaps in their view of the channel graph. -func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { +// +// NOTE: An optional transaction may be provided. If none is provided, then a +// new one will be created. +func (c *ChannelGraph) FetchChanInfos(tx kvdb.RTx, chanIDs []uint64) ( + []ChannelEdge, error) { // TODO(roasbeef): sort cids? var ( @@ -2363,7 +2369,7 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { cidBytes [8]byte ) - err := kvdb.View(c.db, func(tx kvdb.RTx) error { + fetchChanInfos := func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) if edges == nil { return ErrGraphNoEdgesFound @@ -2425,9 +2431,20 @@ func (c *ChannelGraph) FetchChanInfos(chanIDs []uint64) ([]ChannelEdge, error) { }) } return nil - }, func() { - chanEdges = nil - }) + } + + if tx == nil { + err := kvdb.View(c.db, fetchChanInfos, func() { + chanEdges = nil + }) + if err != nil { + return nil, err + } + + return chanEdges, nil + } + + err := fetchChanInfos(tx) if err != nil { return nil, err } @@ -2473,8 +2490,16 @@ func delEdgeUpdateIndexEntry(edgesBucket kvdb.RwBucket, chanID uint64, return nil } -func (c *ChannelGraph) delChannelEdge(edges, edgeIndex, chanIndex, zombieIndex, - nodes kvdb.RwBucket, chanID []byte, isZombie, strictZombie bool) error { +// delChannelEdgeUnsafe deletes the edge with the given chanID from the graph +// cache. It then goes on to delete any policy info and edge info for this +// channel from the DB and finally, if isZombie is true, it will add an entry +// for this channel in the zombie index. +// +// NOTE: this method MUST only be called if the cacheMu has already been +// acquired. +func (c *ChannelGraph) delChannelEdgeUnsafe(edges, edgeIndex, chanIndex, + zombieIndex kvdb.RwBucket, chanID []byte, isZombie, + strictZombie bool) error { edgeInfo, err := fetchChanEdgeInfo(edgeIndex, chanID) if err != nil { @@ -3612,16 +3637,19 @@ func markEdgeZombie(zombieIndex kvdb.RwBucket, chanID uint64, pubKey1, // MarkEdgeLive clears an edge from our zombie index, deeming it as live. func (c *ChannelGraph) MarkEdgeLive(chanID uint64) error { - return c.markEdgeLive(nil, chanID) -} - -// markEdgeLive clears an edge from the zombie index. This method can be called -// with an existing kvdb.RwTx or the argument can be set to nil in which case a -// new transaction will be created. -func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error { c.cacheMu.Lock() defer c.cacheMu.Unlock() + return c.markEdgeLiveUnsafe(nil, chanID) +} + +// markEdgeLiveUnsafe clears an edge from the zombie index. This method can be +// called with an existing kvdb.RwTx or the argument can be set to nil in which +// case a new transaction will be created. +// +// NOTE: this method MUST only be called if the cacheMu has already been +// acquired. +func (c *ChannelGraph) markEdgeLiveUnsafe(tx kvdb.RwTx, chanID uint64) error { dbFn := func(tx kvdb.RwTx) error { edges := tx.ReadWriteBucket(edgeBucket) if edges == nil { @@ -3660,7 +3688,7 @@ func (c *ChannelGraph) markEdgeLive(tx kvdb.RwTx, chanID uint64) error { // We need to add the channel back into our graph cache, otherwise we // won't use it for path finding. if c.graphCache != nil { - edgeInfos, err := c.FetchChanInfos([]uint64{chanID}) + edgeInfos, err := c.FetchChanInfos(tx, []uint64{chanID}) if err != nil { return err } diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index cbfabb3cc..8fe90c545 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2090,6 +2090,296 @@ func TestFilterKnownChanIDs(t *testing.T) { } } +// TestStressTestChannelGraphAPI is a stress test that concurrently calls some +// of the ChannelGraph methods in various orders in order to ensure that no +// deadlock can occur. This test currently focuses on stress testing all the +// methods that acquire the cache mutex along with the DB mutex. +func TestStressTestChannelGraphAPI(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.NoError(t, err) + + node1, err := createTestVertex(graph.db) + require.NoError(t, err, "unable to create test node") + require.NoError(t, graph.AddLightningNode(node1)) + + node2, err := createTestVertex(graph.db) + require.NoError(t, err, "unable to create test node") + require.NoError(t, graph.AddLightningNode(node2)) + + err = graph.SetSourceNode(node1) + require.NoError(t, err) + + type chanInfo struct { + info models.ChannelEdgeInfo + id lnwire.ShortChannelID + } + + var ( + chans []*chanInfo + mu sync.RWMutex + ) + + // newBlockHeight returns a random block height between 0 and 100. + newBlockHeight := func() uint32 { + return uint32(rand.Int31n(100)) + } + + // addNewChan is a will create and return a new random channel and will + // add it to the set of channels. + addNewChan := func() *chanInfo { + mu.Lock() + defer mu.Unlock() + + channel, chanID := createEdge( + newBlockHeight(), rand.Uint32(), uint16(rand.Int()), + rand.Uint32(), node1, node2, + ) + + newChan := &chanInfo{ + info: channel, + id: chanID, + } + chans = append(chans, newChan) + + return newChan + } + + // getRandChan picks a random channel from the set and returns it. + getRandChan := func() *chanInfo { + mu.RLock() + defer mu.RUnlock() + + if len(chans) == 0 { + return nil + } + + return chans[rand.Intn(len(chans))] + } + + // getRandChanSet returns a random set of channels. + getRandChanSet := func() []*chanInfo { + mu.RLock() + defer mu.RUnlock() + + if len(chans) == 0 { + return nil + } + + start := rand.Intn(len(chans)) + end := rand.Intn(len(chans)) + + if end < start { + start, end = end, start + } + + var infoCopy []*chanInfo + for i := start; i < end; i++ { + infoCopy = append(infoCopy, &chanInfo{ + info: chans[i].info, + id: chans[i].id, + }) + } + + return infoCopy + } + + // delChan deletes the channel with the given ID from the set if it + // exists. + delChan := func(id lnwire.ShortChannelID) { + mu.Lock() + defer mu.Unlock() + + index := -1 + for i, c := range chans { + if c.id == id { + index = i + break + } + } + + if index == -1 { + return + } + + chans = append(chans[:index], chans[index+1:]...) + } + + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) + + var methodsMu sync.Mutex + methods := []struct { + name string + fn func() error + }{ + { + name: "MarkEdgeZombie", + fn: func() error { + channel := getRandChan() + if channel == nil { + return nil + } + + return graph.MarkEdgeZombie( + channel.id.ToUint64(), + node1.PubKeyBytes, + node2.PubKeyBytes, + ) + }, + }, + { + name: "FilterKnownChanIDs", + fn: func() error { + chanSet := getRandChanSet() + var chanIDs []ChannelUpdateInfo + + for _, c := range chanSet { + chanIDs = append( + chanIDs, + ChannelUpdateInfo{ + ShortChannelID: c.id, + }, + ) + } + + _, err := graph.FilterKnownChanIDs( + chanIDs, + func(t time.Time, t2 time.Time) bool { + return rand.Intn(2) == 0 + }, + ) + + return err + }, + }, + { + name: "HasChannelEdge", + fn: func() error { + channel := getRandChan() + if channel == nil { + return nil + } + + _, _, _, _, err := graph.HasChannelEdge( + channel.id.ToUint64(), + ) + + return err + }, + }, + { + name: "PruneGraph", + fn: func() error { + chanSet := getRandChanSet() + var spentOutpoints []*wire.OutPoint + + for _, c := range chanSet { + spentOutpoints = append( + spentOutpoints, + &c.info.ChannelPoint, + ) + } + + _, err := graph.PruneGraph( + spentOutpoints, &blockHash, 100, + ) + + return err + }, + }, + { + name: "ChanUpdateInHorizon", + fn: func() error { + _, err := graph.ChanUpdatesInHorizon( + time.Now().Add(-time.Hour), time.Now(), + ) + + return err + }, + }, + { + name: "DeleteChannelEdges", + fn: func() error { + var ( + strictPruning = rand.Intn(2) == 0 + markZombie = rand.Intn(2) == 0 + channels = getRandChanSet() + chanIDs []uint64 + ) + + for _, c := range channels { + chanIDs = append( + chanIDs, c.id.ToUint64(), + ) + delChan(c.id) + } + + err := graph.DeleteChannelEdges( + strictPruning, markZombie, chanIDs..., + ) + if err != nil && + !errors.Is(err, ErrEdgeNotFound) { + + return err + } + + return nil + }, + }, + { + name: "DisconnectBlockAtHeight", + fn: func() error { + _, err := graph.DisconnectBlockAtHeight( + newBlockHeight(), + ) + + return err + }, + }, + { + name: "AddChannelEdge", + fn: func() error { + channel := addNewChan() + + return graph.AddChannelEdge(&channel.info) + }, + }, + } + + const ( + // concurrencyLevel is the number of concurrent goroutines that + // will be run simultaneously. + concurrencyLevel = 10 + + // executionCount is the number of methods that will be called + // per goroutine. + executionCount = 100 + ) + + for i := 0; i < concurrencyLevel; i++ { + i := i + + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + + for j := 0; j < executionCount; j++ { + // Randomly select a method to execute. + methodIndex := rand.Intn(len(methods)) + + methodsMu.Lock() + fn := methods[methodIndex].fn + name := methods[methodIndex].name + methodsMu.Unlock() + + err := fn() + require.NoErrorf(t, err, fmt.Sprintf(name)) + } + }) + } +} + // TestFilterChannelRange tests that we're able to properly retrieve the full // set of short channel ID's for a given block range. func TestFilterChannelRange(t *testing.T) { @@ -2395,7 +2685,7 @@ func TestFetchChanInfos(t *testing.T) { // We'll now attempt to query for the range of channel ID's we just // inserted into the database. We should get the exact same set of // edges back. - resp, err := graph.FetchChanInfos(edgeQuery) + resp, err := graph.FetchChanInfos(nil, edgeQuery) require.NoError(t, err, "unable to fetch chan edges") if len(resp) != len(edges) { t.Fatalf("expected %v edges, instead got %v", len(edges), diff --git a/discovery/chan_series.go b/discovery/chan_series.go index 34e6d4a9d..bd6571b87 100644 --- a/discovery/chan_series.go +++ b/discovery/chan_series.go @@ -249,7 +249,7 @@ func (c *ChanSeries) FetchChanAnns(chain chainhash.Hash, chanIDs = append(chanIDs, chanID.ToUint64()) } - channels, err := c.graph.FetchChanInfos(chanIDs) + channels, err := c.graph.FetchChanInfos(nil, chanIDs) if err != nil { return nil, err } diff --git a/docs/release-notes/release-notes-0.17.4.md b/docs/release-notes/release-notes-0.17.4.md index fdd04ae81..c527576b7 100644 --- a/docs/release-notes/release-notes-0.17.4.md +++ b/docs/release-notes/release-notes-0.17.4.md @@ -24,6 +24,10 @@ channel opening was pruned from memory no more channels were able to be created nor accepted. This PR fixes this issue and enhances the test suite for this behavior. + +* [Fix deadlock possibility in + FilterKnownChanIDs](https://github.com/lightningnetwork/lnd/pull/8400) by + ensuring the `cacheMu` mutex is acquired before the main database lock. # New Features ## Functional Enhancements @@ -46,3 +50,5 @@ ## Tooling and Documentation # Contributors (Alphabetical Order) +* Elle Mouton +* ziggie1984 diff --git a/routing/router.go b/routing/router.go index 3465ec0b5..c602573eb 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1008,20 +1008,23 @@ func (r *ChannelRouter) pruneZombieChans() error { if r.cfg.AssumeChannelValid { disabledChanIDs, err := r.cfg.Graph.DisabledChannelIDs() if err != nil { - return fmt.Errorf("unable to get disabled channels ids "+ - "chans: %v", err) + return fmt.Errorf("unable to get disabled channels "+ + "ids chans: %v", err) } - disabledEdges, err := r.cfg.Graph.FetchChanInfos(disabledChanIDs) + disabledEdges, err := r.cfg.Graph.FetchChanInfos( + nil, disabledChanIDs, + ) if err != nil { - return fmt.Errorf("unable to fetch disabled channels edges "+ - "chans: %v", err) + return fmt.Errorf("unable to fetch disabled channels "+ + "edges chans: %v", err) } // Ensuring we won't prune our own channel from the graph. for _, disabledEdge := range disabledEdges { if !isSelfChannelEdge(disabledEdge.Info) { - chansToPrune[disabledEdge.Info.ChannelID] = struct{}{} + chansToPrune[disabledEdge.Info.ChannelID] = + struct{}{} } } }