From 76748bdbf720d49d827fc6278a62e57042980fe3 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 19 Jan 2024 13:41:14 +0200 Subject: [PATCH] channeldb: add ChannelGraph stress test This commit adds a test that calls many of the ChannelGraph methods concurrently and in a random order. This test demonstrates that a deadlock currently exists in the ChannelGraph since the test does not complete. This is fixed in the next commit. --- channeldb/graph_test.go | 290 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 290 insertions(+) diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index cbfabb3cc..a37b84a01 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -2090,6 +2090,296 @@ func TestFilterKnownChanIDs(t *testing.T) { } } +// TestStressTestChannelGraphAPI is a stress test that concurrently calls some +// of the ChannelGraph methods in various orders in order to ensure that no +// deadlock can occur. This test currently focuses on stress testing all the +// methods that acquire the cache mutex along with the DB mutex. +func TestStressTestChannelGraphAPI(t *testing.T) { + t.Parallel() + + graph, err := MakeTestGraph(t) + require.NoError(t, err) + + node1, err := createTestVertex(graph.db) + require.NoError(t, err, "unable to create test node") + require.NoError(t, graph.AddLightningNode(node1)) + + node2, err := createTestVertex(graph.db) + require.NoError(t, err, "unable to create test node") + require.NoError(t, graph.AddLightningNode(node2)) + + err = graph.SetSourceNode(node1) + require.NoError(t, err) + + type chanInfo struct { + info models.ChannelEdgeInfo + id lnwire.ShortChannelID + } + + var ( + chans []*chanInfo + mu sync.RWMutex + ) + + // newBlockHeight returns a random block height between 0 and 100. + newBlockHeight := func() uint32 { + return uint32(rand.Int31n(100)) + } + + // addNewChan is a will create and return a new random channel and will + // add it to the set of channels. + addNewChan := func() *chanInfo { + mu.Lock() + defer mu.Unlock() + + channel, chanID := createEdge( + newBlockHeight(), rand.Uint32(), uint16(rand.Int()), + rand.Uint32(), node1, node2, + ) + + newChan := &chanInfo{ + info: channel, + id: chanID, + } + chans = append(chans, newChan) + + return newChan + } + + // getRandChan picks a random channel from the set and returns it. + getRandChan := func() *chanInfo { + mu.RLock() + defer mu.RUnlock() + + if len(chans) == 0 { + return nil + } + + return chans[rand.Intn(len(chans))] + } + + // getRandChanSet returns a random set of channels. + getRandChanSet := func() []*chanInfo { + mu.RLock() + defer mu.RUnlock() + + if len(chans) == 0 { + return nil + } + + start := rand.Intn(len(chans)) + end := rand.Intn(len(chans)) + + if end < start { + start, end = end, start + } + + var infoCopy []*chanInfo + for i := start; i < end; i++ { + infoCopy = append(infoCopy, &chanInfo{ + info: chans[i].info, + id: chans[i].id, + }) + } + + return infoCopy + } + + // delChan deletes the channel with the given ID from the set if it + // exists. + delChan := func(id lnwire.ShortChannelID) { + mu.Lock() + defer mu.Unlock() + + index := -1 + for i, c := range chans { + if c.id == id { + index = i + break + } + } + + if index == -1 { + return + } + + chans = append(chans[:index], chans[index+1:]...) + } + + var blockHash chainhash.Hash + copy(blockHash[:], bytes.Repeat([]byte{2}, 32)) + + var methodsMu sync.Mutex + methods := []struct { + name string + fn func() error + }{ + { + name: "MarkEdgeZombie", + fn: func() error { + channel := getRandChan() + if channel == nil { + return nil + } + + return graph.MarkEdgeZombie( + channel.id.ToUint64(), + node1.PubKeyBytes, + node2.PubKeyBytes, + ) + }, + }, + { + name: "FilterKnownChanIDs", + fn: func() error { + chanSet := getRandChanSet() + var chanIDs []ChannelUpdateInfo + + for _, c := range chanSet { + chanIDs = append( + chanIDs, + ChannelUpdateInfo{ + ShortChannelID: c.id, + }, + ) + } + + _, err := graph.FilterKnownChanIDs( + chanIDs, + func(t time.Time, t2 time.Time) bool { + return rand.Intn(2) == 0 + }, + ) + + return err + }, + }, + { + name: "HasChannelEdge", + fn: func() error { + channel := getRandChan() + if channel == nil { + return nil + } + + _, _, _, _, err := graph.HasChannelEdge( + channel.id.ToUint64(), + ) + + return err + }, + }, + { + name: "PruneGraph", + fn: func() error { + chanSet := getRandChanSet() + var spentOutpoints []*wire.OutPoint + + for _, c := range chanSet { + spentOutpoints = append( + spentOutpoints, + &c.info.ChannelPoint, + ) + } + + _, err := graph.PruneGraph( + spentOutpoints, &blockHash, 100, + ) + + return err + }, + }, + { + name: "ChanUpdateInHorizon", + fn: func() error { + _, err := graph.ChanUpdatesInHorizon( + time.Now().Add(-time.Hour), time.Now(), + ) + + return err + }, + }, + { + name: "DeleteChannelEdges", + fn: func() error { + var ( + strictPruning = rand.Intn(2) == 0 + markZombie = rand.Intn(2) == 0 + channels = getRandChanSet() + chanIDs []uint64 + ) + + for _, c := range channels { + chanIDs = append( + chanIDs, c.id.ToUint64(), + ) + delChan(c.id) + } + + err := graph.DeleteChannelEdges( + strictPruning, markZombie, chanIDs..., + ) + if err != nil && + !errors.Is(err, ErrEdgeNotFound) { + + return err + } + + return nil + }, + }, + { + name: "DisconnectBlockAtHeight", + fn: func() error { + _, err := graph.DisconnectBlockAtHeight( + newBlockHeight(), + ) + + return err + }, + }, + { + name: "AddChannelEdge", + fn: func() error { + channel := addNewChan() + + return graph.AddChannelEdge(&channel.info) + }, + }, + } + + const ( + // concurrencyLevel is the number of concurrent goroutines that + // will be run simultaneously. + concurrencyLevel = 10 + + // executionCount is the number of methods that will be called + // per goroutine. + executionCount = 100 + ) + + for i := 0; i < concurrencyLevel; i++ { + i := i + + t.Run(fmt.Sprintf("%d", i), func(t *testing.T) { + t.Parallel() + + for j := 0; j < executionCount; j++ { + // Randomly select a method to execute. + methodIndex := rand.Intn(len(methods)) + + methodsMu.Lock() + fn := methods[methodIndex].fn + name := methods[methodIndex].name + methodsMu.Unlock() + + err := fn() + require.NoErrorf(t, err, fmt.Sprintf(name)) + } + }) + } +} + // TestFilterChannelRange tests that we're able to properly retrieve the full // set of short channel ID's for a given block range. func TestFilterChannelRange(t *testing.T) {