diff --git a/lnwallet/chainfee/estimator_test.go b/lnwallet/chainfee/estimator_test.go index f962df5ec..fc16c9b12 100644 --- a/lnwallet/chainfee/estimator_test.go +++ b/lnwallet/chainfee/estimator_test.go @@ -9,15 +9,6 @@ import ( "github.com/stretchr/testify/require" ) -type mockSparseConfFeeSource struct { - url string - fees map[uint32]uint32 -} - -func (e mockSparseConfFeeSource) GetFeeMap() (map[uint32]uint32, error) { - return e.fees, nil -} - // TestFeeRateTypes checks that converting fee rates between the // different types that represent fee rates and calculating fees // work as expected. @@ -204,10 +195,9 @@ func TestWebAPIFeeEstimator(t *testing.T) { maxTarget: minFeeRate, } - feeSource := mockSparseConfFeeSource{ - url: "https://www.github.com", - fees: feeRateResp, - } + // Create a mock fee source and mock its returned map. + feeSource := &mockFeeSource{} + feeSource.On("GetFeeMap").Return(feeRateResp, nil) estimator := NewWebAPIEstimator(feeSource, false) @@ -238,12 +228,15 @@ func TestWebAPIFeeEstimator(t *testing.T) { exp := SatPerKVByte(tc.expectedFeeRate).FeePerKWeight() require.Equalf(t, exp, est, "target %v failed, fee "+ - "map is %v", tc.target, feeSource.fees) + "map is %v", tc.target, feeRateResp) }) } // Stop the estimator when test ends. require.NoError(t, estimator.Stop(), "unable to stop fee estimator") + + // Assert the mocked fee source is called as expected. + feeSource.AssertExpectations(t) } // TestGetCachedFee checks that the fee caching logic works as expected. diff --git a/lnwallet/chainfee/mocks.go b/lnwallet/chainfee/mocks.go new file mode 100644 index 000000000..03d40e11e --- /dev/null +++ b/lnwallet/chainfee/mocks.go @@ -0,0 +1,17 @@ +package chainfee + +import "github.com/stretchr/testify/mock" + +type mockFeeSource struct { + mock.Mock +} + +// A compile-time assertion to ensure that mockFeeSource implements the +// WebAPIFeeSource interface. +var _ WebAPIFeeSource = (*mockFeeSource)(nil) + +func (m *mockFeeSource) GetFeeMap() (map[uint32]uint32, error) { + args := m.Called() + + return args.Get(0).(map[uint32]uint32), args.Error(1) +}