diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go new file mode 100644 index 000000000..c40614961 --- /dev/null +++ b/sweep/fee_bumper.go @@ -0,0 +1,142 @@ +package sweep + +import ( + "errors" + "fmt" + + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet/chainfee" +) + +var ( + // ErrInvalidBumpResult is returned when the bump result is invalid. + ErrInvalidBumpResult = errors.New("invalid bump result") +) + +// Bumper defines an interface that can be used by other subsystems for fee +// bumping. +type Bumper interface { + // Broadcast is used to publish the tx created from the given inputs + // specified in the request. It handles the tx creation, broadcasts it, + // and monitors its confirmation status for potential fee bumping. It + // returns a chan that the caller can use to receive updates about the + // broadcast result and potential RBF attempts. + Broadcast(req *BumpRequest) (<-chan *BumpResult, error) +} + +// BumpEvent represents the event of a fee bumping attempt. +type BumpEvent uint8 + +const ( + // TxPublished is sent when the broadcast attempt is finished. + TxPublished BumpEvent = iota + + // TxFailed is sent when the broadcast attempt fails. + TxFailed + + // TxReplaced is sent when the original tx is replaced by a new one. + TxReplaced + + // TxConfirmed is sent when the tx is confirmed. + TxConfirmed + + // sentinalEvent is used to check if an event is unknown. + sentinalEvent +) + +// String returns a human-readable string for the event. +func (e BumpEvent) String() string { + switch e { + case TxPublished: + return "Published" + case TxFailed: + return "Failed" + case TxReplaced: + return "Replaced" + case TxConfirmed: + return "Confirmed" + default: + return "Unknown" + } +} + +// Unknown returns true if the event is unknown. +func (e BumpEvent) Unknown() bool { + return e >= sentinalEvent +} + +// BumpRequest is used by the caller to give the Bumper the necessary info to +// create and manage potential fee bumps for a set of inputs. +type BumpRequest struct { + // Budget givens the total amount that can be used as fees by these + // inputs. + Budget btcutil.Amount + + // Inputs is the set of inputs to sweep. + Inputs []input.Input + + // DeadlineHeight is the block height at which the tx should be + // confirmed. + DeadlineHeight int32 + + // DeliveryAddress is the script to send the change output to. + DeliveryAddress []byte + + // MaxFeeRate is the maximum fee rate that can be used for fee bumping. + MaxFeeRate chainfee.SatPerKWeight +} + +// BumpResult is used by the Bumper to send updates about the tx being +// broadcast. +type BumpResult struct { + // Event is the type of event that the result is for. + Event BumpEvent + + // Tx is the tx being broadcast. + Tx *wire.MsgTx + + // ReplacedTx is the old, replaced tx if a fee bump is attempted. + ReplacedTx *wire.MsgTx + + // FeeRate is the fee rate used for the new tx. + FeeRate chainfee.SatPerKWeight + + // Fee is the fee paid by the new tx. + Fee btcutil.Amount + + // Err is the error that occurred during the broadcast. + Err error +} + +// Validate validates the BumpResult so it's safe to use. +func (b *BumpResult) Validate() error { + // Every result must have a tx. + if b.Tx == nil { + return fmt.Errorf("%w: nil tx", ErrInvalidBumpResult) + } + + // Every result must have a known event. + if b.Event.Unknown() { + return fmt.Errorf("%w: unknown event", ErrInvalidBumpResult) + } + + // If it's a replacing event, it must have a replaced tx. + if b.Event == TxReplaced && b.ReplacedTx == nil { + return fmt.Errorf("%w: nil replacing tx", ErrInvalidBumpResult) + } + + // If it's a failed event, it must have an error. + if b.Event == TxFailed && b.Err == nil { + return fmt.Errorf("%w: nil error", ErrInvalidBumpResult) + } + + // If it's a confirmed event, it must have a fee rate and fee. + if b.Event == TxConfirmed && (b.FeeRate == 0 || b.Fee == 0) { + return fmt.Errorf("%w: missing fee rate or fee", + ErrInvalidBumpResult) + } + + return nil +} diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go new file mode 100644 index 000000000..22c247b2c --- /dev/null +++ b/sweep/fee_bumper_test.go @@ -0,0 +1,52 @@ +package sweep + +import ( + "testing" + + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/require" +) + +// TestBumpResultValidate tests the validate method of the BumpResult struct. +func TestBumpResultValidate(t *testing.T) { + t.Parallel() + + // An empty result will give an error. + b := BumpResult{} + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Unknown event type will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: sentinalEvent, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A replacing event without a new tx will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxReplaced, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A failed event without a failure reason will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxFailed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // A confirmed event without fee info will give an error. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxConfirmed, + } + require.ErrorIs(t, b.Validate(), ErrInvalidBumpResult) + + // Test a valid result. + b = BumpResult{ + Tx: &wire.MsgTx{}, + Event: TxPublished, + } + require.NoError(t, b.Validate()) +} diff --git a/sweep/mock_test.go b/sweep/mock_test.go index f908cf6db..270c3844e 100644 --- a/sweep/mock_test.go +++ b/sweep/mock_test.go @@ -462,3 +462,22 @@ func (m *MockInputSet) Budget() btcutil.Amount { return args.Get(0).(btcutil.Amount) } + +// MockBumper is a mock implementation of the interface Bumper. +type MockBumper struct { + mock.Mock +} + +// Compile-time constraint to ensure MockBumper implements Bumper. +var _ Bumper = (*MockBumper)(nil) + +// Broadcast broadcasts the transaction to the network. +func (m *MockBumper) Broadcast(req *BumpRequest) (<-chan *BumpResult, error) { + args := m.Called(req) + + if args.Get(0) == nil { + return nil, args.Error(1) + } + + return args.Get(0).(chan *BumpResult), args.Error(1) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 659c1d340..505292912 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -13,7 +13,6 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" - "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" ) @@ -41,6 +40,12 @@ var ( // an input is included in a publish attempt before giving up and // returning an error to the caller. DefaultMaxSweepAttempts = 10 + + // DefaultDeadlineDelta defines a default deadline delta (1 week) to be + // used when sweeping inputs with no deadline pressure. + // + // TODO(yy): make this configurable. + DefaultDeadlineDelta = int32(1008) ) // Params contains the parameters that control the sweeping process. @@ -317,6 +322,10 @@ type UtxoSweeper struct { // currentHeight is the best known height of the main chain. This is // updated whenever a new block epoch is received. currentHeight int32 + + // bumpResultChan is a channel that receives broadcast results from the + // TxPublisher. + bumpResultChan chan *BumpResult } // UtxoSweeperConfig contains dependencies of UtxoSweeper. @@ -364,6 +373,10 @@ type UtxoSweeperConfig struct { // Aggregator is used to group inputs into clusters based on its // implemention-specific strategy. Aggregator UtxoAggregator + + // Publisher is used to publish the sweep tx crafted here and monitors + // it for potential fee bumps. + Publisher Bumper } // Result is the struct that is pushed through the result channel. Callers can @@ -397,6 +410,7 @@ func New(cfg *UtxoSweeperConfig) *UtxoSweeper { pendingSweepsReqs: make(chan *pendingSweepsReq), quit: make(chan struct{}), pendingInputs: make(pendingInputs), + bumpResultChan: make(chan *BumpResult, 100), } } @@ -670,11 +684,16 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { err: err, } - // A new block comes in, update the bestHeight. - // - // TODO(yy): this is where we check our published transactions - // and perform RBF if needed. We'd also like to consult our fee - // bumper to get an updated fee rate. + case result := <-s.bumpResultChan: + // Handle the bump event. + err := s.handleBumpEvent(result) + if err != nil { + log.Errorf("Failed to handle bump event: %v", + err) + } + + // A new block comes in, update the bestHeight, perform a check + // over all pending inputs and publish sweeping txns if needed. case epoch, ok := <-blockEpochs: if !ok { // We should stop the sweeper before stopping @@ -779,8 +798,8 @@ func (s *UtxoSweeper) signalResult(pi *pendingInput, result Result) { } } -// sweep takes a set of preselected inputs, creates a sweep tx and publishes the -// tx. The output address is only marked as used if the publish succeeds. +// sweep takes a set of preselected inputs, creates a sweep tx and publishes +// the tx. The output address is only marked as used if the publish succeeds. func (s *UtxoSweeper) sweep(set InputSet) error { // Generate an output script if there isn't an unused script available. if s.currentOutputScript == nil { @@ -791,20 +810,21 @@ func (s *UtxoSweeper) sweep(set InputSet) error { s.currentOutputScript = pkScript } - // Create sweep tx. - tx, fee, err := createSweepTx( - set.Inputs(), nil, s.currentOutputScript, - uint32(s.currentHeight), set.FeeRate(), - s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, - ) - if err != nil { - return fmt.Errorf("create sweep tx: %w", err) - } + // Create a default deadline height, and replace it with set's + // DeadlineHeight if it's set. + deadlineHeight := s.currentHeight + DefaultDeadlineDelta + deadlineHeight = set.DeadlineHeight().UnwrapOr(deadlineHeight) - tr := &TxRecord{ - Txid: tx.TxHash(), - FeeRate: uint64(set.FeeRate()), - Fee: uint64(fee), + // Create a fee bump request and ask the publisher to broadcast it. The + // publisher will then take over and start monitoring the tx for + // potential fee bump. + req := &BumpRequest{ + Inputs: set.Inputs(), + Budget: set.Budget(), + DeadlineHeight: deadlineHeight, + DeliveryAddress: s.currentOutputScript, + MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), + // TODO(yy): pass the strategy here. } // Reschedule the inputs that we just tried to sweep. This is done in @@ -812,13 +832,9 @@ func (s *UtxoSweeper) sweep(set InputSet) error { // publish attempts and rescue them in the next sweep. s.markInputsPendingPublish(set) - log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", - tx.TxHash(), len(tx.TxIn), s.currentHeight) - - // Publish the sweeping tx with customized label. - err = s.cfg.Wallet.PublishTransaction( - tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), - ) + // Broadcast will return a read-only chan that we will listen to for + // this publish result and future RBF attempt. + resp, err := s.cfg.Publisher.Broadcast(req) if err != nil { outpoints := make([]wire.OutPoint, len(set.Inputs())) for i, inp := range set.Inputs() { @@ -831,16 +847,11 @@ func (s *UtxoSweeper) sweep(set InputSet) error { return err } - // Inputs have been successfully published so we update their states. - err = s.markInputsPublished(tr, tx.TxIn) - if err != nil { - return err - } - - // If there's no error, remove the output script. Otherwise keep it so - // that it can be reused for the next transaction and causes no address - // inflation. - s.currentOutputScript = nil + // Successfully sent the broadcast attempt, we now handle the result by + // subscribing to the result chan and listen for future updates about + // this tx. + s.wg.Add(1) + go s.monitorFeeBumpResult(resp) return nil } @@ -1557,3 +1568,167 @@ func (s *UtxoSweeper) sweepPendingInputs(inputs pendingInputs) { } } } + +// monitorFeeBumpResult subscribes to the passed result chan to listen for +// future updates about the sweeping tx. +// +// NOTE: must run as a goroutine. +func (s *UtxoSweeper) monitorFeeBumpResult(resultChan <-chan *BumpResult) { + defer s.wg.Done() + + for { + select { + case r := <-resultChan: + // Validate the result is valid. + if err := r.Validate(); err != nil { + log.Errorf("Received invalid result: %v", err) + continue + } + + // Send the result back to the main event loop. + select { + case s.bumpResultChan <- r: + case <-s.quit: + log.Debug("Sweeper shutting down, skip " + + "sending bump result") + + return + } + + // The sweeping tx has been confirmed, we can exit the + // monitor now. + // + // TODO(yy): can instead remove the spend subscription + // in sweeper and rely solely on this event to mark + // inputs as Swept? + if r.Event == TxConfirmed || r.Event == TxFailed { + log.Debugf("Received %v for sweep tx %v, exit "+ + "fee bump monitor", r.Event, + r.Tx.TxHash()) + + return + } + + case <-s.quit: + log.Debugf("Sweeper shutting down, exit fee " + + "bump handler") + + return + } + } +} + +// handleBumpEventTxFailed handles the case where the tx has been failed to +// publish. +func (s *UtxoSweeper) handleBumpEventTxFailed(r *BumpResult) error { + tx, err := r.Tx, r.Err + + log.Errorf("Fee bump attempt failed for tx=%v: %v", tx.TxHash(), err) + + outpoints := make([]wire.OutPoint, 0, len(tx.TxIn)) + for _, inp := range tx.TxIn { + outpoints = append(outpoints, inp.PreviousOutPoint) + } + + // TODO(yy): should we also remove the failed tx from db? + s.markInputsPublishFailed(outpoints) + + return err +} + +// handleBumpEventTxReplaced handles the case where the sweeping tx has been +// replaced by a new one. +func (s *UtxoSweeper) handleBumpEventTxReplaced(r *BumpResult) error { + oldTx := r.ReplacedTx + newTx := r.Tx + + // Prepare a new record to replace the old one. + tr := &TxRecord{ + Txid: newTx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Get the old record for logging purpose. + oldTxid := oldTx.TxHash() + record, err := s.cfg.Store.GetTx(oldTxid) + if err != nil { + log.Errorf("Fetch tx record for %v: %v", oldTxid, err) + return err + } + + log.Infof("RBFed tx=%v(fee=%v, feerate=%v) with new tx=%v(fee=%v, "+ + "feerate=%v)", record.Txid, record.Fee, record.FeeRate, + tr.Txid, tr.Fee, tr.FeeRate) + + // The old sweeping tx has been replaced by a new one, we will update + // the tx record in the sweeper db. + // + // TODO(yy): we may also need to update the inputs in this tx to a new + // state. Suppose a replacing tx only spends a subset of the inputs + // here, we'd end up with the rest being marked as `StatePublished` and + // won't be aggregated in the next sweep. Atm it's fine as we always + // RBF the same input set. + if err := s.cfg.Store.DeleteTx(oldTxid); err != nil { + log.Errorf("Delete tx record for %v: %v", oldTxid, err) + return err + } + + // Mark the inputs as published using the replacing tx. + return s.markInputsPublished(tr, r.Tx.TxIn) +} + +// handleBumpEventTxPublished handles the case where the sweeping tx has been +// successfully published. +func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { + tx := r.Tx + tr := &TxRecord{ + Txid: tx.TxHash(), + FeeRate: uint64(r.FeeRate), + Fee: uint64(r.Fee), + } + + // Inputs have been successfully published so we update their + // states. + err := s.markInputsPublished(tr, tx.TxIn) + if err != nil { + return err + } + + log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v", + tx.TxHash(), len(tx.TxIn), s.currentHeight) + + // If there's no error, remove the output script. Otherwise + // keep it so that it can be reused for the next transaction + // and causes no address inflation. + s.currentOutputScript = nil + + return nil +} + +// handleBumpEvent handles the result sent from the bumper based on its event +// type. +// +// NOTE: TxConfirmed event is not handled, since we already subscribe to the +// input's spending event, we don't need to do anything here. +func (s *UtxoSweeper) handleBumpEvent(r *BumpResult) error { + log.Debugf("Received bump event [%v] for tx %v", r.Event, r.Tx.TxHash()) + + switch r.Event { + // The tx has been published, we update the inputs' state and create a + // record to be stored in the sweeper db. + case TxPublished: + return s.handleBumpEventTxPublished(r) + + // The tx has failed, we update the inputs' state. + case TxFailed: + return s.handleBumpEventTxFailed(r) + + // The tx has been replaced, we will remove the old tx and replace it + // with the new one. + case TxReplaced: + return s.handleBumpEventTxReplaced(r) + } + + return nil +} diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index f6891db92..db26a4e75 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -33,6 +33,8 @@ var ( testMaxInputsPerTx = uint32(3) defaultFeePref = Params{Fee: FeeEstimateInfo{ConfTarget: 1}} + + errDummy = errors.New("dummy error") ) type sweeperTestContext struct { @@ -137,6 +139,12 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { currentHeight: mockChainHeight, } + // Create a mock fee bumper. + mockBumper := &MockBumper{} + t.Cleanup(func() { + mockBumper.AssertExpectations(t) + }) + ctx.sweeper = New(&UtxoSweeperConfig{ Notifier: notifier, Wallet: backend, @@ -153,6 +161,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { MaxSweepAttempts: testMaxSweepAttempts, MaxFeeRate: DefaultMaxFeeRate, Aggregator: aggregator, + Publisher: mockBumper, }) ctx.sweeper.Start() @@ -2410,16 +2419,27 @@ func TestSweepPendingInputs(t *testing.T) { // Create a mock wallet and aggregator. wallet := &MockWallet{} + defer wallet.AssertExpectations(t) + aggregator := &mockUtxoAggregator{} + defer aggregator.AssertExpectations(t) + + publisher := &MockBumper{} + defer publisher.AssertExpectations(t) // Create a test sweeper. s := New(&UtxoSweeperConfig{ Wallet: wallet, Aggregator: aggregator, + Publisher: publisher, + GenSweepScript: func() ([]byte, error) { + return testPubKey.SerializeCompressed(), nil + }, }) // Create an input set that needs wallet inputs. setNeedWallet := &MockInputSet{} + defer setNeedWallet.AssertExpectations(t) // Mock this set to ask for wallet input. setNeedWallet.On("NeedWalletInput").Return(true).Once() @@ -2430,15 +2450,18 @@ func TestSweepPendingInputs(t *testing.T) { // Create an input set that doesn't need wallet inputs. normalSet := &MockInputSet{} + defer normalSet.AssertExpectations(t) + normalSet.On("NeedWalletInput").Return(false).Once() // Mock the methods used in `sweep`. This is not important for this // unit test. - feeRate := chainfee.SatPerKWeight(1000) - setNeedWallet.On("Inputs").Return(nil).Once() - setNeedWallet.On("FeeRate").Return(feeRate).Once() - normalSet.On("Inputs").Return(nil).Once() - normalSet.On("FeeRate").Return(feeRate).Once() + setNeedWallet.On("Inputs").Return(nil).Times(4) + setNeedWallet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + setNeedWallet.On("Budget").Return(btcutil.Amount(1)).Once() + normalSet.On("Inputs").Return(nil).Times(4) + normalSet.On("DeadlineHeight").Return(fn.None[int32]()).Once() + normalSet.On("Budget").Return(btcutil.Amount(1)).Once() // Make pending inputs for testing. We don't need real values here as // the returned clusters are mocked. @@ -2449,19 +2472,369 @@ func TestSweepPendingInputs(t *testing.T) { setNeedWallet, normalSet, }) - // Set change output script to an invalid value. This should cause the + // Mock `Broadcast` to return an error. This should cause the // `createSweepTx` inside `sweep` to fail. This is done so we can // terminate the method early as we are only interested in testing the // workflow in `sweepPendingInputs`. We don't need to test `sweep` here // as it should be tested in its own unit test. - s.currentOutputScript = []byte{1} + dummyErr := errors.New("dummy error") + publisher.On("Broadcast", mock.Anything).Return(nil, dummyErr).Twice() // Call the method under test. s.sweepPendingInputs(pis) - - // Assert mocked methods are called as expected. - wallet.AssertExpectations(t) - aggregator.AssertExpectations(t) - setNeedWallet.AssertExpectations(t) - normalSet.AssertExpectations(t) +} + +// TestHandleBumpEventTxFailed checks that the sweeper correctly handles the +// case where the bump event tx fails to be published. +func TestHandleBumpEventTxFailed(t *testing.T) { + t.Parallel() + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{}) + + var ( + // Create four testing outpoints. + op1 = wire.OutPoint{Hash: chainhash.Hash{1}} + op2 = wire.OutPoint{Hash: chainhash.Hash{2}} + op3 = wire.OutPoint{Hash: chainhash.Hash{3}} + opNotExist = wire.OutPoint{Hash: chainhash.Hash{4}} + ) + + // Create three mock inputs. + input1 := &input.MockInput{} + defer input1.AssertExpectations(t) + + input2 := &input.MockInput{} + defer input2.AssertExpectations(t) + + input3 := &input.MockInput{} + defer input3.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op1: &pendingInput{Input: input1, state: StatePendingPublish}, + op2: &pendingInput{Input: input2, state: StatePendingPublish}, + op3: &pendingInput{Input: input3, state: StatePendingPublish}, + } + + // Create a testing tx that spends the first two inputs. + tx := &wire.MsgTx{ + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op1}, + {PreviousOutPoint: op2}, + {PreviousOutPoint: opNotExist}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + // Call the method under test. + err := s.handleBumpEvent(br) + require.ErrorIs(t, err, errDummy) + + // Assert the states of the first two inputs are updated. + require.Equal(t, StatePublishFailed, s.pendingInputs[op1].state) + require.Equal(t, StatePublishFailed, s.pendingInputs[op2].state) + + // Assert the state of the third input is not updated. + require.Equal(t, StatePendingPublish, s.pendingInputs[op3].state) + + // Assert the non-existing input is not added to the pending inputs. + require.NotContains(t, s.pendingInputs, opNotExist) +} + +// TestHandleBumpEventTxReplaced checks that the sweeper correctly handles the +// case where the bump event tx is replaced. +func TestHandleBumpEventTxReplaced(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a replacement tx. + replacementTx := &wire.MsgTx{ + LockTime: 2, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: replacementTx, + ReplacedTx: tx, + Event: TxReplaced, + } + + // Mock the store to return an error. + dummyErr := errors.New("dummy error") + store.On("GetTx", tx.TxHash()).Return(nil, dummyErr).Once() + + // Call the method under test and assert the error is returned. + err := s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + + // Mock an error returned when deleting the old tx record. + store.On("DeleteTx", tx.TxHash()).Return(dummyErr).Once() + + // Call the method under test and assert the error is returned. + err = s.handleBumpEventTxReplaced(br) + require.ErrorIs(t, err, dummyErr) + + // Mock the store to return the old tx record and delete it without + // error. + store.On("GetTx", tx.TxHash()).Return(&TxRecord{ + Txid: tx.TxHash(), + }, nil).Once() + store.On("DeleteTx", tx.TxHash()).Return(nil).Once() + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: replacementTx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err = s.handleBumpEventTxReplaced(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} + +// TestHandleBumpEventTxPublished checks that the sweeper correctly handles the +// case where the bump event tx is published. +func TestHandleBumpEventTxPublished(t *testing.T) { + t.Parallel() + + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + // Create a testing bump result. + br := &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + // Mock the store to save the new tx record. + store.On("StoreTx", &TxRecord{ + Txid: tx.TxHash(), + Published: true, + }).Return(nil).Once() + + // Call the method under test. + err := s.handleBumpEventTxPublished(br) + require.NoError(t, err) + + // Assert the state of the input is updated. + require.Equal(t, StatePublished, s.pendingInputs[op].state) +} + +// TestMonitorFeeBumpResult checks that the fee bump monitor loop correctly +// exits when the sweeper is stopped, the tx is confirmed or failed. +func TestMonitorFeeBumpResult(t *testing.T) { + // Create a mock store. + store := &MockSweeperStore{} + defer store.AssertExpectations(t) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: store, + }) + + // Create a testing outpoint. + op := wire.OutPoint{Hash: chainhash.Hash{1}} + + // Create a mock input. + inp := &input.MockInput{} + defer inp.AssertExpectations(t) + + // Construct the initial state for the sweeper. + s.pendingInputs = pendingInputs{ + op: &pendingInput{Input: inp, state: StatePendingPublish}, + } + + // Create a testing tx that spends the input. + tx := &wire.MsgTx{ + LockTime: 1, + TxIn: []*wire.TxIn{ + {PreviousOutPoint: op}, + }, + } + + testCases := []struct { + name string + setupResultChan func() <-chan *BumpResult + shouldExit bool + }{ + { + // When a tx confirmed event is received, we expect to + // exit the monitor loop. + name: "tx confirmed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxConfirmed, + Fee: 10000, + FeeRate: 100, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When a tx failed event is received, we expect to + // exit the monitor loop. + name: "tx failed", + // We send a result with TxConfirmed event to the + // result channel. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxFailed, + Err: errDummy, + } + + return resultChan + }, + shouldExit: true, + }, + { + // When processing non-confirmed events, the monitor + // should not exit. + name: "no exit on normal event", + // We send a result with TxPublished and mock the + // method `StoreTx` to return nil. + setupResultChan: func() <-chan *BumpResult { + // Create a result chan. + resultChan := make(chan *BumpResult, 1) + resultChan <- &BumpResult{ + Tx: tx, + Event: TxPublished, + } + + return resultChan + }, + shouldExit: false, + }, { + // When the sweeper is shutting down, the monitor loop + // should exit. + name: "exit on sweeper shutdown", + // We don't send anything but quit the sweeper. + setupResultChan: func() <-chan *BumpResult { + close(s.quit) + + return nil + }, + shouldExit: true, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Setup the testing result channel. + resultChan := tc.setupResultChan() + + // Create a done chan that's used to signal the monitor + // has exited. + done := make(chan struct{}) + + s.wg.Add(1) + go func() { + s.monitorFeeBumpResult(resultChan) + close(done) + }() + + // The monitor is expected to exit, we check it's done + // in one second or fail. + if tc.shouldExit { + select { + case <-done: + case <-time.After(1 * time.Second): + require.Fail(t, "monitor not exited") + } + + return + } + + // The monitor should not exit, check it doesn't close + // the `done` channel within one second. + select { + case <-done: + require.Fail(t, "monitor exited") + case <-time.After(1 * time.Second): + } + }) + } }