blockcache: fix datarace in blockcache_test

mockChainBackend simulates a chain backend and
tracks the number of calls via a private property
called chainCallCount.  This commit uses a write
mutex instead of read mutex in the call to GetBlock,
adds a getter method to access chainCallCount
with a read mutex and uses a write mutex in
resetChainCallCount.
This commit is contained in:
Michael Street
2022-10-12 20:49:51 -04:00
parent b387e2c718
commit 46444d9154

View File

@@ -20,16 +20,18 @@ type mockChainBackend struct {
sync.RWMutex sync.RWMutex
} }
func (m *mockChainBackend) addBlock(block *wire.MsgBlock, nonce uint32) { func newMockChain() *mockChainBackend {
return &mockChainBackend{
blocks: make(map[chainhash.Hash]*wire.MsgBlock),
}
}
// GetBlock is a mock implementation of block fetching that tracks the number
// of backend calls and returns the block found for the given hash or an error.
func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
m.Lock() m.Lock()
defer m.Unlock() defer m.Unlock()
block.Header.Nonce = nonce
hash := block.Header.BlockHash()
m.blocks[hash] = block
}
func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock, error) {
m.RLock()
defer m.RUnlock()
m.chainCallCount++ m.chainCallCount++
block, ok := m.blocks[*blockHash] block, ok := m.blocks[*blockHash]
@@ -40,15 +42,25 @@ func (m *mockChainBackend) GetBlock(blockHash *chainhash.Hash) (*wire.MsgBlock,
return block, nil return block, nil
} }
func newMockChain() *mockChainBackend { func (m *mockChainBackend) getChainCallCount() int {
return &mockChainBackend{ m.RLock()
blocks: make(map[chainhash.Hash]*wire.MsgBlock), defer m.RUnlock()
}
return m.chainCallCount
}
func (m *mockChainBackend) addBlock(block *wire.MsgBlock, nonce uint32) {
m.Lock()
defer m.Unlock()
block.Header.Nonce = nonce
hash := block.Header.BlockHash()
m.blocks[hash] = block
} }
func (m *mockChainBackend) resetChainCallCount() { func (m *mockChainBackend) resetChainCallCount() {
m.RLock() m.Lock()
defer m.RUnlock() defer m.Unlock()
m.chainCallCount = 0 m.chainCallCount = 0
} }
@@ -90,7 +102,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
_, err := bc.GetBlock(&blockhash1, getBlockImpl) _, err := bc.GetBlock(&blockhash1, getBlockImpl)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 1, bc.Cache.Len()) require.Equal(t, 1, bc.Cache.Len())
require.Equal(t, 1, mc.chainCallCount) require.Equal(t, 1, mc.getChainCallCount())
mc.resetChainCallCount() mc.resetChainCallCount()
_, err = bc.Cache.Get(*inv1) _, err = bc.Cache.Get(*inv1)
@@ -102,7 +114,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
_, err = bc.GetBlock(&blockhash2, getBlockImpl) _, err = bc.GetBlock(&blockhash2, getBlockImpl)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, bc.Cache.Len()) require.Equal(t, 2, bc.Cache.Len())
require.Equal(t, 1, mc.chainCallCount) require.Equal(t, 1, mc.getChainCallCount())
mc.resetChainCallCount() mc.resetChainCallCount()
_, err = bc.Cache.Get(*inv1) _, err = bc.Cache.Get(*inv1)
@@ -117,7 +129,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
_, err = bc.GetBlock(&blockhash1, getBlockImpl) _, err = bc.GetBlock(&blockhash1, getBlockImpl)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, bc.Cache.Len()) require.Equal(t, 2, bc.Cache.Len())
require.Equal(t, 0, mc.chainCallCount) require.Equal(t, 0, mc.getChainCallCount())
mc.resetChainCallCount() mc.resetChainCallCount()
// Since the Cache is now at its max capacity, it is expected that when // Since the Cache is now at its max capacity, it is expected that when
@@ -128,7 +140,7 @@ func TestBlockCacheGetBlock(t *testing.T) {
_, err = bc.GetBlock(&blockhash3, getBlockImpl) _, err = bc.GetBlock(&blockhash3, getBlockImpl)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 2, bc.Cache.Len()) require.Equal(t, 2, bc.Cache.Len())
require.Equal(t, 1, mc.chainCallCount) require.Equal(t, 1, mc.getChainCallCount())
mc.resetChainCallCount() mc.resetChainCallCount()
_, err = bc.Cache.Get(*inv1) _, err = bc.Cache.Get(*inv1)
@@ -183,5 +195,5 @@ func TestBlockCacheMutexes(t *testing.T) {
} }
wg.Wait() wg.Wait()
require.Equal(t, 2, mc.chainCallCount) require.Equal(t, 2, mc.getChainCallCount())
} }