mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-11 03:25:21 +02:00
blockcache: fix datarace in blockcache_test
mockChainBackend simulates a chain backend and tracks the number of calls via a private property called chainCallCount. This commit uses a write mutex instead of read mutex in the call to GetBlock, adds a getter method to access chainCallCount with a read mutex and uses a write mutex in resetChainCallCount.
This commit is contained in:
@@ -20,16 +20,18 @@ type mockChainBackend struct {
|
|||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockChainBackend) addBlock(block *wire.MsgBlock, nonce uint32) {
|
func newMockChain() *mockChainBackend {
|
||||||
|
return &mockChainBackend{
|
||||||
|
blocks: make(map[chainhash.Hash]*wire.MsgBlock),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlock is a mock implementation of block fetching that tracks the number
|
||||||
|
// of backend calls and returns the block found for the given hash or an error.
|
||||||
|
func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
|
||||||
m.Lock()
|
m.Lock()
|
||||||
defer m.Unlock()
|
defer m.Unlock()
|
||||||
block.Header.Nonce = nonce
|
|
||||||
hash := block.Header.BlockHash()
|
|
||||||
m.blocks[hash] = block
|
|
||||||
}
|
|
||||||
func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
|
|
||||||
m.RLock()
|
|
||||||
defer m.RUnlock()
|
|
||||||
m.chainCallCount++
|
m.chainCallCount++
|
||||||
|
|
||||||
block, ok := m.blocks[*blockHash]
|
block, ok := m.blocks[*blockHash]
|
||||||
@@ -40,15 +42,25 @@ func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock,
|
|||||||
return block, nil
|
return block, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func newMockChain() *mockChainBackend {
|
func (m *mockChainBackend) getChainCallCount() int {
|
||||||
return &mockChainBackend{
|
m.RLock()
|
||||||
blocks: make(map[chainhash.Hash]*wire.MsgBlock),
|
defer m.RUnlock()
|
||||||
}
|
|
||||||
|
return m.chainCallCount
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockChainBackend) addBlock(block *wire.MsgBlock, nonce uint32) {
|
||||||
|
m.Lock()
|
||||||
|
defer m.Unlock()
|
||||||
|
|
||||||
|
block.Header.Nonce = nonce
|
||||||
|
hash := block.Header.BlockHash()
|
||||||
|
m.blocks[hash] = block
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *mockChainBackend) resetChainCallCount() {
|
func (m *mockChainBackend) resetChainCallCount() {
|
||||||
m.RLock()
|
m.Lock()
|
||||||
defer m.RUnlock()
|
defer m.Unlock()
|
||||||
|
|
||||||
m.chainCallCount = 0
|
m.chainCallCount = 0
|
||||||
}
|
}
|
||||||
@@ -90,7 +102,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
|
|||||||
_, err := bc.GetBlock(&blockhash1, getBlockImpl)
|
_, err := bc.GetBlock(&blockhash1, getBlockImpl)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 1, bc.Cache.Len())
|
require.Equal(t, 1, bc.Cache.Len())
|
||||||
require.Equal(t, 1, mc.chainCallCount)
|
require.Equal(t, 1, mc.getChainCallCount())
|
||||||
mc.resetChainCallCount()
|
mc.resetChainCallCount()
|
||||||
|
|
||||||
_, err = bc.Cache.Get(*inv1)
|
_, err = bc.Cache.Get(*inv1)
|
||||||
@@ -102,7 +114,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
|
|||||||
_, err = bc.GetBlock(&blockhash2, getBlockImpl)
|
_, err = bc.GetBlock(&blockhash2, getBlockImpl)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 2, bc.Cache.Len())
|
require.Equal(t, 2, bc.Cache.Len())
|
||||||
require.Equal(t, 1, mc.chainCallCount)
|
require.Equal(t, 1, mc.getChainCallCount())
|
||||||
mc.resetChainCallCount()
|
mc.resetChainCallCount()
|
||||||
|
|
||||||
_, err = bc.Cache.Get(*inv1)
|
_, err = bc.Cache.Get(*inv1)
|
||||||
@@ -117,7 +129,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
|
|||||||
_, err = bc.GetBlock(&blockhash1, getBlockImpl)
|
_, err = bc.GetBlock(&blockhash1, getBlockImpl)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 2, bc.Cache.Len())
|
require.Equal(t, 2, bc.Cache.Len())
|
||||||
require.Equal(t, 0, mc.chainCallCount)
|
require.Equal(t, 0, mc.getChainCallCount())
|
||||||
mc.resetChainCallCount()
|
mc.resetChainCallCount()
|
||||||
|
|
||||||
// Since the Cache is now at its max capacity, it is expected that when
|
// Since the Cache is now at its max capacity, it is expected that when
|
||||||
@@ -128,7 +140,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
|
|||||||
_, err = bc.GetBlock(&blockhash3, getBlockImpl)
|
_, err = bc.GetBlock(&blockhash3, getBlockImpl)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.Equal(t, 2, bc.Cache.Len())
|
require.Equal(t, 2, bc.Cache.Len())
|
||||||
require.Equal(t, 1, mc.chainCallCount)
|
require.Equal(t, 1, mc.getChainCallCount())
|
||||||
mc.resetChainCallCount()
|
mc.resetChainCallCount()
|
||||||
|
|
||||||
_, err = bc.Cache.Get(*inv1)
|
_, err = bc.Cache.Get(*inv1)
|
||||||
@@ -183,5 +195,5 @@ func TestBlockCacheMutexes(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
require.Equal(t, 2, mc.chainCallCount)
|
require.Equal(t, 2, mc.getChainCallCount())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user