mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-29 03:01:52 +01:00
sweep: refactor storeRecord
to updateRecord
To make it clear we are only updating fields, which will be handy for the following commit where we start tracking for spending notifications.
This commit is contained in:
parent
7eea7a7e9a
commit
e5f39dd644
@ -441,19 +441,19 @@ func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
|
|||||||
return record
|
return record
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeRecord stores the given record in the records map.
|
// updateRecord updates the given record's tx and fee, and saves it in the
|
||||||
func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
|
// records map.
|
||||||
req *BumpRequest, f FeeFunction) {
|
func (t *TxPublisher) updateRecord(r *monitorRecord,
|
||||||
|
sweepCtx *sweepTxCtx) *monitorRecord {
|
||||||
|
|
||||||
|
r.tx = sweepCtx.tx
|
||||||
|
r.fee = sweepCtx.fee
|
||||||
|
r.outpointToTxIndex = sweepCtx.outpointToTxIndex
|
||||||
|
|
||||||
// Register the record.
|
// Register the record.
|
||||||
t.records.Store(requestID, &monitorRecord{
|
t.records.Store(r.requestID, r)
|
||||||
requestID: requestID,
|
|
||||||
tx: sweepCtx.tx,
|
return r
|
||||||
req: req,
|
|
||||||
feeFunction: f,
|
|
||||||
fee: sweepCtx.fee,
|
|
||||||
outpointToTxIndex: sweepCtx.outpointToTxIndex,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOTE: part of the `chainio.Consumer` interface.
|
// NOTE: part of the `chainio.Consumer` interface.
|
||||||
@ -463,11 +463,11 @@ func (t *TxPublisher) Name() string {
|
|||||||
|
|
||||||
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
|
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
|
||||||
// succeeded, the initial tx is stored in the records map.
|
// succeeded, the initial tx is stored in the records map.
|
||||||
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
|
func (t *TxPublisher) initializeTx(r *monitorRecord) (*monitorRecord, error) {
|
||||||
// Create a fee bumping algorithm to be used for future RBF.
|
// Create a fee bumping algorithm to be used for future RBF.
|
||||||
feeAlgo, err := t.initializeFeeFunction(r.req)
|
feeAlgo, err := t.initializeFeeFunction(r.req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("init fee function: %w", err)
|
return nil, fmt.Errorf("init fee function: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attach the newly created fee function.
|
// Attach the newly created fee function.
|
||||||
@ -481,12 +481,12 @@ func (t *TxPublisher) initializeTx(r *monitorRecord) error {
|
|||||||
|
|
||||||
// Create the initial tx to be broadcasted. This tx is guaranteed to
|
// Create the initial tx to be broadcasted. This tx is guaranteed to
|
||||||
// comply with the RBF restrictions.
|
// comply with the RBF restrictions.
|
||||||
err = t.createRBFCompliantTx(r)
|
record, err := t.createRBFCompliantTx(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create RBF-compliant tx: %w", err)
|
return nil, fmt.Errorf("create RBF-compliant tx: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return record, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// initializeFeeFunction initializes a fee function to be used for this request
|
// initializeFeeFunction initializes a fee function to be used for this request
|
||||||
@ -522,7 +522,9 @@ func (t *TxPublisher) initializeFeeFunction(
|
|||||||
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
|
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
|
||||||
// and redo the process until the tx is valid, or return an error when non-RBF
|
// and redo the process until the tx is valid, or return an error when non-RBF
|
||||||
// related errors occur or the budget has been used up.
|
// related errors occur or the budget has been used up.
|
||||||
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
func (t *TxPublisher) createRBFCompliantTx(
|
||||||
|
r *monitorRecord) (*monitorRecord, error) {
|
||||||
|
|
||||||
f := r.feeFunction
|
f := r.feeFunction
|
||||||
|
|
||||||
for {
|
for {
|
||||||
@ -533,7 +535,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
|||||||
switch {
|
switch {
|
||||||
case err == nil:
|
case err == nil:
|
||||||
// The tx is valid, store it.
|
// The tx is valid, store it.
|
||||||
t.storeRecord(r.requestID, sweepCtx, r.req, f)
|
record := t.updateRecord(r, sweepCtx)
|
||||||
|
|
||||||
log.Infof("Created initial sweep tx=%v for %v inputs: "+
|
log.Infof("Created initial sweep tx=%v for %v inputs: "+
|
||||||
"feerate=%v, fee=%v, inputs:\n%v",
|
"feerate=%v, fee=%v, inputs:\n%v",
|
||||||
@ -541,7 +543,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
|||||||
f.FeeRate(), sweepCtx.fee,
|
f.FeeRate(), sweepCtx.fee,
|
||||||
inputTypeSummary(r.req.Inputs))
|
inputTypeSummary(r.req.Inputs))
|
||||||
|
|
||||||
return nil
|
return record, nil
|
||||||
|
|
||||||
// If the error indicates the fees paid is not enough, we will
|
// If the error indicates the fees paid is not enough, we will
|
||||||
// ask the fee function to increase the fee rate and retry.
|
// ask the fee function to increase the fee rate and retry.
|
||||||
@ -572,7 +574,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
|||||||
// cluster these inputs differetly.
|
// cluster these inputs differetly.
|
||||||
increased, err = f.Increment()
|
increased, err = f.Increment()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -582,7 +584,7 @@ func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
|
|||||||
// mempool acceptance.
|
// mempool acceptance.
|
||||||
default:
|
default:
|
||||||
log.Debugf("Failed to create RBF-compliant tx: %v", err)
|
log.Debugf("Failed to create RBF-compliant tx: %v", err)
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -645,13 +647,7 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest,
|
|||||||
// the event channel to the record. Any broadcast-related errors will not be
|
// the event channel to the record. Any broadcast-related errors will not be
|
||||||
// returned here, instead, they will be put inside the `BumpResult` and
|
// returned here, instead, they will be put inside the `BumpResult` and
|
||||||
// returned to the caller.
|
// returned to the caller.
|
||||||
func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
|
func (t *TxPublisher) broadcast(record *monitorRecord) (*BumpResult, error) {
|
||||||
// Get the record being monitored.
|
|
||||||
record, ok := t.records.Load(requestID)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("tx record %v not found", requestID)
|
|
||||||
}
|
|
||||||
|
|
||||||
txid := record.tx.TxHash()
|
txid := record.tx.TxHash()
|
||||||
|
|
||||||
tx := record.tx
|
tx := record.tx
|
||||||
@ -698,7 +694,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) {
|
|||||||
Fee: record.fee,
|
Fee: record.fee,
|
||||||
FeeRate: record.feeFunction.FeeRate(),
|
FeeRate: record.feeFunction.FeeRate(),
|
||||||
Err: err,
|
Err: err,
|
||||||
requestID: requestID,
|
requestID: record.requestID,
|
||||||
}
|
}
|
||||||
|
|
||||||
return result, nil
|
return result, nil
|
||||||
@ -1043,7 +1039,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
|
|||||||
// RBF rules.
|
// RBF rules.
|
||||||
//
|
//
|
||||||
// Create the initial tx to be broadcasted.
|
// Create the initial tx to be broadcasted.
|
||||||
err = t.initializeTx(r)
|
record, err := t.initializeTx(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Initial broadcast failed: %v", err)
|
log.Errorf("Initial broadcast failed: %v", err)
|
||||||
|
|
||||||
@ -1054,7 +1050,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Successfully created the first tx, now broadcast it.
|
// Successfully created the first tx, now broadcast it.
|
||||||
result, err = t.broadcast(r.requestID)
|
result, err = t.broadcast(record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// The broadcast failed, which can only happen if the tx record
|
// The broadcast failed, which can only happen if the tx record
|
||||||
// cannot be found or the aux sweeper returns an error. In
|
// cannot be found or the aux sweeper returns an error. In
|
||||||
@ -1199,10 +1195,10 @@ func (t *TxPublisher) createAndPublishTx(
|
|||||||
|
|
||||||
// The tx has been created without any errors, we now register a new
|
// The tx has been created without any errors, we now register a new
|
||||||
// record by overwriting the same requestID.
|
// record by overwriting the same requestID.
|
||||||
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)
|
record := t.updateRecord(r, sweepCtx)
|
||||||
|
|
||||||
// Attempt to broadcast this new tx.
|
// Attempt to broadcast this new tx.
|
||||||
result, err := t.broadcast(r.requestID)
|
result, err := t.broadcast(record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Infof("Failed to broadcast replacement tx %v: %v",
|
log.Infof("Failed to broadcast replacement tx %v: %v",
|
||||||
sweepCtx.tx.TxHash(), err)
|
sweepCtx.tx.TxHash(), err)
|
||||||
|
@ -313,9 +313,9 @@ func TestInitializeFeeFunction(t *testing.T) {
|
|||||||
require.Equal(t, feerate, f.FeeRate())
|
require.Equal(t, feerate, f.FeeRate())
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestStoreRecord correctly increases the request counter and saves the
|
// TestUpdateRecord correctly updates the fields fee and tx, and saves the
|
||||||
// record.
|
// record.
|
||||||
func TestStoreRecord(t *testing.T) {
|
func TestUpdateRecord(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Create a test input.
|
// Create a test input.
|
||||||
@ -358,8 +358,15 @@ func TestStoreRecord(t *testing.T) {
|
|||||||
outpointToTxIndex: utxoIndex,
|
outpointToTxIndex: utxoIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: initialCounter,
|
||||||
|
req: req,
|
||||||
|
feeFunction: feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
// Call the method under test.
|
// Call the method under test.
|
||||||
tp.storeRecord(initialCounter, sweepCtx, req, feeFunc)
|
tp.updateRecord(record, sweepCtx)
|
||||||
|
|
||||||
// Read the saved record and compare.
|
// Read the saved record and compare.
|
||||||
record, ok := tp.records.Load(initialCounter)
|
record, ok := tp.records.Load(initialCounter)
|
||||||
@ -676,10 +683,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
|
|||||||
tc.setupMock()
|
tc.setupMock()
|
||||||
|
|
||||||
// Call the method under test.
|
// Call the method under test.
|
||||||
err := tp.createRBFCompliantTx(record)
|
rec, err := tp.createRBFCompliantTx(record)
|
||||||
|
|
||||||
// Check the result is as expected.
|
// Check the result is as expected.
|
||||||
require.ErrorIs(t, err, tc.expectedErr)
|
require.ErrorIs(t, err, tc.expectedErr)
|
||||||
|
|
||||||
|
if tc.expectedErr != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert the returned record has the following fields
|
||||||
|
// populated.
|
||||||
|
require.NotEmpty(t, rec.tx)
|
||||||
|
require.NotEmpty(t, rec.fee)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -721,13 +737,13 @@ func TestTxPublisherBroadcast(t *testing.T) {
|
|||||||
outpointToTxIndex: utxoIndex,
|
outpointToTxIndex: utxoIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
// Quickly check when the requestID cannot be found, an error is
|
requestID: requestID,
|
||||||
// returned.
|
req: req,
|
||||||
result, err := tp.broadcast(uint64(1000))
|
feeFunction: m.feeFunc,
|
||||||
require.Error(t, err)
|
}
|
||||||
require.Nil(t, result)
|
rec := tp.updateRecord(record, sweepCtx)
|
||||||
|
|
||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
@ -782,7 +798,7 @@ func TestTxPublisherBroadcast(t *testing.T) {
|
|||||||
tc.setupMock()
|
tc.setupMock()
|
||||||
|
|
||||||
// Call the method under test.
|
// Call the method under test.
|
||||||
result, err := tp.broadcast(requestID)
|
result, err := tp.broadcast(rec)
|
||||||
|
|
||||||
// Check the result is as expected.
|
// Check the result is as expected.
|
||||||
require.ErrorIs(t, err, tc.expectedErr)
|
require.ErrorIs(t, err, tc.expectedErr)
|
||||||
@ -838,7 +854,15 @@ func TestRemoveResult(t *testing.T) {
|
|||||||
name: "remove on TxConfirmed",
|
name: "remove on TxConfirmed",
|
||||||
setupRecord: func() uint64 {
|
setupRecord: func() uint64 {
|
||||||
rid := requestCounter.Add(1)
|
rid := requestCounter.Add(1)
|
||||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
|
||||||
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: rid,
|
||||||
|
req: req,
|
||||||
|
feeFunction: m.feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tp.updateRecord(record, sweepCtx)
|
||||||
tp.subscriberChans.Store(rid, nil)
|
tp.subscriberChans.Store(rid, nil)
|
||||||
|
|
||||||
return rid
|
return rid
|
||||||
@ -854,7 +878,15 @@ func TestRemoveResult(t *testing.T) {
|
|||||||
name: "remove on TxFailed",
|
name: "remove on TxFailed",
|
||||||
setupRecord: func() uint64 {
|
setupRecord: func() uint64 {
|
||||||
rid := requestCounter.Add(1)
|
rid := requestCounter.Add(1)
|
||||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
|
||||||
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: rid,
|
||||||
|
req: req,
|
||||||
|
feeFunction: m.feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tp.updateRecord(record, sweepCtx)
|
||||||
tp.subscriberChans.Store(rid, nil)
|
tp.subscriberChans.Store(rid, nil)
|
||||||
|
|
||||||
return rid
|
return rid
|
||||||
@ -871,7 +903,15 @@ func TestRemoveResult(t *testing.T) {
|
|||||||
name: "noop when tx is not confirmed or failed",
|
name: "noop when tx is not confirmed or failed",
|
||||||
setupRecord: func() uint64 {
|
setupRecord: func() uint64 {
|
||||||
rid := requestCounter.Add(1)
|
rid := requestCounter.Add(1)
|
||||||
tp.storeRecord(rid, sweepCtx, req, m.feeFunc)
|
|
||||||
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: rid,
|
||||||
|
req: req,
|
||||||
|
feeFunction: m.feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tp.updateRecord(record, sweepCtx)
|
||||||
tp.subscriberChans.Store(rid, nil)
|
tp.subscriberChans.Store(rid, nil)
|
||||||
|
|
||||||
return rid
|
return rid
|
||||||
@ -937,8 +977,14 @@ func TestNotifyResult(t *testing.T) {
|
|||||||
fee: fee,
|
fee: fee,
|
||||||
outpointToTxIndex: utxoIndex,
|
outpointToTxIndex: utxoIndex,
|
||||||
}
|
}
|
||||||
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: requestID,
|
||||||
|
req: req,
|
||||||
|
feeFunction: m.feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
tp.updateRecord(record, sweepCtx)
|
||||||
|
|
||||||
// Create a subscription to the event.
|
// Create a subscription to the event.
|
||||||
subscriber := make(chan *BumpResult, 1)
|
subscriber := make(chan *BumpResult, 1)
|
||||||
@ -1250,7 +1296,14 @@ func TestHandleTxConfirmed(t *testing.T) {
|
|||||||
outpointToTxIndex: utxoIndex,
|
outpointToTxIndex: utxoIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
// Create a test record.
|
||||||
|
record := &monitorRecord{
|
||||||
|
requestID: requestID,
|
||||||
|
req: req,
|
||||||
|
feeFunction: m.feeFunc,
|
||||||
|
}
|
||||||
|
|
||||||
|
tp.updateRecord(record, sweepCtx)
|
||||||
record, ok := tp.records.Load(requestID)
|
record, ok := tp.records.Load(requestID)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
@ -1340,7 +1393,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
|
|||||||
outpointToTxIndex: utxoIndex,
|
outpointToTxIndex: utxoIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
tp.storeRecord(requestID, sweepCtx, req, m.feeFunc)
|
tp.updateRecord(record, sweepCtx)
|
||||||
|
|
||||||
// Create a subscription to the event.
|
// Create a subscription to the event.
|
||||||
subscriber := make(chan *BumpResult, 1)
|
subscriber := make(chan *BumpResult, 1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user