sweep: add handleInitialBroadcast to handle initial broadcast

This commit adds a new method `handleInitialBroadcast` to handle the
initial broadcast. Previously we'd broadcast immediately inside
`Broadcast`, which soon will not work after the `blockbeat` is
implemented as the action to publish is now always triggered by a new
block. Meanwhile, we still keep the option to bypass the block trigger
so users can broadcast immediately by setting `Immediate` to true.
This commit is contained in:
yyforyongyu
2024-10-25 18:31:46 +08:00
parent 2479dc7f2e
commit 77ff2c0585
2 changed files with 370 additions and 164 deletions

View File

@@ -352,13 +352,10 @@ func TestStoreRecord(t *testing.T) {
}
// Call the method under test.
requestID := tp.storeRecord(tx, req, feeFunc, fee, utxoIndex)
// Check the request ID is as expected.
require.Equal(t, initialCounter+1, requestID)
tp.storeRecord(initialCounter, tx, req, feeFunc, fee, utxoIndex)
// Read the saved record and compare.
record, ok := tp.records.Load(requestID)
record, ok := tp.records.Load(initialCounter)
require.True(t, ok)
require.Equal(t, tx, record.tx)
require.Equal(t, feeFunc, record.feeFunction)
@@ -655,23 +652,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
},
}
var requestCounter atomic.Uint64
for _, tc := range testCases {
tc := tc
rid := requestCounter.Add(1)
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()
// Call the method under test.
id, err := tp.createRBFCompliantTx(req, m.feeFunc)
err := tp.createRBFCompliantTx(rid, req, m.feeFunc)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
// If there's an error, expect the requestID to be
// empty.
if tc.expectedErr != nil {
require.Zero(t, id)
}
})
}
}
@@ -704,7 +697,8 @@ func TestTxPublisherBroadcast(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Quickly check when the requestID cannot be found, an error is
// returned.
@@ -799,6 +793,9 @@ func TestRemoveResult(t *testing.T) {
op: 0,
}
// Create a test request ID counter.
requestCounter := atomic.Uint64{}
testCases := []struct {
name string
setupRecord func() uint64
@@ -810,12 +807,13 @@ func TestRemoveResult(t *testing.T) {
// removed.
name: "remove on TxConfirmed",
setupRecord: func() uint64 {
id := tp.storeRecord(
tx, req, m.feeFunc, fee, utxoIndex,
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.subscriberChans.Store(id, nil)
tp.subscriberChans.Store(rid, nil)
return id
return rid
},
result: &BumpResult{
Event: TxConfirmed,
@@ -827,12 +825,13 @@ func TestRemoveResult(t *testing.T) {
// When the tx is failed, the records will be removed.
name: "remove on TxFailed",
setupRecord: func() uint64 {
id := tp.storeRecord(
tx, req, m.feeFunc, fee, utxoIndex,
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.subscriberChans.Store(id, nil)
tp.subscriberChans.Store(rid, nil)
return id
return rid
},
result: &BumpResult{
Event: TxFailed,
@@ -845,12 +844,13 @@ func TestRemoveResult(t *testing.T) {
// Noop when the tx is neither confirmed or failed.
name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 {
id := tp.storeRecord(
tx, req, m.feeFunc, fee, utxoIndex,
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.subscriberChans.Store(id, nil)
tp.subscriberChans.Store(rid, nil)
return id
return rid
},
result: &BumpResult{
Event: TxPublished,
@@ -905,7 +905,8 @@ func TestNotifyResult(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
@@ -953,41 +954,17 @@ func TestNotifyResult(t *testing.T) {
}
}
// TestBroadcastSuccess checks the public `Broadcast` method can successfully
// broadcast a tx based on the request.
func TestBroadcastSuccess(t *testing.T) {
// TestBroadcast checks the public `Broadcast` method can successfully register
// a broadcast request.
func TestBroadcast(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
tp, _ := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(nil).Once()
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
@@ -1003,25 +980,23 @@ func TestBroadcastSuccess(t *testing.T) {
// Send the req and expect no error.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxPublished.
require.Equal(t, TxPublished, result.Event)
}
require.NotNil(t, resultChan)
// Validate the record was stored.
require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len())
// Validate the record.
rid := tp.requestCounter.Load()
record, found := tp.records.Load(rid)
require.True(t, found)
require.Equal(t, req, record.req)
}
// TestBroadcastFail checks the public `Broadcast` returns the error or a
// failed result when the broadcast fails.
func TestBroadcastFail(t *testing.T) {
// TestBroadcastImmediate checks the public `Broadcast` method can successfully
// register a broadcast request and publish the tx when `Immediate` flag is
// set.
func TestBroadcastImmediate(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
@@ -1040,64 +1015,28 @@ func TestBroadcastFail(t *testing.T) {
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
Immediate: true,
}
// Mock the fee estimator to return the testing fee rate.
// Mock the fee estimator to return an error.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
// NOTE: We are not testing `handleInitialBroadcast` here, but only
// interested in checking that this method is indeed called when
// `Immediate` is true. Thus we mock the method to return an error to
// quickly abort. As long as this mocked method is called, we know the
// `Immediate` flag works.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Twice()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
chainfee.SatPerKWeight(0), errDummy).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to return an error.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
// Send the req and expect an error returned.
// Send the req and expect no error.
resultChan, err := tp.Broadcast(req)
require.ErrorIs(t, err, errDummy)
require.Nil(t, resultChan)
// Validate the record was NOT stored.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Mock the testmempoolaccept again, this time it passes.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to fail on publish.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(errDummy).Once()
// Send the req and expect no error returned.
resultChan, err = tp.Broadcast(req)
require.NoError(t, err)
require.NotNil(t, resultChan)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the result to be TxFailed and the error is set in
// the result.
require.Equal(t, TxFailed, result.Event)
require.ErrorIs(t, result.Err, errDummy)
}
// Validate the record was removed.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Validate the record was removed due to an error returned in initial
// broadcast.
require.Empty(t, tp.records.Len())
require.Empty(t, tp.subscriberChans.Len())
}
// TestCreateAnPublishFail checks all the error cases are handled properly in
@@ -1270,7 +1209,8 @@ func TestHandleTxConfirmed(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
record, ok := tp.records.Load(requestID)
require.True(t, ok)
@@ -1350,7 +1290,8 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := tp.storeRecord(tx, req, m.feeFunc, fee, utxoIndex)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
@@ -1551,3 +1492,186 @@ func TestProcessRecords(t *testing.T) {
require.Equal(t, requestID2, result.requestID)
}
}
// TestHandleInitialBroadcastSuccess checks `handleInitialBroadcast` method can
// successfully broadcast a tx based on the request.
func TestHandleInitialBroadcastSuccess(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Once()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Once()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to pass.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to publish successfully.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(nil).Once()
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
}
// Register the testing record use `Broadcast`.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid := tp.requestCounter.Load()
rec, ok := tp.records.Load(rid)
require.True(t, ok)
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxPublished.
require.Equal(t, TxPublished, result.Event)
}
// Validate the record was stored.
require.Equal(t, 1, tp.records.Len())
require.Equal(t, 1, tp.subscriberChans.Len())
}
// TestHandleInitialBroadcastFail checks `handleInitialBroadcast` returns the
// error or a failed result when the broadcast fails.
func TestHandleInitialBroadcastFail(t *testing.T) {
t.Parallel()
// Create a publisher using the mocks.
tp, m := createTestPublisher(t)
// Create a test feerate.
feerate := chainfee.SatPerKWeight(1000)
// Create a test request.
inp := createTestInput(1000, input.WitnessKeyHash)
// Create a testing bump request.
req := &BumpRequest{
DeliveryAddress: changePkScript,
Inputs: []input.Input{&inp},
Budget: btcutil.Amount(1000),
MaxFeeRate: feerate * 10,
DeadlineHeight: 10,
}
// Mock the fee estimator to return the testing fee rate.
//
// We are not testing `NewLinearFeeFunction` here, so the actual params
// used are irrelevant.
m.estimator.On("EstimateFeePerKW", mock.Anything).Return(
feerate, nil).Twice()
m.estimator.On("RelayFeePerKW").Return(chainfee.FeePerKwFloor).Twice()
// Mock the signer to always return a valid script.
//
// NOTE: we are not testing the utility of creating valid txes here, so
// this is fine to be mocked. This behaves essentially as skipping the
// Signer check and alaways assume the tx has a valid sig.
script := &input.Script{}
m.signer.On("ComputeInputScript", mock.Anything,
mock.Anything).Return(script, nil)
// Mock the testmempoolaccept to return an error.
m.wallet.On("CheckMempoolAcceptance",
mock.Anything).Return(errDummy).Once()
// Register the testing record use `Broadcast`.
resultChan, err := tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid := tp.requestCounter.Load()
rec, ok := tp.records.Load(rid)
require.True(t, ok)
// Call the method under test and expect an error returned.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the first result to be TxFatal.
require.Equal(t, TxFatal, result.Event)
}
// Validate the record was NOT stored.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
// Mock the testmempoolaccept again, this time it passes.
m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once()
// Mock the wallet to fail on publish.
m.wallet.On("PublishTransaction",
mock.Anything, mock.Anything).Return(errDummy).Once()
// Register the testing record use `Broadcast`.
resultChan, err = tp.Broadcast(req)
require.NoError(t, err)
// Grab the monitor record from the map.
rid = tp.requestCounter.Load()
rec, ok = tp.records.Load(rid)
require.True(t, ok)
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
// Check the result is sent back.
select {
case <-time.After(time.Second):
t.Fatal("timeout waiting for subscriber to receive result")
case result := <-resultChan:
// We expect the result to be TxFailed and the error is set in
// the result.
require.Equal(t, TxFailed, result.Event)
require.ErrorIs(t, result.Err, errDummy)
}
// Validate the record was removed.
require.Equal(t, 0, tp.records.Len())
require.Equal(t, 0, tp.subscriberChans.Len())
}