sweep: shorten storeRecord method signature

This commit shortens the function signature of `storeRecord`, also makes
sure we don't call `t.records.Store` directly but always using
`storeRecord` instead so it's easier to trace the record creation.
This commit is contained in:
yyforyongyu
2025-01-21 00:12:04 +08:00
parent c68b8e8c1e
commit bde5124e1b
2 changed files with 70 additions and 40 deletions

View File

@@ -440,6 +440,20 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
return requestID, record
}
// storeRecord stores the given record in the records map.
func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
req *BumpRequest, f FeeFunction) {
// Register the record.
t.records.Store(requestID, &monitorRecord{
tx: sweepCtx.tx,
req: req,
feeFunction: f,
fee: sweepCtx.fee,
outpointToTxIndex: sweepCtx.outpointToTxIndex,
})
}
// NOTE: part of the `chainio.Consumer` interface.
func (t *TxPublisher) Name() string {
return "TxPublisher"
@@ -508,10 +522,7 @@ func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
switch {
case err == nil:
// The tx is valid, store it.
t.storeRecord(
requestID, sweepCtx.tx, req, f, sweepCtx.fee,
sweepCtx.outpointToTxIndex,
)
t.storeRecord(requestID, sweepCtx, req, f)
log.Infof("Created initial sweep tx=%v for %v inputs: "+
"feerate=%v, fee=%v, inputs:\n%v",
@@ -565,21 +576,6 @@ func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
}
}
// storeRecord stores the given record in the records map.
func (t *TxPublisher) storeRecord(requestID uint64, tx *wire.MsgTx,
req *BumpRequest, f FeeFunction, fee btcutil.Amount,
outpointToTxIndex map[wire.OutPoint]int) {
// Register the record.
t.records.Store(requestID, &monitorRecord{
tx: tx,
req: req,
feeFunction: f,
fee: fee,
outpointToTxIndex: outpointToTxIndex,
})
}
// createAndCheckTx creates a tx based on the given inputs, change output
// script, and the fee rate. In addition, it validates the tx's mempool
// acceptance before returning a tx that can be published directly, along with
@@ -1195,13 +1191,7 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64,
// The tx has been created without any errors, we now register a new
// record by overwriting the same requestID.
t.records.Store(requestID, &monitorRecord{
tx: sweepCtx.tx,
req: r.req,
feeFunction: r.feeFunction,
fee: sweepCtx.fee,
outpointToTxIndex: sweepCtx.outpointToTxIndex,
})
t.storeRecord(requestID, sweepCtx, r.req, r.feeFunction)
// Attempt to broadcast this new tx.
result, err := t.broadcast(requestID)

View File

@@ -351,8 +351,15 @@ func TestStoreRecord(t *testing.T) {
op: 0,
}
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
// Call the method under test.
tp.storeRecord(initialCounter, tx, req, feeFunc, fee, utxoIndex)
tp.storeRecord(initialCounter, sweepCtx, req, feeFunc)
// Read the saved record and compare.
record, ok := tp.records.Load(initialCounter)
@@ -698,7 +705,15 @@ func TestTxPublisherBroadcast(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
// Quickly check when the requestID cannot be found, an error is
// returned.
@@ -796,6 +811,13 @@ func TestRemoveResult(t *testing.T) {
// Create a test request ID counter.
requestCounter := atomic.Uint64{}
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
testCases := []struct {
name string
setupRecord func() uint64
@@ -808,9 +830,7 @@ func TestRemoveResult(t *testing.T) {
name: "remove on TxConfirmed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
tp.subscriberChans.Store(rid, nil)
return rid
@@ -826,9 +846,7 @@ func TestRemoveResult(t *testing.T) {
name: "remove on TxFailed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
tp.subscriberChans.Store(rid, nil)
return rid
@@ -845,9 +863,7 @@ func TestRemoveResult(t *testing.T) {
name: "noop when tx is not confirmed or failed",
setupRecord: func() uint64 {
rid := requestCounter.Add(1)
tp.storeRecord(
rid, tx, req, m.feeFunc, fee, utxoIndex,
)
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
tp.subscriberChans.Store(rid, nil)
return rid
@@ -906,7 +922,15 @@ func TestNotifyResult(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)
@@ -1208,7 +1232,15 @@ func TestHandleTxConfirmed(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
record, ok := tp.records.Load(requestID)
require.True(t, ok)
@@ -1289,7 +1321,15 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := uint64(1)
tp.storeRecord(requestID, tx, req, m.feeFunc, fee, utxoIndex)
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
tx: tx,
fee: fee,
outpointToTxIndex: utxoIndex,
}
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
// Create a subscription to the event.
subscriber := make(chan *BumpResult, 1)