diff --git a/routing/notifications_test.go b/routing/notifications_test.go index 885a58167..c63aa66e4 100644 --- a/routing/notifications_test.go +++ b/routing/notifications_test.go @@ -47,6 +47,8 @@ var ( priv2, _ = btcec.NewPrivateKey(btcec.S256()) bitcoinKey2 = priv2.PubKey() + + timeout = time.Second * 5 ) func createTestNode() (*channeldb.LightningNode, error) { @@ -122,8 +124,9 @@ func createChannelEdge(ctx *testCtx, bitcoinKey1, bitcoinKey2 []byte, } type mockChain struct { - blocks map[chainhash.Hash]*wire.MsgBlock - blockIndex map[uint32]chainhash.Hash + blocks map[chainhash.Hash]*wire.MsgBlock + blockIndex map[uint32]chainhash.Hash + blockHeightIndex map[chainhash.Hash]uint32 utxos map[wire.OutPoint]wire.TxOut @@ -138,10 +141,11 @@ var _ lnwallet.BlockChainIO = (*mockChain)(nil) func newMockChain(currentHeight uint32) *mockChain { return &mockChain{ - bestHeight: int32(currentHeight), - blocks: make(map[chainhash.Hash]*wire.MsgBlock), - utxos: make(map[wire.OutPoint]wire.TxOut), - blockIndex: make(map[uint32]chainhash.Hash), + bestHeight: int32(currentHeight), + blocks: make(map[chainhash.Hash]*wire.MsgBlock), + utxos: make(map[wire.OutPoint]wire.TxOut), + blockIndex: make(map[uint32]chainhash.Hash), + blockHeightIndex: make(map[chainhash.Hash]uint32), } } @@ -209,8 +213,10 @@ func (m *mockChain) addBlock(block *wire.MsgBlock, height uint32, nonce uint32) hash := block.Header.BlockHash() m.blocks[hash] = block m.blockIndex[height] = hash + m.blockHeightIndex[hash] = height m.Unlock() } + func (m *mockChain) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) { m.RLock() defer m.RUnlock() @@ -226,8 +232,10 @@ func (m *mockChain) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) type mockChainView struct { sync.RWMutex - newBlocks chan *chainview.FilteredBlock - staleBlocks chan *chainview.FilteredBlock + newBlocks chan *chainview.FilteredBlock + staleBlocks chan *chainview.FilteredBlock + notifyBlockAck chan struct{} + notifyStaleBlockAck chan struct{} chain lnwallet.BlockChainIO @@ -269,7 +277,7 @@ func (m *mockChainView) UpdateFilter(ops []channeldb.EdgePoint, updateHeight uin } func (m *mockChainView) notifyBlock(hash chainhash.Hash, height uint32, - txns []*wire.MsgTx) { + txns []*wire.MsgTx, t *testing.T) { m.RLock() defer m.RUnlock() @@ -283,10 +291,23 @@ func (m *mockChainView) notifyBlock(hash chainhash.Hash, height uint32, case <-m.quit: return } + + // Do not ack the block if our notify channel is nil. + if m.notifyBlockAck == nil { + return + } + + select { + case m.notifyBlockAck <- struct{}{}: + case <-time.After(timeout): + t.Fatal("expected block to be delivered") + case <-m.quit: + return + } } func (m *mockChainView) notifyStaleBlock(hash chainhash.Hash, height uint32, - txns []*wire.MsgTx) { + txns []*wire.MsgTx, t *testing.T) { m.RLock() defer m.RUnlock() @@ -300,6 +321,19 @@ func (m *mockChainView) notifyStaleBlock(hash chainhash.Hash, height uint32, case <-m.quit: return } + + // Do not ack the block if our notify channel is nil. + if m.notifyStaleBlockAck == nil { + return + } + + select { + case m.notifyStaleBlockAck <- struct{}{}: + case <-time.After(timeout): + t.Fatal("expected stale block to be delivered") + case <-m.quit: + return + } } func (m *mockChainView) FilteredBlocks() <-chan *chainview.FilteredBlock { @@ -317,7 +351,14 @@ func (m *mockChainView) FilterBlock(blockHash *chainhash.Hash) (*chainview.Filte return nil, err } - filteredBlock := &chainview.FilteredBlock{} + chain := m.chain.(*mockChain) + + chain.Lock() + filteredBlock := &chainview.FilteredBlock{ + Hash: *blockHash, + Height: chain.blockHeightIndex[*blockHash], + } + chain.Unlock() for _, tx := range block.Transactions { for _, txIn := range tx.TxIn { prevOp := txIn.PreviousOutPoint @@ -895,7 +936,7 @@ func TestChannelCloseNotification(t *testing.T) { } ctx.chain.addBlock(newBlock, blockHeight, blockHeight) ctx.chainView.notifyBlock(newBlock.Header.BlockHash(), blockHeight, - newBlock.Transactions) + newBlock.Transactions, t) // The notification registered above should be sent, if not we'll time // out and mark the test as failed. diff --git a/routing/router_test.go b/routing/router_test.go index d7c14969a..b4d49d6c1 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -18,9 +18,11 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/htlcswitch" + lnmock "github.com/lightningnetwork/lnd/lntest/mock" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" @@ -44,6 +46,8 @@ type testCtx struct { chain *mockChain chainView *mockChainView + + notifier *lnmock.ChainNotifier } func (c *testCtx) getChannelIDFromAlias(t *testing.T, a, b string) uint64 { @@ -136,11 +140,18 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, MissionControl: mc, } + notifier := &lnmock.ChainNotifier{ + EpochChan: make(chan *chainntnfs.BlockEpoch), + SpendChan: make(chan *chainntnfs.SpendDetail), + ConfChan: make(chan *chainntnfs.TxConfirmation), + } + router, err := New(Config{ Graph: graphInstance.graph, Chain: chain, ChainView: chainView, Payer: &mockPaymentAttemptDispatcherOld{}, + Notifier: notifier, Control: makeMockControlTower(), MissionControl: mc, SessionSource: sessionSource, @@ -171,6 +182,7 @@ func createTestCtxFromGraphInstanceAssumeValid(t *testing.T, channelIDs: graphInstance.channelIDs, chain: chain, chainView: chainView, + notifier: notifier, } cleanUp := func() { @@ -1624,7 +1636,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // Give time to process new blocks @@ -1656,7 +1668,7 @@ func TestWakeUpOnStaleBranch(t *testing.T) { ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // Give time to process new blocks time.Sleep(time.Millisecond * 500) @@ -1833,7 +1845,7 @@ func TestDisconnectedBlocks(t *testing.T) { ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // Give time to process new blocks @@ -1867,7 +1879,7 @@ func TestDisconnectedBlocks(t *testing.T) { ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) ctx.chainView.notifyBlock(block.BlockHash(), height, - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // Give time to process new blocks time.Sleep(time.Millisecond * 500) @@ -1949,12 +1961,18 @@ func TestDisconnectedBlocks(t *testing.T) { // Create a 15 block fork. We first let the chainView notify the router // about stale blocks, before sending the now connected blocks. We do // this because we expect this order from the chainview. + ctx.chainView.notifyStaleBlockAck = make(chan struct{}, 1) for i := len(minorityChain) - 1; i >= 0; i-- { block := minorityChain[i] height := uint32(forkHeight) + uint32(i) + 1 ctx.chainView.notifyStaleBlock(block.BlockHash(), height, - block.Transactions) + block.Transactions, t) + <-ctx.chainView.notifyStaleBlockAck } + + time.Sleep(time.Second * 2) + + ctx.chainView.notifyBlockAck = make(chan struct{}, 1) for i := uint32(1); i <= 15; i++ { block := &wire.MsgBlock{ Transactions: []*wire.MsgTx{}, @@ -1963,10 +1981,10 @@ func TestDisconnectedBlocks(t *testing.T) { ctx.chain.addBlock(block, height, rand.Uint32()) ctx.chain.setBestBlock(int32(height)) ctx.chainView.notifyBlock(block.BlockHash(), height, - block.Transactions) + block.Transactions, t) + <-ctx.chainView.notifyBlockAck } - // Give time to process new blocks time.Sleep(time.Millisecond * 500) // chanID2 should not be in the database anymore, since it is not @@ -2022,7 +2040,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { ctx.chain.addBlock(block102, uint32(nextHeight), rand.Uint32()) ctx.chain.setBestBlock(int32(nextHeight)) ctx.chainView.notifyBlock(block102.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) // We'll now create the edges and nodes within the database required // for the ChannelRouter to properly recognize the channel we added @@ -2075,7 +2093,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) ctx.chain.setBestBlock(int32(nextHeight)) ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // At this point, our starting height should be 107. @@ -2117,7 +2135,7 @@ func TestRouterChansClosedOfflinePruneGraph(t *testing.T) { ctx.chain.addBlock(block, uint32(nextHeight), rand.Uint32()) ctx.chain.setBestBlock(int32(nextHeight)) ctx.chainView.notifyBlock(block.BlockHash(), uint32(nextHeight), - []*wire.MsgTx{}) + []*wire.MsgTx{}, t) } // At this point, our starting height should be 112. @@ -4249,3 +4267,72 @@ func TestSendMPPaymentFailedWithShardsInFlight(t *testing.T) { session.AssertExpectations(t) missionControl.AssertExpectations(t) } + +// TestBlockDifferenceFix tests if when the router is behind on blocks, the +// router catches up to the best block head. +func TestBlockDifferenceFix(t *testing.T) { + t.Parallel() + + initialBlockHeight := uint32(0) + // Starting height here is set to 0, which is behind where we want to be. + ctx, cleanup := createTestCtxSingleNode(t, initialBlockHeight) + defer cleanup() + + // Add initial block to our mini blockchain. + block := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + ctx.chain.addBlock(block, initialBlockHeight, rand.Uint32()) + + // Let's generate a new block of height 5, 5 above where our node is at. + newBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + } + newBlockHeight := uint32(5) + + blockDifference := newBlockHeight - initialBlockHeight + + ctx.chainView.notifyBlockAck = make(chan struct{}, 1) + + ctx.chain.addBlock(newBlock, newBlockHeight, rand.Uint32()) + ctx.chain.setBestBlock(int32(newBlockHeight)) + ctx.chainView.notifyBlock(block.BlockHash(), newBlockHeight, + []*wire.MsgTx{}, t) + + <-ctx.chainView.notifyBlockAck + + // At this point, the chain notifier should have noticed that we're + // behind on blocks, and will send the n missing blocks that we + // need to the client's epochs channel. Let's replicate this + // functionality. + for i := 0; i < int(blockDifference); i++ { + currBlockHeight := int32(i + 1) + + nonce := rand.Uint32() + + newBlock := &wire.MsgBlock{ + Transactions: []*wire.MsgTx{}, + Header: wire.BlockHeader{Nonce: nonce}, + } + ctx.chain.addBlock(newBlock, uint32(currBlockHeight), nonce) + currHash := newBlock.Header.BlockHash() + + newEpoch := &chainntnfs.BlockEpoch{ + Height: currBlockHeight, + Hash: &currHash, + } + + ctx.notifier.EpochChan <- newEpoch + + ctx.chainView.notifyBlock(currHash, + uint32(currBlockHeight), block.Transactions, t) + + <-ctx.chainView.notifyBlockAck + } + + // Then router height should be updated to the latest block. + if atomic.LoadUint32(&ctx.router.bestHeight) != newBlockHeight { + t.Fatalf("height should have been updated to %v, instead got "+ + "%v", newBlockHeight, ctx.router.bestHeight) + } +}