diff --git a/autopilot/agent_test.go b/autopilot/agent_test.go index ccae20a14..758401868 100644 --- a/autopilot/agent_test.go +++ b/autopilot/agent_test.go @@ -155,7 +155,7 @@ type testContext struct { sync.Mutex } -func setup(t *testing.T, initialChans []LocalChannel) (*testContext, func()) { +func setup(t *testing.T, initialChans []LocalChannel) *testContext { t.Helper() // First, we'll create all the dependencies that we'll need in order to @@ -178,7 +178,7 @@ func setup(t *testing.T, initialChans []LocalChannel) (*testContext, func()) { chanController := &mockChanController{ openChanSignals: make(chan openChanIntent, 10), } - memGraph, _, _ := newMemChanGraph() + memGraph, _ := newMemChanGraph(t) // We'll keep track of the funds available to the agent, to make sure // it correctly uses this value when querying the ChannelBudget. @@ -224,14 +224,14 @@ func setup(t *testing.T, initialChans []LocalChannel) (*testContext, func()) { t.Fatalf("unable to start agent: %v", err) } - cleanup := func() { + t.Cleanup(func() { // We must close quit before agent.Stop(), to make sure // ChannelBudget won't block preventing the agent from exiting. close(quit) agent.Stop() - } + }) - return ctx, cleanup + return ctx } // respondMoreChans consumes the moreChanArgs element and responds to the agent @@ -279,8 +279,7 @@ func respondNodeScores(t *testing.T, testCtx *testContext, func TestAgentChannelOpenSignal(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // We'll send an initial "no" response to advance the agent past its // initial check. @@ -324,8 +323,7 @@ func TestAgentChannelOpenSignal(t *testing.T) { func TestAgentHeuristicUpdateSignal(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) pub, err := testCtx.graph.addRandNode() require.NoError(t, err, "unable to generate key") @@ -386,8 +384,7 @@ var _ ChannelController = (*mockFailingChanController)(nil) func TestAgentChannelFailureSignal(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) testCtx.chanController = &mockFailingChanController{} @@ -436,8 +433,7 @@ func TestAgentChannelCloseSignal(t *testing.T) { }, } - testCtx, cleanup := setup(t, initialChans) - defer cleanup() + testCtx := setup(t, initialChans) // We'll send an initial "no" response to advance the agent past its // initial check. @@ -478,8 +474,7 @@ func TestAgentChannelCloseSignal(t *testing.T) { func TestAgentBalanceUpdate(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // We'll send an initial "no" response to advance the agent past its // initial check. @@ -525,8 +520,7 @@ func TestAgentBalanceUpdate(t *testing.T) { func TestAgentImmediateAttach(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) const numChans = 5 @@ -591,8 +585,7 @@ func TestAgentImmediateAttach(t *testing.T) { func TestAgentPrivateChannels(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // The chanController should be initialized such that all of its open // channel requests are for private channels. @@ -652,8 +645,7 @@ func TestAgentPrivateChannels(t *testing.T) { func TestAgentPendingChannelState(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // We'll only return a single directive for a pre-chosen node. nodeKey, err := testCtx.graph.addRandNode() @@ -764,8 +756,7 @@ func TestAgentPendingChannelState(t *testing.T) { func TestAgentPendingOpenChannel(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // We'll send an initial "no" response to advance the agent past its // initial check. @@ -796,8 +787,7 @@ func TestAgentPendingOpenChannel(t *testing.T) { func TestAgentOnNodeUpdates(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) // We'll send an initial "yes" response to advance the agent past its // initial check. This will cause it to try to get directives from an @@ -844,8 +834,7 @@ func TestAgentOnNodeUpdates(t *testing.T) { func TestAgentSkipPendingConns(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) connect := make(chan chan error) testCtx.agent.cfg.ConnectToPeer = func(*btcec.PublicKey, []net.Addr) (bool, error) { @@ -1025,8 +1014,7 @@ func TestAgentSkipPendingConns(t *testing.T) { func TestAgentQuitWhenPendingConns(t *testing.T) { t.Parallel() - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) connect := make(chan chan error) @@ -1216,8 +1204,7 @@ func TestAgentChannelSizeAllocation(t *testing.T) { // Total number of nodes in our mock graph. const numNodes = 20 - testCtx, cleanup := setup(t, nil) - defer cleanup() + testCtx := setup(t, nil) nodeScores := make(map[NodeID]*NodeScore) for i := 0; i < numNodes; i++ { diff --git a/autopilot/betweenness_centrality_test.go b/autopilot/betweenness_centrality_test.go index 1efd005f2..408becb63 100644 --- a/autopilot/betweenness_centrality_test.go +++ b/autopilot/betweenness_centrality_test.go @@ -37,22 +37,18 @@ func TestBetweennessCentralityEmptyGraph(t *testing.T) { ) for _, chanGraph := range chanGraphs { - graph, cleanup, err := chanGraph.genFunc() success := t.Run(chanGraph.name, func(t1 *testing.T) { - require.NoError(t, err, "unable to create graph") + graph, err := chanGraph.genFunc(t1) + require.NoError(t1, err, "unable to create graph") - if cleanup != nil { - defer cleanup() - } - - err := centralityMetric.Refresh(graph) - require.NoError(t, err) + err = centralityMetric.Refresh(graph) + require.NoError(t1, err) centrality := centralityMetric.GetMetric(false) - require.Equal(t, 0, len(centrality)) + require.Equal(t1, 0, len(centrality)) centrality = centralityMetric.GetMetric(true) - require.Equal(t, 0, len(centrality)) + require.Equal(t1, 0, len(centrality)) }) if !success { break @@ -81,13 +77,9 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { for _, numWorkers := range workers { for _, chanGraph := range chanGraphs { numWorkers := numWorkers - graph, cleanup, err := chanGraph.genFunc() + graph, err := chanGraph.genFunc(t) require.NoError(t, err, "unable to create graph") - if cleanup != nil { - defer cleanup() - } - testName := fmt.Sprintf( "%v %d workers", chanGraph.name, numWorkers, ) @@ -97,7 +89,7 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { numWorkers, ) require.NoError( - t, err, + t1, err, "construction must succeed with "+ "positive number of workers", ) @@ -107,7 +99,7 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { ) err = metric.Refresh(graph) - require.NoError(t, err) + require.NoError(t1, err) for _, expected := range tests { expected := expected @@ -115,7 +107,7 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { expected.normalize, ) - require.Equal(t, + require.Equal(t1, centralityTestGraph.nodes, len(centrality), ) @@ -125,8 +117,8 @@ func TestBetweennessCentralityWithNonEmptyGraph(t *testing.T) { graphNodes[i], ) result, ok := centrality[nodeID] - require.True(t, ok) - require.Equal(t, c, result) + require.True(t1, ok) + require.Equal(t1, c, result) } } }) diff --git a/autopilot/prefattach_test.go b/autopilot/prefattach_test.go index 64a27802b..6803363f5 100644 --- a/autopilot/prefattach_test.go +++ b/autopilot/prefattach_test.go @@ -2,9 +2,7 @@ package autopilot import ( "bytes" - "io/ioutil" prand "math/rand" - "os" "testing" "time" @@ -14,7 +12,7 @@ import ( "github.com/stretchr/testify/require" ) -type genGraphFunc func() (testGraph, func(), error) +type genGraphFunc func(t *testing.T) (testGraph, error) type testGraph interface { ChannelGraph @@ -25,34 +23,25 @@ type testGraph interface { addRandNode() (*btcec.PublicKey, error) } -func newDiskChanGraph() (testGraph, func(), error) { - // First, create a temporary directory to be used for the duration of - // this test. - tempDirName, err := ioutil.TempDir("", "channeldb") - if err != nil { - return nil, nil, err - } - +func newDiskChanGraph(t *testing.T) (testGraph, error) { // Next, create channeldb for the first time. - cdb, err := channeldb.Open(tempDirName) + cdb, err := channeldb.Open(t.TempDir()) if err != nil { - return nil, nil, err - } - - cleanUp := func() { - cdb.Close() - os.RemoveAll(tempDirName) + return nil, err } + t.Cleanup(func() { + require.NoError(t, cdb.Close()) + }) return &databaseChannelGraph{ db: cdb.ChannelGraph(), - }, cleanUp, nil + }, nil } var _ testGraph = (*databaseChannelGraph)(nil) -func newMemChanGraph() (testGraph, func(), error) { - return newMemChannelGraph(), nil, nil +func newMemChanGraph(_ *testing.T) (testGraph, error) { + return newMemChannelGraph(), nil } var _ testGraph = (*memChannelGraph)(nil) @@ -86,13 +75,10 @@ func TestPrefAttachmentSelectEmptyGraph(t *testing.T) { for _, graph := range chanGraphs { success := t.Run(graph.name, func(t1 *testing.T) { - graph, cleanup, err := graph.genFunc() + graph, err := graph.genFunc(t1) if err != nil { t1.Fatalf("unable to create graph: %v", err) } - if cleanup != nil { - defer cleanup() - } // With the necessary state initialized, we'll now // attempt to get the score for this one node. @@ -131,13 +117,10 @@ func TestPrefAttachmentSelectTwoVertexes(t *testing.T) { for _, graph := range chanGraphs { success := t.Run(graph.name, func(t1 *testing.T) { - graph, cleanup, err := graph.genFunc() + graph, err := graph.genFunc(t1) if err != nil { t1.Fatalf("unable to create graph: %v", err) } - if cleanup != nil { - defer cleanup() - } prefAttach := NewPrefAttachment() @@ -231,13 +214,10 @@ func TestPrefAttachmentSelectGreedyAllocation(t *testing.T) { for _, graph := range chanGraphs { success := t.Run(graph.name, func(t1 *testing.T) { - graph, cleanup, err := graph.genFunc() + graph, err := graph.genFunc(t1) if err != nil { t1.Fatalf("unable to create graph: %v", err) } - if cleanup != nil { - defer cleanup() - } prefAttach := NewPrefAttachment() @@ -363,13 +343,10 @@ func TestPrefAttachmentSelectSkipNodes(t *testing.T) { for _, graph := range chanGraphs { success := t.Run(graph.name, func(t1 *testing.T) { - graph, cleanup, err := graph.genFunc() + graph, err := graph.genFunc(t1) if err != nil { t1.Fatalf("unable to create graph: %v", err) } - if cleanup != nil { - defer cleanup() - } prefAttach := NewPrefAttachment() diff --git a/autopilot/top_centrality_test.go b/autopilot/top_centrality_test.go index 2688f4090..426f0adc3 100644 --- a/autopilot/top_centrality_test.go +++ b/autopilot/top_centrality_test.go @@ -85,22 +85,19 @@ func TestTopCentrality(t *testing.T) { for _, chanGraph := range chanGraphs { chanGraph := chanGraph - success := t.Run(chanGraph.name, func(t *testing.T) { - t.Parallel() + success := t.Run(chanGraph.name, func(t1 *testing.T) { + t1.Parallel() - graph, cleanup, err := chanGraph.genFunc() - require.NoError(t, err, "unable to create graph") - if cleanup != nil { - defer cleanup() - } + graph, err := chanGraph.genFunc(t1) + require.NoError(t1, err, "unable to create graph") // Build the test graph. graphNodes := buildTestGraph( - t, graph, centralityTestGraph, + t1, graph, centralityTestGraph, ) for _, chans := range channelsWith { - testTopCentrality(t, graph, graphNodes, chans) + testTopCentrality(t1, graph, graphNodes, chans) } })