channeldb: add ChannelGraph stress test

This commit adds a test that calls many of the ChannelGraph methods
concurrently and in a random order. This test demonstrates that a
deadlock currently exists in the ChannelGraph since the test does not
complete. This is fixed in the next commit.
This commit is contained in:
Elle Mouton 2024-01-19 13:41:14 +02:00
parent ec5b824879
commit 76748bdbf7
No known key found for this signature in database
GPG Key ID: D7D916376026F177

View File

@ -2090,6 +2090,296 @@ func TestFilterKnownChanIDs(t *testing.T) {
}
}
// TestStressTestChannelGraphAPI is a stress test that concurrently calls some
// of the ChannelGraph methods in various orders in order to ensure that no
// deadlock can occur. This test currently focuses on stress testing all the
// methods that acquire the cache mutex along with the DB mutex.
func TestStressTestChannelGraphAPI(t *testing.T) {
t.Parallel()
graph, err := MakeTestGraph(t)
require.NoError(t, err)
node1, err := createTestVertex(graph.db)
require.NoError(t, err, "unable to create test node")
require.NoError(t, graph.AddLightningNode(node1))
node2, err := createTestVertex(graph.db)
require.NoError(t, err, "unable to create test node")
require.NoError(t, graph.AddLightningNode(node2))
err = graph.SetSourceNode(node1)
require.NoError(t, err)
type chanInfo struct {
info models.ChannelEdgeInfo
id lnwire.ShortChannelID
}
var (
chans []*chanInfo
mu sync.RWMutex
)
// newBlockHeight returns a random block height between 0 and 100.
newBlockHeight := func() uint32 {
return uint32(rand.Int31n(100))
}
// addNewChan is a will create and return a new random channel and will
// add it to the set of channels.
addNewChan := func() *chanInfo {
mu.Lock()
defer mu.Unlock()
channel, chanID := createEdge(
newBlockHeight(), rand.Uint32(), uint16(rand.Int()),
rand.Uint32(), node1, node2,
)
newChan := &chanInfo{
info: channel,
id: chanID,
}
chans = append(chans, newChan)
return newChan
}
// getRandChan picks a random channel from the set and returns it.
getRandChan := func() *chanInfo {
mu.RLock()
defer mu.RUnlock()
if len(chans) == 0 {
return nil
}
return chans[rand.Intn(len(chans))]
}
// getRandChanSet returns a random set of channels.
getRandChanSet := func() []*chanInfo {
mu.RLock()
defer mu.RUnlock()
if len(chans) == 0 {
return nil
}
start := rand.Intn(len(chans))
end := rand.Intn(len(chans))
if end < start {
start, end = end, start
}
var infoCopy []*chanInfo
for i := start; i < end; i++ {
infoCopy = append(infoCopy, &chanInfo{
info: chans[i].info,
id: chans[i].id,
})
}
return infoCopy
}
// delChan deletes the channel with the given ID from the set if it
// exists.
delChan := func(id lnwire.ShortChannelID) {
mu.Lock()
defer mu.Unlock()
index := -1
for i, c := range chans {
if c.id == id {
index = i
break
}
}
if index == -1 {
return
}
chans = append(chans[:index], chans[index+1:]...)
}
var blockHash chainhash.Hash
copy(blockHash[:], bytes.Repeat([]byte{2}, 32))
var methodsMu sync.Mutex
methods := []struct {
name string
fn func() error
}{
{
name: "MarkEdgeZombie",
fn: func() error {
channel := getRandChan()
if channel == nil {
return nil
}
return graph.MarkEdgeZombie(
channel.id.ToUint64(),
node1.PubKeyBytes,
node2.PubKeyBytes,
)
},
},
{
name: "FilterKnownChanIDs",
fn: func() error {
chanSet := getRandChanSet()
var chanIDs []ChannelUpdateInfo
for _, c := range chanSet {
chanIDs = append(
chanIDs,
ChannelUpdateInfo{
ShortChannelID: c.id,
},
)
}
_, err := graph.FilterKnownChanIDs(
chanIDs,
func(t time.Time, t2 time.Time) bool {
return rand.Intn(2) == 0
},
)
return err
},
},
{
name: "HasChannelEdge",
fn: func() error {
channel := getRandChan()
if channel == nil {
return nil
}
_, _, _, _, err := graph.HasChannelEdge(
channel.id.ToUint64(),
)
return err
},
},
{
name: "PruneGraph",
fn: func() error {
chanSet := getRandChanSet()
var spentOutpoints []*wire.OutPoint
for _, c := range chanSet {
spentOutpoints = append(
spentOutpoints,
&c.info.ChannelPoint,
)
}
_, err := graph.PruneGraph(
spentOutpoints, &blockHash, 100,
)
return err
},
},
{
name: "ChanUpdateInHorizon",
fn: func() error {
_, err := graph.ChanUpdatesInHorizon(
time.Now().Add(-time.Hour), time.Now(),
)
return err
},
},
{
name: "DeleteChannelEdges",
fn: func() error {
var (
strictPruning = rand.Intn(2) == 0
markZombie = rand.Intn(2) == 0
channels = getRandChanSet()
chanIDs []uint64
)
for _, c := range channels {
chanIDs = append(
chanIDs, c.id.ToUint64(),
)
delChan(c.id)
}
err := graph.DeleteChannelEdges(
strictPruning, markZombie, chanIDs...,
)
if err != nil &&
!errors.Is(err, ErrEdgeNotFound) {
return err
}
return nil
},
},
{
name: "DisconnectBlockAtHeight",
fn: func() error {
_, err := graph.DisconnectBlockAtHeight(
newBlockHeight(),
)
return err
},
},
{
name: "AddChannelEdge",
fn: func() error {
channel := addNewChan()
return graph.AddChannelEdge(&channel.info)
},
},
}
const (
// concurrencyLevel is the number of concurrent goroutines that
// will be run simultaneously.
concurrencyLevel = 10
// executionCount is the number of methods that will be called
// per goroutine.
executionCount = 100
)
for i := 0; i < concurrencyLevel; i++ {
i := i
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
t.Parallel()
for j := 0; j < executionCount; j++ {
// Randomly select a method to execute.
methodIndex := rand.Intn(len(methods))
methodsMu.Lock()
fn := methods[methodIndex].fn
name := methods[methodIndex].name
methodsMu.Unlock()
err := fn()
require.NoErrorf(t, err, fmt.Sprintf(name))
}
})
}
}
// TestFilterChannelRange tests that we're able to properly retrieve the full
// set of short channel ID's for a given block range.
func TestFilterChannelRange(t *testing.T) {