diff --git a/input/mocks.go b/input/mocks.go new file mode 100644 index 000000000..965489eff --- /dev/null +++ b/input/mocks.go @@ -0,0 +1,125 @@ +package input + +import ( + "github.com/btcsuite/btcd/txscript" + "github.com/btcsuite/btcd/wire" + "github.com/stretchr/testify/mock" +) + +// MockInput implements the `Input` interface and is used by other packages for +// mock testing. +type MockInput struct { + mock.Mock +} + +// Compile time assertion that MockInput implements Input. +var _ Input = (*MockInput)(nil) + +// Outpoint returns the reference to the output being spent, used to construct +// the corresponding transaction input. +func (m *MockInput) OutPoint() *wire.OutPoint { + args := m.Called() + op := args.Get(0) + + if op == nil { + return nil + } + + return op.(*wire.OutPoint) +} + +// RequiredTxOut returns a non-nil TxOut if input commits to a certain +// transaction output. This is used in the SINGLE|ANYONECANPAY case to make +// sure any presigned input is still valid by including the output. +func (m *MockInput) RequiredTxOut() *wire.TxOut { + args := m.Called() + txOut := args.Get(0) + + if txOut == nil { + return nil + } + + return txOut.(*wire.TxOut) +} + +// RequiredLockTime returns whether this input commits to a tx locktime that +// must be used in the transaction including it. +func (m *MockInput) RequiredLockTime() (uint32, bool) { + args := m.Called() + + return args.Get(0).(uint32), args.Bool(1) +} + +// WitnessType returns an enum specifying the type of witness that must be +// generated in order to spend this output. +func (m *MockInput) WitnessType() WitnessType { + args := m.Called() + + wt := args.Get(0) + if wt == nil { + return nil + } + + return wt.(WitnessType) +} + +// SignDesc returns a reference to a spendable output's sign descriptor, which +// is used during signing to compute a valid witness that spends this output. +func (m *MockInput) SignDesc() *SignDescriptor { + args := m.Called() + + sd := args.Get(0) + if sd == nil { + return nil + } + + return sd.(*SignDescriptor) +} + +// CraftInputScript returns a valid set of input scripts allowing this output +// to be spent. The returns input scripts should target the input at location +// txIndex within the passed transaction. The input scripts generated by this +// method support spending p2wkh, p2wsh, and also nested p2sh outputs. +func (m *MockInput) CraftInputScript(signer Signer, txn *wire.MsgTx, + hashCache *txscript.TxSigHashes, + prevOutputFetcher txscript.PrevOutputFetcher, + txinIdx int) (*Script, error) { + + args := m.Called(signer, txn, hashCache, prevOutputFetcher, txinIdx) + + s := args.Get(0) + if s == nil { + return nil, args.Error(1) + } + + return s.(*Script), args.Error(1) +} + +// BlocksToMaturity returns the relative timelock, as a number of blocks, that +// must be built on top of the confirmation height before the output can be +// spent. For non-CSV locked inputs this is always zero. +func (m *MockInput) BlocksToMaturity() uint32 { + args := m.Called() + + return args.Get(0).(uint32) +} + +// HeightHint returns the minimum height at which a confirmed spending tx can +// occur. +func (m *MockInput) HeightHint() uint32 { + args := m.Called() + + return args.Get(0).(uint32) +} + +// UnconfParent returns information about a possibly unconfirmed parent tx. +func (m *MockInput) UnconfParent() *TxInfo { + args := m.Called() + + info := args.Get(0) + if info == nil { + return nil + } + + return info.(*TxInfo) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 345121f68..cda02db9f 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -944,17 +944,25 @@ func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster, cluster = make(pendingInputs) } - cluster[op] = input - locktimes[lt] = cluster - - // We also get the preferred fee rate for this input. + // Get the fee rate based on the fee preference. If an error is + // returned, we'll skip sweeping this input for this round of + // cluster creation and retry it when we create the clusters + // from the pending inputs again. feeRate, err := s.feeRateForPreference(input.params.Fee) if err != nil { log.Warnf("Skipping input %v: %v", op, err) continue } + log.Debugf("Adding input %v to cluster with locktime=%v, "+ + "feeRate=%v", op, lt, feeRate) + + // Attach the fee rate to the input. input.lastFeeRate = feeRate + + // Update the cluster about the updated input. + cluster[op] = input + locktimes[lt] = cluster } // We'll then determine the sweep fee rate for each set of inputs by diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index dcb7a5792..9d86d7f9f 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -6,6 +6,7 @@ import ( "reflect" "runtime/debug" "runtime/pprof" + "sort" "testing" "time" @@ -2357,3 +2358,182 @@ func TestFeeRateForPreference(t *testing.T) { }) } } + +// TestClusterByLockTime tests the method clusterByLockTime works as expected. +func TestClusterByLockTime(t *testing.T) { + t.Parallel() + + // Create a test param with a dummy fee preference. This is needed so + // `feeRateForPreference` won't throw an error. + param := Params{Fee: FeePreference{ConfTarget: 1}} + + // We begin the test by creating three clusters of inputs, the first + // cluster has a locktime of 1, the second has a locktime of 2, and the + // final has no locktime. + lockTime1 := uint32(1) + lockTime2 := uint32(2) + + // Create cluster one, which has a locktime of 1. + input1LockTime1 := &input.MockInput{} + input2LockTime1 := &input.MockInput{} + input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) + input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) + + // Create cluster two, which has a locktime of 2. + input3LockTime2 := &input.MockInput{} + input4LockTime2 := &input.MockInput{} + input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) + input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) + + // Create cluster three, which has no locktime. + input5NoLockTime := &input.MockInput{} + input6NoLockTime := &input.MockInput{} + input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + + // With the inner Input being mocked, we can now create the pending + // inputs. + input1 := &pendingInput{Input: input1LockTime1, params: param} + input2 := &pendingInput{Input: input2LockTime1, params: param} + input3 := &pendingInput{Input: input3LockTime2, params: param} + input4 := &pendingInput{Input: input4LockTime2, params: param} + input5 := &pendingInput{Input: input5NoLockTime, params: param} + input6 := &pendingInput{Input: input6NoLockTime, params: param} + + // Create the pending inputs map, which will be passed to the method + // under test. + // + // NOTE: we don't care the actual outpoint values as long as they are + // unique. + inputs := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + // Create expected clusters so we can shorten the line length in the + // test cases below. + cluster1 := pendingInputs{ + wire.OutPoint{Index: 1}: input1, + wire.OutPoint{Index: 2}: input2, + } + cluster2 := pendingInputs{ + wire.OutPoint{Index: 3}: input3, + wire.OutPoint{Index: 4}: input4, + } + + // cluster3 should be the remaining inputs since they don't have + // locktime. + cluster3 := pendingInputs{ + wire.OutPoint{Index: 5}: input5, + wire.OutPoint{Index: 6}: input6, + } + + // Set the min fee rate to be 1000 sat/kw. + const minFeeRate = chainfee.SatPerKWeight(1000) + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + MaxFeeRate: minFeeRate.FeePerVByte() * 10, + }) + + // Set the relay fee to be the minFeeRate. Any fee rate below the + // minFeeRate will cause an error to be returned. + s.relayFeeRate = minFeeRate + + // applyFeeRate takes a testing fee rate and makes a mocker over + // DetermineFeePerKw that always return the testing fee rate. This + // mocked method is then attached to the sweeper. + applyFeeRate := func(feeRate chainfee.SatPerKWeight) { + mockFeeFunc := func(_ chainfee.Estimator, _ FeePreference) ( + chainfee.SatPerKWeight, error) { + + return feeRate, nil + } + + s.cfg.DetermineFeePerKw = mockFeeFunc + } + + testCases := []struct { + name string + testFeeRate chainfee.SatPerKWeight + expectedClusters []inputCluster + expectedRemainingInputs pendingInputs + }{ + { + // Test a successful case where the locktime clusters + // are created and the no-locktime cluster is returned + // as the remaining inputs. + name: "successfully create clusters", + // Use a fee rate above the min value so we don't hit + // an error when performing fee estimation. + // + // TODO(yy): we should customize the returned fee rate + // for each input to further test the averaging logic. + // Or we can split the method into two, one for + // grouping the clusters and the other for averaging + // the fee rates so it's easier to be tested. + testFeeRate: minFeeRate + 1, + expectedClusters: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: minFeeRate + 1, + inputs: cluster1, + }, + { + lockTime: &lockTime2, + sweepFeeRate: minFeeRate + 1, + inputs: cluster2, + }, + }, + expectedRemainingInputs: cluster3, + }, + { + // Test that when the input is skipped when the fee + // estimation returns an error. + name: "error from fee estimation", + // Use a fee rate below the min value so we hit an + // error when performing fee estimation. + testFeeRate: minFeeRate - 1, + expectedClusters: []inputCluster{}, + // Remaining inputs should stay untouched. + expectedRemainingInputs: cluster3, + }, + } + + for _, tc := range testCases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + // Apply the test fee rate so `feeRateForPreference` is + // mocked to return the specified value. + applyFeeRate(tc.testFeeRate) + + // Call the method under test. + clusters, remainingInputs := s.clusterByLockTime(inputs) + + // Sort by locktime as the order is not guaranteed. + sort.Slice(clusters, func(i, j int) bool { + return *clusters[i].lockTime < + *clusters[j].lockTime + }) + + // Validate the values are returned as expected. + require.Equal(t, tc.expectedClusters, clusters) + require.Equal(t, tc.expectedRemainingInputs, + remainingInputs, + ) + + // Assert the mocked methods are called as expeceted. + input1LockTime1.AssertExpectations(t) + input2LockTime1.AssertExpectations(t) + input3LockTime2.AssertExpectations(t) + input4LockTime2.AssertExpectations(t) + input5NoLockTime.AssertExpectations(t) + input6NoLockTime.AssertExpectations(t) + }) + } +}