mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 01:33:02 +01:00
sweep: refactor storeRecord
to updateRecord
To make it clear we are only updating fields, which will be handy for the following commit where we start tracking for spending notifications.
This commit is contained in:
parent
7eea7a7e9a
commit
e5f39dd644
@ -441,19 +441,19 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
|
||||
return record
|
||||
}
|
||||
|
||||
// storeRecord stores the given record in the records map.
|
||||
func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
|
||||
req *BumpRequest, f FeeFunction) {
|
||||
// updateRecord updates the given record's tx and fee, and saves it in the
|
||||
// records map.
|
||||
func (t *TxPublisher) updateRecord(r *monitorRecord,
|
||||
sweepCtx *sweepTxCtx) *monitorRecord {
|
||||
|
||||
r.tx = sweepCtx.tx
|
||||
r.fee = sweepCtx.fee
|
||||
r.outpointToTxIndex = sweepCtx.outpointToTxIndex
|
||||
|
||||
// Register the record.
|
||||
t.records.Store(requestID, &monitorRecord{
|
||||
requestID: requestID,
|
||||
tx: sweepCtx.tx,
|
||||
req: req,
|
||||
feeFunction: f,
|
||||
fee: sweepCtx.fee,
|
||||
outpointToTxIndex: sweepCtx.outpointToTxIndex,
|
||||
})
|
||||
t.records.Store(r.requestID, r)
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// NOTE: part of the `chainio.Consumer` interface.
|
||||
@ -463,11 +463,11 @@ func (t *TxPublisher) Name() string {
|
||||
|
||||
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
|
||||
// succeeded, the initial tx is stored in the records map.
|
||||
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
|
||||
func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) {
|
||||
// Create a fee bumping algorithm to be used for future RBF.
|
||||
feeAlgo, err := t.initializeFeeFunction(r.req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("init fee function: %w", err)
|
||||
return nil, fmt.Errorf("init fee function: %w", err)
|
||||
}
|
||||
|
||||
// Attach the newly created fee function.
|
||||
@ -481,12 +481,12 @@ func (t *TxPublisher) initializeTx(r *monitorRecord) error {
|
||||
|
||||
// Create the initial tx to be broadcasted. This tx is guaranteed to
|
||||
// comply with the RBF restrictions.
|
||||
err = t.createRBFCompliantTx(r)
|
||||
record, err := t.createRBFCompliantTx(r)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create RBF-compliant tx: %w", err)
|
||||
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// initializeFeeFunction initializes a fee function to be used for this request
|
||||
@ -522,7 +522,9 @@ func (t *TxPublisher) initializeFeeFunction(
|
||||
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
|
||||
// and redo the process until the tx is valid, or return an error when non-RBF
|
||||
// related errors occur or the budget has been used up.
|
||||
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
||||
func (t *TxPublisher) createRBFCompliantTx(
|
||||
r *monitorRecord) (*monitorRecord, error) {
|
||||
|
||||
f := r.feeFunction
|
||||
|
||||
for {
|
||||
@ -533,7 +535,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
||||
switch {
|
||||
case err == nil:
|
||||
// The tx is valid, store it.
|
||||
t.storeRecord(r.requestID, sweepCtx, r.req, f)
|
||||
record := t.updateRecord(r, sweepCtx)
|
||||
|
||||
log.Infof("Created initial sweep tx=%v for %v inputs: "+
|
||||
"feerate=%v, fee=%v, inputs:\n%v",
|
||||
@ -541,7 +543,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
||||
f.FeeRate(), sweepCtx.fee,
|
||||
inputTypeSummary(r.req.Inputs))
|
||||
|
||||
return nil
|
||||
return record, nil
|
||||
|
||||
// If the error indicates the fees paid is not enough, we will
|
||||
// ask the fee function to increase the fee rate and retry.
|
||||
@ -572,7 +574,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
||||
// cluster these inputs differetly.
|
||||
increased, err = f.Increment()
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -582,7 +584,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
||||
// mempool acceptance.
|
||||
default:
|
||||
log.Debugf("Failed to create RBF-compliant tx: %v", err)
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -645,13 +647,7 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest,
|
||||
// the event channel to the record. Any broadcast-related errors will not be
|
||||
// returned here, instead, they will be put inside the `BumpResult` and
|
||||
// returned to the caller.
|
||||
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
|
||||
// Get the record being monitored.
|
||||
record, ok := t.records.Load(requestID)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("tx record %v not found", requestID)
|
||||
}
|
||||
|
||||
func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) {
|
||||
txid := record.tx.TxHash()
|
||||
|
||||
tx := record.tx
|
||||
@ -698,7 +694,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
|
||||
Fee: record.fee,
|
||||
FeeRate: record.feeFunction.FeeRate(),
|
||||
Err: err,
|
||||
requestID: requestID,
|
||||
requestID: record.requestID,
|
||||
}
|
||||
|
||||
return result, nil
|
||||
@ -1043,7 +1039,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
|
||||
// RBF rules.
|
||||
//
|
||||
// Create the initial tx to be broadcasted.
|
||||
err = t.initializeTx(r)
|
||||
record, err := t.initializeTx(r)
|
||||
if err != nil {
|
||||
log.Errorf("Initial broadcast failed: %v", err)
|
||||
|
||||
@ -1054,7 +1050,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
|
||||
}
|
||||
|
||||
// Successfully created the first tx, now broadcast it.
|
||||
result, err = t.broadcast(r.requestID)
|
||||
result, err = t.broadcast(record)
|
||||
if err != nil {
|
||||
// The broadcast failed, which can only happen if the tx record
|
||||
// cannot be found or the aux sweeper returns an error. In
|
||||
@ -1199,10 +1195,10 @@ func (t *TxPublisher) createAndPublishTx(
|
||||
|
||||
// The tx has been created without any errors, we now register a new
|
||||
// record by overwriting the same requestID.
|
||||
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)
|
||||
record := t.updateRecord(r, sweepCtx)
|
||||
|
||||
// Attempt to broadcast this new tx.
|
||||
result, err := t.broadcast(r.requestID)
|
||||
result, err := t.broadcast(record)
|
||||
if err != nil {
|
||||
log.Infof("Failed to broadcast replacement tx %v: %v",
|
||||
sweepCtx.tx.TxHash(), err)
|
||||
|
@ -313,9 +313,9 @@ func TestInitializeFeeFunction(t *testing.T) {
|
||||
require.Equal(t, feerate, f.FeeRate())
|
||||
}
|
||||
|
||||
// TestStoreRecord correctly increases the request counter and saves the
|
||||
// TestUpdateRecord correctly updates the fields fee and tx, and saves the
|
||||
// record.
|
||||
func TestStoreRecord(t *testing.T) {
|
||||
func TestUpdateRecord(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test input.
|
||||
@ -358,8 +358,15 @@ func TestStoreRecord(t *testing.T) {
|
||||
outpointToTxIndex: utxoIndex,
|
||||
}
|
||||
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: initialCounter,
|
||||
req: req,
|
||||
feeFunction: feeFunc,
|
||||
}
|
||||
|
||||
// Call the method under test.
|
||||
tp.storeRecord(initialCounter, sweepCtx, req, feeFunc)
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
|
||||
// Read the saved record and compare.
|
||||
record, ok := tp.records.Load(initialCounter)
|
||||
@ -676,10 +683,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
|
||||
tc.setupMock()
|
||||
|
||||
// Call the method under test.
|
||||
err := tp.createRBFCompliantTx(record)
|
||||
rec, err := tp.createRBFCompliantTx(record)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
|
||||
if tc.expectedErr != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Assert the returned record has the following fields
|
||||
// populated.
|
||||
require.NotEmpty(t, rec.tx)
|
||||
require.NotEmpty(t, rec.fee)
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -721,13 +737,13 @@ func TestTxPublisherBroadcast(t *testing.T) {
|
||||
outpointToTxIndex: utxoIndex,
|
||||
}
|
||||
|
||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
||||
|
||||
// Quickly check when the requestID cannot be found, an error is
|
||||
// returned.
|
||||
result, err := tp.broadcast(uint64(1000))
|
||||
require.Error(t, err)
|
||||
require.Nil(t, result)
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: requestID,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
rec := tp.updateRecord(record, sweepCtx)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
@ -782,7 +798,7 @@ func TestTxPublisherBroadcast(t *testing.T) {
|
||||
tc.setupMock()
|
||||
|
||||
// Call the method under test.
|
||||
result, err := tp.broadcast(requestID)
|
||||
result, err := tp.broadcast(rec)
|
||||
|
||||
// Check the result is as expected.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
@ -838,7 +854,15 @@ func TestRemoveResult(t *testing.T) {
|
||||
name: "remove on TxConfirmed",
|
||||
setupRecord: func() uint64 {
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
||||
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: rid,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return rid
|
||||
@ -854,7 +878,15 @@ func TestRemoveResult(t *testing.T) {
|
||||
name: "remove on TxFailed",
|
||||
setupRecord: func() uint64 {
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
||||
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: rid,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return rid
|
||||
@ -871,7 +903,15 @@ func TestRemoveResult(t *testing.T) {
|
||||
name: "noop when tx is not confirmed or failed",
|
||||
setupRecord: func() uint64 {
|
||||
rid := requestCounter.Add(1)
|
||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
||||
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: rid,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
tp.subscriberChans.Store(rid, nil)
|
||||
|
||||
return rid
|
||||
@ -937,8 +977,14 @@ func TestNotifyResult(t *testing.T) {
|
||||
fee: fee,
|
||||
outpointToTxIndex: utxoIndex,
|
||||
}
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: requestID,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
|
||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
|
||||
// Create a subscription to the event.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
@ -1250,7 +1296,14 @@ func TestHandleTxConfirmed(t *testing.T) {
|
||||
outpointToTxIndex: utxoIndex,
|
||||
}
|
||||
|
||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
||||
// Create a test record.
|
||||
record := &monitorRecord{
|
||||
requestID: requestID,
|
||||
req: req,
|
||||
feeFunction: m.feeFunc,
|
||||
}
|
||||
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
record, ok := tp.records.Load(requestID)
|
||||
require.True(t, ok)
|
||||
|
||||
@ -1340,7 +1393,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
|
||||
outpointToTxIndex: utxoIndex,
|
||||
}
|
||||
|
||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
||||
tp.updateRecord(record, sweepCtx)
|
||||
|
||||
// Create a subscription to the event.
|
||||
subscriber := make(chan *BumpResult, 1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user