diff --git a/contractcourt/utxonursery.go b/contractcourt/utxonursery.go index 6b8742255..c2f0264d3 100644 --- a/contractcourt/utxonursery.go +++ b/contractcourt/utxonursery.go @@ -1406,6 +1406,10 @@ func (k *kidOutput) ConfHeight() uint32 { return k.confHeight } +func (k *kidOutput) RequiredLockTime() (uint32, bool) { + return k.absoluteMaturity, k.absoluteMaturity > 0 +} + // Encode converts a KidOutput struct into a form suitable for on-disk database // storage. Note that the signDescriptor struct field is included so that the // output's witness can be generated by createSweepTx() when the output becomes diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index e0db7eb18..46e932359 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -7,6 +7,7 @@ import ( "sync/atomic" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/rpcclient" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcwallet/chain" @@ -512,17 +513,6 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { txid := record.tx.TxHash() - // Subscribe to its confirmation notification. - confEvent, err := t.cfg.Notifier.RegisterConfirmationsNtfn( - &txid, nil, 1, uint32(t.currentHeight), - ) - if err != nil { - return nil, fmt.Errorf("register confirmation ntfn: %w", err) - } - - // Attach the confirmation event channel to the record. - record.confEvent = confEvent - tx := record.tx log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", txid, len(tx.TxIn), t.currentHeight) @@ -534,7 +524,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { // Publish the sweeping tx with customized label. If the publish fails, // this error will be saved in the `BumpResult` and it will be removed // from being monitored. - err = t.cfg.Wallet.PublishTransaction( + err := t.cfg.Wallet.PublishTransaction( tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { @@ -638,9 +628,6 @@ type monitorRecord struct { // req is the original request. req *BumpRequest - // confEvent is the subscription to the confirmation event of the tx. - confEvent *chainntnfs.ConfirmationEvent - // feeFunction is the fee bumping algorithm used by the publisher. feeFunction FeeFunction @@ -648,6 +635,283 @@ type monitorRecord struct { fee btcutil.Amount } +// Start starts the publisher by subscribing to block epoch updates and kicking +// off the monitor loop. +func (t *TxPublisher) Start() error { + log.Info("TxPublisher starting...") + defer log.Debugf("TxPublisher started") + + blockEvent, err := t.cfg.Notifier.RegisterBlockEpochNtfn(nil) + if err != nil { + return fmt.Errorf("register block epoch ntfn: %w", err) + } + + t.wg.Add(1) + go t.monitor(blockEvent) + + return nil +} + +// Stop stops the publisher and waits for the monitor loop to exit. +func (t *TxPublisher) Stop() { + log.Info("TxPublisher stopping...") + defer log.Debugf("TxPublisher stopped") + + close(t.quit) + + t.wg.Wait() +} + +// monitor is the main loop driven by new blocks. Whevenr a new block arrives, +// it will examine all the txns being monitored, and check if any of them needs +// to be bumped. If so, it will attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine. +func (t *TxPublisher) monitor(blockEvent *chainntnfs.BlockEpochEvent) { + defer blockEvent.Cancel() + defer t.wg.Done() + + for { + select { + case epoch, ok := <-blockEvent.Epochs: + if !ok { + // We should stop the publisher before stopping + // the chain service. Otherwise it indicates an + // error. + log.Error("Block epoch channel closed, exit " + + "monitor") + + return + } + + log.Debugf("TxPublisher received new block: %v", + epoch.Height) + + // Update the best known height for the publisher. + t.currentHeight = epoch.Height + + // Check all monitored txns to see if any of them needs + // to be bumped. + t.processRecords() + + case <-t.quit: + log.Debug("Fee bumper stopped, exit monitor") + return + } + } +} + +// processRecords checks all the txns being monitored, and checks if any of +// them needs to be bumped. If so, it will attempt to bump the fee of the tx. +func (t *TxPublisher) processRecords() { + // confirmedRecords stores a map of the records which have been + // confirmed. + confirmedRecords := make(map[uint64]*monitorRecord) + + // feeBumpRecords stores a map of the records which need to be bumped. + feeBumpRecords := make(map[uint64]*monitorRecord) + + // visitor is a helper closure that visits each record and divides them + // into two groups. + visitor := func(requestID uint64, r *monitorRecord) error { + log.Tracef("Checking monitor recordID=%v for tx=%v", requestID, + r.tx.TxHash()) + + // If the tx is already confirmed, we can stop monitoring it. + if t.isConfirmed(r.tx.TxHash()) { + confirmedRecords[requestID] = r + + // Move to the next record. + return nil + } + + feeBumpRecords[requestID] = r + + // Return nil to move to the next record. + return nil + } + + // Iterate through all the records and divide them into two groups. + t.records.ForEach(visitor) + + // For records that are confirmed, we'll notify the caller about this + // result. + for requestID, r := range confirmedRecords { + rec := r + + log.Debugf("Tx=%v is confirmed", r.tx.TxHash()) + t.wg.Add(1) + go t.handleTxConfirmed(rec, requestID) + } + + // Get the current height to be used in the following goroutines. + currentHeight := t.currentHeight + + // For records that are not confirmed, we perform a fee bump if needed. + for requestID, r := range feeBumpRecords { + rec := r + + log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash()) + t.wg.Add(1) + go t.handleFeeBumpTx(requestID, rec, currentHeight) + } +} + +// handleTxConfirmed is called when a monitored tx is confirmed. It will +// notify the subscriber then remove the record from the maps . +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) { + defer t.wg.Done() + + // Create a result that will be sent to the resultChan which is + // listened by the caller. + result := &BumpResult{ + Event: TxConfirmed, + Tx: r.tx, + requestID: requestID, + Fee: r.fee, + FeeRate: r.feeFunction.FeeRate(), + } + + // Notify that this tx is confirmed and remove the record from the map. + t.handleResult(result) +} + +// handleFeeBumpTx checks if the tx needs to be bumped, and if so, it will +// attempt to bump the fee of the tx. +// +// NOTE: Must be run as a goroutine to avoid blocking on sending the result. +func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord, + currentHeight int32) { + + defer t.wg.Done() + + oldTxid := r.tx.TxHash() + + // Get the current conf target for this record. + confTarget := calcCurrentConfTarget(currentHeight, r.req.DeadlineHeight) + + // Ask the fee function whether a bump is needed. We expect the fee + // function to increase its returned fee rate after calling this + // method. + increased, err := r.feeFunction.IncreaseFeeRate(confTarget) + if err != nil { + // TODO(yy): send this error back to the sweeper so it can + // re-group the inputs? + log.Errorf("Failed to increase fee rate for tx %v at "+ + "height=%v: %v", oldTxid, t.currentHeight, err) + + return + } + + // If the fee rate was not increased, there's no need to bump the fee. + if !increased { + log.Tracef("Skip bumping tx %v at height=%v", oldTxid, + t.currentHeight) + + return + } + + // The fee function now has a new fee rate, we will use it to bump the + // fee of the tx. + result, err := t.createAndPublishTx(requestID, r) + if err != nil { + log.Errorf("Failed to bump tx %v: %v", oldTxid, err) + + return + } + + // Notify the new result. + t.handleResult(result) +} + +// createAndPublishTx creates a new tx with a higher fee rate and publishes it +// to the network. It will update the record with the new tx and fee rate if +// successfully created, and return the result when published successfully. +func (t *TxPublisher) createAndPublishTx(requestID uint64, + r *monitorRecord) (*BumpResult, error) { + + // Fetch the old tx. + oldTx := r.tx + + // Create a new tx with the new fee rate. + // + // NOTE: The fee function is expected to have increased its returned + // fee rate after calling the SkipFeeBump method. So we can use it + // directly here. + tx, fee, err := t.createAndCheckTx(r.req, r.feeFunction) + + // If the tx doesn't not have enought budget, we will return a result + // so the sweeper can handle it by re-clustering the utxos. + if errors.Is(err, ErrNotEnoughBudget) { + log.Warnf("Fail to fee bump tx %v: %v", oldTx.TxHash(), err) + + return &BumpResult{ + Event: TxFailed, + Tx: oldTx, + Err: err, + requestID: requestID, + }, nil + } + + // If the error is not budget related, we will return an error and let + // the fee bumper retry it at next block. + // + // NOTE: we can check the RBF error here and ask the fee function to + // recalculate the fee rate. However, this would defeat the purpose of + // using a deadline based fee function: + // - if the deadline is far away, there's no rush to RBF the tx. + // - if the deadline is close, we expect the fee function to give us a + // higher fee rate. If the fee rate cannot satisfy the RBF rules, it + // means the budget is not enough. + if err != nil { + log.Infof("Failed to bump tx %v: %v", oldTx.TxHash(), err) + return nil, err + } + + // Register a new record by overwriting the same requestID. + t.records.Store(requestID, &monitorRecord{ + tx: tx, + req: r.req, + feeFunction: r.feeFunction, + fee: fee, + }) + + // Attempt to broadcast this new tx. + result, err := t.broadcast(requestID) + if err != nil { + return nil, err + } + + // A successful replacement tx is created, attach the old tx. + result.ReplacedTx = oldTx + + // If the new tx failed to be published, we will return the result so + // the caller can handle it. + if result.Event == TxFailed { + return result, nil + } + + log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), tx.TxHash()) + + // Otherwise, it's a successful RBF, set the event and return. + result.Event = TxReplaced + + return result, nil +} + +// isConfirmed checks the btcwallet to see whether the tx is confirmed. +func (t *TxPublisher) isConfirmed(txid chainhash.Hash) bool { + details, err := t.cfg.Wallet.GetTransactionDetails(&txid) + if err != nil { + log.Warnf("Failed to get tx details for %v: %v", txid, err) + return false + } + + return details.NumConfirmations > 0 +} + // calcCurrentConfTarget calculates the current confirmation target based on // the deadline height. The conf target is capped at 0 if the deadline has // already been past. diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 308a69a57..f3b67f3bd 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -612,15 +612,11 @@ func TestTxPublisherBroadcast(t *testing.T) { // Create a test tx. tx := &wire.MsgTx{LockTime: 1} - txid := tx.TxHash() // Create a test feerate and return it from the mock fee function. feerate := chainfee.SatPerKWeight(1000) m.feeFunc.On("FeeRate").Return(feerate) - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - // Create a testing record and put it in the map. fee := btcutil.Amount(1000) requestID := tp.storeRecord(tx, req, m.feeFunc, fee) @@ -631,41 +627,17 @@ func TestTxPublisherBroadcast(t *testing.T) { require.Error(t, err) require.Nil(t, result) - // Define params to be used in RegisterConfirmationsNtfn. Not important - // for this test. - var pkScript []byte - confs := uint32(1) - height := uint32(tp.currentHeight) - testCases := []struct { name string setupMock func() expectedErr error expectedResult *BumpResult }{ - { - // When the notifier cannot register this spend, an - // error should be returned - name: "fail to register nftn", - setupMock: func() { - // Mock the RegisterConfirmationsNtfn to fail. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - nil, errDummy).Once() - }, - expectedErr: errDummy, - expectedResult: nil, - }, { // When the wallet cannot publish this tx, the error // should be put inside the result. name: "fail to publish", setupMock: func() { - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - confEvent, nil).Once() - // Mock the wallet to fail to publish. m.wallet.On("PublishTransaction", tx, mock.Anything).Return( @@ -685,11 +657,6 @@ func TestTxPublisherBroadcast(t *testing.T) { // When nothing goes wrong, the result is returned. name: "publish success", setupMock: func() { - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - &txid, pkScript, confs, height).Return( - confEvent, nil).Once() - // Mock the wallet to publish successfully. m.wallet.On("PublishTransaction", tx, mock.Anything).Return(nil).Once() @@ -910,14 +877,6 @@ func TestBroadcastSuccess(t *testing.T) { // Mock the testmempoolaccept to pass. m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(confEvent, nil).Once() - // Mock the wallet to publish successfully. m.wallet.On("PublishTransaction", mock.Anything, mock.Anything).Return(nil).Once() @@ -1007,14 +966,6 @@ func TestBroadcastFail(t *testing.T) { // Mock the testmempoolaccept again, this time it passes. m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() - // Create a test conf event. - confEvent := &chainntnfs.ConfirmationEvent{} - - // Mock the RegisterConfirmationsNtfn to pass. - m.notifier.On("RegisterConfirmationsNtfn", - mock.Anything, mock.Anything, mock.Anything, mock.Anything, - ).Return(confEvent, nil).Once() - // Mock the wallet to fail on publish. m.wallet.On("PublishTransaction", mock.Anything, mock.Anything).Return(errDummy).Once() @@ -1039,3 +990,418 @@ func TestBroadcastFail(t *testing.T) { require.Equal(t, 0, tp.records.Len()) require.Equal(t, 0, tp.subscriberChans.Len()) } + +// TestCreateAnPublishFail checks all the error cases are handled properly in +// the method createAndPublish. +func TestCreateAnPublishFail(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + + // Overwrite the budget to make it smaller than the fee. + req.Budget = 100 + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Call the createAndPublish method. + result, err := tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxFailed and the error is set in the + // result. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, ErrNotEnoughBudget) + require.Equal(t, requestID, result.requestID) + + // Increase the budget and call it again. This time we will mock an + // error to be returned from CheckMempoolAcceptance. + req.Budget = 1000 + + // Mock the testmempoolaccept to return an error. + m.wallet.On("CheckMempoolAcceptance", + mock.Anything).Return(lnwallet.ErrMempoolFee).Once() + + // Call the createAndPublish method and expect an error. + result, err = tp.createAndPublishTx(requestID, record) + require.ErrorIs(t, err, lnwallet.ErrMempoolFee) + require.Nil(t, result) +} + +// TestCreateAnPublishSuccess checks the expected result is returned from the +// method createAndPublish. +func TestCreateAnPublishSuccess(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test requestID. + requestID := uint64(1) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: &wire.MsgTx{}, + } + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish and return an error. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(errDummy).Once() + + // Call the createAndPublish method and expect a failure result. + result, err := tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxFailed and the error is set. + require.Equal(t, TxFailed, result.Event) + require.ErrorIs(t, result.Err, errDummy) + + // Although the replacement tx was failed to be published, the record + // should be stored. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + _, found := tp.records.Load(requestID) + require.True(t, found) + + // We now check a successful RBF. + // + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the createAndPublish method and expect a success result. + result, err = tp.createAndPublishTx(requestID, record) + require.NoError(t, err) + + // We expect the result to be TxReplaced and the error is nil. + require.Equal(t, TxReplaced, result.Event) + require.Nil(t, result.Err) + + // Check the Tx and ReplacedTx are set. + require.NotNil(t, result.Tx) + require.NotNil(t, result.ReplacedTx) + + // Check the record is stored. + _, found = tp.records.Load(requestID) + require.True(t, found) +} + +// TestHandleTxConfirmed checks the expected result is returned from the method +// handleTxConfirmed. +func TestHandleTxConfirmed(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test bump request. + req := createTestBumpRequest() + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + record, ok := tp.records.Load(requestID) + require.True(t, ok) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Mock the fee function to return a fee rate. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleTxConfirmed(record, requestID) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxConfirmed and the tx is set. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx, result.Tx) + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + require.Equal(t, record.fee, result.Fee) + require.Equal(t, feerate, result.FeeRate) + } + + // We expect the record to be removed from the maps. + _, found := tp.records.Load(requestID) + require.False(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.False(t, found) +} + +// TestHandleFeeBumpTx validates handleFeeBumpTx behaves as expected. +func TestHandleFeeBumpTx(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create a test tx. + tx := &wire.MsgTx{LockTime: 1} + + // Create a test current height. + testHeight := int32(800000) + + // Create a testing monitor record. + req := createTestBumpRequest() + record := &monitorRecord{ + req: req, + feeFunction: m.feeFunc, + tx: tx, + } + + // Create a testing record and put it in the map. + fee := btcutil.Amount(1000) + requestID := tp.storeRecord(tx, req, m.feeFunc, fee) + + // Create a subscription to the event. + subscriber := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID, subscriber) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // Mock the fee function to skip the bump due to error. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return( + false, errDummy).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to skip the bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(false, nil).Once() + + // Call the method and expect no result received. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + // Check there's no result sent back. + select { + case <-time.After(time.Second): + case result := <-subscriber: + t.Fatalf("unexpected result received: %v", result) + } + + // Mock the fee function to perform the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil) + + // Mock the signer to always return a valid script. + // + // NOTE: we are not testing the utility of creating valid txes here, so + // this is fine to be mocked. This behaves essentially as skipping the + // Signer check and alaways assume the tx has a valid sig. + script := &input.Script{} + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(script, nil) + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil) + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call the method and expect a result to be received. + // + // NOTE: must be called in a goroutine in case it blocks. + tp.wg.Add(1) + go tp.handleFeeBumpTx(requestID, record, testHeight) + + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriber to receive result") + + case result := <-subscriber: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx, result.Tx) + require.Equal(t, tx, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID, result.requestID) + } + + // We expect the record to NOT be removed from the maps. + _, found := tp.records.Load(requestID) + require.True(t, found) + _, found = tp.subscriberChans.Load(requestID) + require.True(t, found) +} + +// TestProcessRecords validates processRecords behaves as expected. +func TestProcessRecords(t *testing.T) { + t.Parallel() + + // Create a publisher using the mocks. + tp, m := createTestPublisher(t) + + // Create testing objects. + requestID1 := uint64(1) + req1 := createTestBumpRequest() + tx1 := &wire.MsgTx{LockTime: 1} + txid1 := tx1.TxHash() + + requestID2 := uint64(2) + req2 := createTestBumpRequest() + tx2 := &wire.MsgTx{LockTime: 2} + txid2 := tx2.TxHash() + + // Create a monitor record that's confirmed. + recordConfirmed := &monitorRecord{ + req: req1, + feeFunction: m.feeFunc, + tx: tx1, + } + m.wallet.On("GetTransactionDetails", &txid1).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 1, + }, nil, + ).Once() + + // Create a monitor record that's not confirmed. We know it's not + // confirmed because the num of confirms is zero. + recordFeeBump := &monitorRecord{ + req: req2, + feeFunction: m.feeFunc, + tx: tx2, + } + m.wallet.On("GetTransactionDetails", &txid2).Return( + &lnwallet.TransactionDetail{ + NumConfirmations: 0, + }, nil, + ).Once() + + // Setup the initial publisher state by adding the records to the maps. + subscriberConfirmed := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID1, subscriberConfirmed) + tp.records.Store(requestID1, recordConfirmed) + + subscriberReplaced := make(chan *BumpResult, 1) + tp.subscriberChans.Store(requestID2, subscriberReplaced) + tp.records.Store(requestID2, recordFeeBump) + + // Create a test feerate and return it from the mock fee function. + feerate := chainfee.SatPerKWeight(1000) + m.feeFunc.On("FeeRate").Return(feerate) + + // The following methods should only be called once when creating the + // replacement tx. + // + // Mock the fee function to NOT skip the fee bump. + m.feeFunc.On("IncreaseFeeRate", mock.Anything).Return(true, nil).Once() + + // Mock the signer to always return a valid script. + m.signer.On("ComputeInputScript", mock.Anything, + mock.Anything).Return(&input.Script{}, nil).Once() + + // Mock the testmempoolaccept to return nil. + m.wallet.On("CheckMempoolAcceptance", mock.Anything).Return(nil).Once() + + // Mock the wallet to publish successfully. + m.wallet.On("PublishTransaction", + mock.Anything, mock.Anything).Return(nil).Once() + + // Call processRecords and expect the results are notified back. + tp.processRecords() + + // We expect two results to be received. One for the confirmed tx and + // one for the replaced tx. + // + // Check the confirmed tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberConfirmed") + + case result := <-subscriberConfirmed: + // We expect the result to be TxConfirmed. + require.Equal(t, TxConfirmed, result.Event) + require.Equal(t, tx1, result.Tx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID1, result.requestID) + } + + // Now check the replaced tx result. + select { + case <-time.After(time.Second): + t.Fatal("timeout waiting for subscriberReplaced") + + case result := <-subscriberReplaced: + // We expect the result to be TxReplaced. + require.Equal(t, TxReplaced, result.Event) + + // The new tx and old tx should be properly set. + require.NotEqual(t, tx2, result.Tx) + require.Equal(t, tx2, result.ReplacedTx) + + // No error should be set. + require.Nil(t, result.Err) + require.Equal(t, requestID2, result.requestID) + } +} diff --git a/sweep/interface.go b/sweep/interface.go index e58cc8507..a6e5d2153 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -46,4 +46,9 @@ type Wallet interface { // policies and returns an error if it cannot be accepted into the // mempool. CheckMempoolAcceptance(tx *wire.MsgTx) error + + // GetTransactionDetails returns a detailed description of a tx given + // its transaction hash. + GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) } diff --git a/sweep/mock_test.go b/sweep/mock_test.go index 86edbacef..6b23953c3 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -175,6 +175,14 @@ func (b *mockBackend) FetchTx(chainhash.Hash) (*wire.MsgTx, error) { func (b *mockBackend) CancelRebroadcast(tx chainhash.Hash) { } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (b *mockBackend) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + return nil, nil +} + // mockFeeEstimator implements a mock fee estimator. It closely resembles // lnwallet.StaticFeeEstimator with the addition that fees can be changed for // testing purposes in a thread safe manner. @@ -418,6 +426,20 @@ func (m *MockWallet) CancelRebroadcast(tx chainhash.Hash) { m.Called(tx) } +// GetTransactionDetails returns a detailed description of a tx given its +// transaction hash. +func (m *MockWallet) GetTransactionDetails(txHash *chainhash.Hash) ( + *lnwallet.TransactionDetail, error) { + + args := m.Called(txHash) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(*lnwallet.TransactionDetail), args.Error(1) +} + // MockInputSet is a mock implementation of the InputSet interface. type MockInputSet struct { mock.Mock