diff --git a/lntest/itest/lnd_channel_force_close.go b/lntest/itest/lnd_channel_force_close.go index 9077a7c48..4af4f7049 100644 --- a/lntest/itest/lnd_channel_force_close.go +++ b/lntest/itest/lnd_channel_force_close.go @@ -65,10 +65,7 @@ func testCommitmentTransactionDeadline(net *lntest.NetworkHarness, // transaction to CPFP our commitment transaction. feeRateLarge := maxPerKw * 2 - ctxt, cancel := context.WithTimeout( - context.Background(), defaultTimeout, - ) - defer cancel() + ctxb := context.Background() // Before we start, set up the default fee rate and we will test the // actual fee rate against it to decide whether we are using the @@ -76,25 +73,28 @@ func testCommitmentTransactionDeadline(net *lntest.NetworkHarness, net.SetFeeEstimate(feeRateDefault) // setupNode creates a new node and sends 1 btc to the node. - setupNode := func(name string) *lntest.HarnessNode { + setupNode := func(ctx context.Context, name string) *lntest.HarnessNode { // Create the node. args := []string{"--hodl.exit-settle"} args = append(args, commitTypeAnchors.Args()...) node := net.NewNode(t.t, name, args) // Send some coins to the node. - net.SendCoins(ctxt, t.t, btcutil.SatoshiPerBitcoin, node) + net.SendCoins(ctx, t.t, btcutil.SatoshiPerBitcoin, node) return node } // calculateSweepFeeRate runs multiple steps to calculate the fee rate // used in sweeping the transactions. calculateSweepFeeRate := func(expectedSweepTxNum int) int64 { + ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout) + defer cancel() + // Create two nodes, Alice and Bob. - alice := setupNode("Alice") + alice := setupNode(ctxt, "Alice") defer shutdownAndAssert(net, t, alice) - bob := setupNode("Bob") + bob := setupNode(ctxt, "Bob") defer shutdownAndAssert(net, t, bob) // Connect Alice to Bob. @@ -102,8 +102,7 @@ func testCommitmentTransactionDeadline(net *lntest.NetworkHarness, // Open a channel between Alice and Bob. chanPoint := openChannelAndAssert( - ctxt, t, net, alice, bob, - lntest.OpenChannelParams{ + ctxt, t, net, alice, bob, lntest.OpenChannelParams{ Amt: 10e6, PushAmt: 5e6, }, @@ -113,8 +112,7 @@ func testCommitmentTransactionDeadline(net *lntest.NetworkHarness, // be used as our deadline later on when Alice force closes the // channel. _, err := alice.RouterClient.SendPaymentV2( - ctxt, - &routerrpc.SendPaymentRequest{ + ctxt, &routerrpc.SendPaymentRequest{ Dest: bob.PubKey[:], Amt: 10e4, PaymentHash: makeFakePayHash(t), @@ -134,6 +132,8 @@ func testCommitmentTransactionDeadline(net *lntest.NetworkHarness, require.NoError(t.t, err, "htlc mismatch") // Alice force closes the channel. + ctxt, cancel = context.WithTimeout(ctxb, defaultTimeout) + defer cancel() _, _, err = net.CloseChannel(ctxt, alice, chanPoint, true) require.NoError(t.t, err, "unable to force close channel") diff --git a/lntest/itest/lnd_hold_invoice_force_test.go b/lntest/itest/lnd_hold_invoice_force_test.go index 7a71ac744..62c2bccd8 100644 --- a/lntest/itest/lnd_hold_invoice_force_test.go +++ b/lntest/itest/lnd_hold_invoice_force_test.go @@ -25,8 +25,11 @@ func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { Amt: 300000, } - ctxt, _ := context.WithTimeout(ctxb, channelOpenTimeout) - chanPoint := openChannelAndAssert(ctxt, t, net, net.Alice, net.Bob, chanReq) + ctxt, cancel := context.WithTimeout(ctxb, channelOpenTimeout) + defer cancel() + chanPoint := openChannelAndAssert( + ctxt, t, net, net.Alice, net.Bob, chanReq, + ) // Create a non-dust hold invoice for bob. var ( @@ -39,7 +42,8 @@ func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { Hash: payHash[:], } - ctxt, _ = context.WithTimeout(ctxb, defaultTimeout) + ctxt, cancel = context.WithTimeout(ctxb, defaultTimeout) + defer cancel() bobInvoice, err := net.Bob.AddHoldInvoice(ctxt, invoiceReq) require.NoError(t.t, err) @@ -72,23 +76,30 @@ func testHoldInvoiceForceClose(net *lntest.NetworkHarness, t *harnessTest) { require.Len(t.t, chans.Channels[0].PendingHtlcs, 1) activeHtlc := chans.Channels[0].PendingHtlcs[0] + require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) + require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) + info, err := net.Alice.GetInfo(ctxb, &lnrpc.GetInfoRequest{}) require.NoError(t.t, err) // Now we will mine blocks until the htlc expires, and wait for each // node to sync to our latest height. Sanity check that we won't // underflow. - require.Greater(t.t, activeHtlc.ExpirationHeight, info.BlockHeight, - "expected expiry after current height") + require.Greater( + t.t, activeHtlc.ExpirationHeight, info.BlockHeight, + "expected expiry after current height", + ) blocksTillExpiry := activeHtlc.ExpirationHeight - info.BlockHeight // Alice will go to chain with some delta, sanity check that we won't // underflow and subtract this from our mined blocks. - require.Greater(t.t, blocksTillExpiry, - uint32(lncfg.DefaultOutgoingBroadcastDelta)) + require.Greater( + t.t, blocksTillExpiry, + uint32(lncfg.DefaultOutgoingBroadcastDelta), + ) blocksTillForce := blocksTillExpiry - lncfg.DefaultOutgoingBroadcastDelta - mineBlocks(t, net, blocksTillForce, 0) + mineBlocksSlow(t, net, blocksTillForce, 0) require.NoError(t.t, net.Alice.WaitForBlockchainSync(ctxb)) require.NoError(t.t, net.Bob.WaitForBlockchainSync(ctxb)) diff --git a/lntest/itest/test_harness.go b/lntest/itest/test_harness.go index 90a3c4ba6..dc4585f28 100644 --- a/lntest/itest/test_harness.go +++ b/lntest/itest/test_harness.go @@ -20,6 +20,7 @@ import ( "github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lntest" "github.com/lightningnetwork/lnd/lntest/wait" + "github.com/stretchr/testify/require" ) var ( @@ -29,6 +30,8 @@ var ( lndExecutable = flag.String( "lndexec", itestLndBinary, "full path to lnd binary", ) + + slowMineDelay = 50 * time.Millisecond ) const ( @@ -238,6 +241,54 @@ func mineBlocks(t *harnessTest, net *lntest.NetworkHarness, return blocks } +// mineBlocksSlow mines 'num' of blocks and checks that blocks are present in +// the mining node's blockchain. numTxs should be set to the number of +// transactions (excluding the coinbase) we expect to be included in the first +// mined block. Between each mined block an artificial delay is introduced to +// give all network participants time to catch up. +func mineBlocksSlow(t *harnessTest, net *lntest.NetworkHarness, + num uint32, numTxs int) []*wire.MsgBlock { + + t.t.Helper() + + // If we expect transactions to be included in the blocks we'll mine, + // we wait here until they are seen in the miner's mempool. + var txids []*chainhash.Hash + var err error + if numTxs > 0 { + txids, err = waitForNTxsInMempool( + net.Miner.Client, numTxs, minerMempoolTimeout, + ) + require.NoError(t.t, err, "unable to find txns in mempool") + } + + blocks := make([]*wire.MsgBlock, num) + blockHashes := make([]*chainhash.Hash, 0, num) + + for i := uint32(0); i < num; i++ { + generatedHashes, err := net.Miner.Client.Generate(1) + require.NoError(t.t, err, "generate blocks") + blockHashes = append(blockHashes, generatedHashes...) + + time.Sleep(slowMineDelay) + } + + for i, blockHash := range blockHashes { + block, err := net.Miner.Client.GetBlock(blockHash) + require.NoError(t.t, err, "get blocks") + + blocks[i] = block + } + + // Finally, assert that all the transactions were included in the first + // block. + for _, txid := range txids { + assertTxInBlock(t, blocks[0], txid) + } + + return blocks +} + func assertTxInBlock(t *harnessTest, block *wire.MsgBlock, txid *chainhash.Hash) { for _, tx := range block.Transactions { sha := tx.TxHash() diff --git a/lntest/node.go b/lntest/node.go index 27ccb15c4..06359b0ed 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -12,6 +12,7 @@ import ( "os" "os/exec" "path/filepath" + "strings" "sync" "sync/atomic" "time" @@ -1125,7 +1126,21 @@ func (hn *HarnessNode) stop() error { // closed before a response is returned. req := lnrpc.StopRequest{} ctx := context.Background() - _, err := hn.LightningClient.StopDaemon(ctx, &req) + + err := wait.NoError(func() error { + _, err := hn.LightningClient.StopDaemon(ctx, &req) + switch { + case err == nil: + return nil + + // Try again if a recovery/rescan is in progress. + case strings.Contains(err.Error(), "recovery in progress"): + return err + + default: + return nil + } + }, DefaultTimeout) if err != nil { return err }