mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-05-12 04:40:05 +02:00
Merge pull request #7824 from yyforyongyu/sweeper-unit-test
input+sweep: make sure input with no fee rate is not added to cluster
This commit is contained in:
commit
e8d865aba5
@ -19,6 +19,10 @@
|
||||
|
||||
# Bug Fixes
|
||||
|
||||
* [Fixed a potential case](https://github.com/lightningnetwork/lnd/pull/7824)
|
||||
that when sweeping inputs with locktime, an unexpected lower fee rate is
|
||||
applied.
|
||||
|
||||
# New Features
|
||||
## Functional Enhancements
|
||||
|
||||
|
@ -331,14 +331,14 @@ func (b *Batcher) BatchFund(ctx context.Context,
|
||||
// settings from the first request as all of them should be equal
|
||||
// anyway.
|
||||
firstReq := b.channels[0].fundingReq
|
||||
feeRateSatPerKVByte := firstReq.FundingFeePerKw.FeePerKVByte()
|
||||
feeRateSatPerVByte := firstReq.FundingFeePerKw.FeePerVByte()
|
||||
changeType := walletrpc.ChangeAddressType_CHANGE_ADDRESS_TYPE_P2TR
|
||||
fundPsbtReq := &walletrpc.FundPsbtRequest{
|
||||
Template: &walletrpc.FundPsbtRequest_Raw{
|
||||
Raw: txTemplate,
|
||||
},
|
||||
Fees: &walletrpc.FundPsbtRequest_SatPerVbyte{
|
||||
SatPerVbyte: uint64(feeRateSatPerKVByte) / 1000,
|
||||
SatPerVbyte: uint64(feeRateSatPerVByte),
|
||||
},
|
||||
MinConfs: firstReq.MinConfs,
|
||||
SpendUnconfirmed: firstReq.MinConfs == 0,
|
||||
|
125
input/mocks.go
Normal file
125
input/mocks.go
Normal file
@ -0,0 +1,125 @@
|
||||
package input
|
||||
|
||||
import (
|
||||
"github.com/btcsuite/btcd/txscript"
|
||||
"github.com/btcsuite/btcd/wire"
|
||||
"github.com/stretchr/testify/mock"
|
||||
)
|
||||
|
||||
// MockInput implements the `Input` interface and is used by other packages for
|
||||
// mock testing.
|
||||
type MockInput struct {
|
||||
mock.Mock
|
||||
}
|
||||
|
||||
// Compile time assertion that MockInput implements Input.
|
||||
var _ Input = (*MockInput)(nil)
|
||||
|
||||
// Outpoint returns the reference to the output being spent, used to construct
|
||||
// the corresponding transaction input.
|
||||
func (m *MockInput) OutPoint() *wire.OutPoint {
|
||||
args := m.Called()
|
||||
op := args.Get(0)
|
||||
|
||||
if op == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return op.(*wire.OutPoint)
|
||||
}
|
||||
|
||||
// RequiredTxOut returns a non-nil TxOut if input commits to a certain
|
||||
// transaction output. This is used in the SINGLE|ANYONECANPAY case to make
|
||||
// sure any presigned input is still valid by including the output.
|
||||
func (m *MockInput) RequiredTxOut() *wire.TxOut {
|
||||
args := m.Called()
|
||||
txOut := args.Get(0)
|
||||
|
||||
if txOut == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return txOut.(*wire.TxOut)
|
||||
}
|
||||
|
||||
// RequiredLockTime returns whether this input commits to a tx locktime that
|
||||
// must be used in the transaction including it.
|
||||
func (m *MockInput) RequiredLockTime() (uint32, bool) {
|
||||
args := m.Called()
|
||||
|
||||
return args.Get(0).(uint32), args.Bool(1)
|
||||
}
|
||||
|
||||
// WitnessType returns an enum specifying the type of witness that must be
|
||||
// generated in order to spend this output.
|
||||
func (m *MockInput) WitnessType() WitnessType {
|
||||
args := m.Called()
|
||||
|
||||
wt := args.Get(0)
|
||||
if wt == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return wt.(WitnessType)
|
||||
}
|
||||
|
||||
// SignDesc returns a reference to a spendable output's sign descriptor, which
|
||||
// is used during signing to compute a valid witness that spends this output.
|
||||
func (m *MockInput) SignDesc() *SignDescriptor {
|
||||
args := m.Called()
|
||||
|
||||
sd := args.Get(0)
|
||||
if sd == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return sd.(*SignDescriptor)
|
||||
}
|
||||
|
||||
// CraftInputScript returns a valid set of input scripts allowing this output
|
||||
// to be spent. The returns input scripts should target the input at location
|
||||
// txIndex within the passed transaction. The input scripts generated by this
|
||||
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
|
||||
func (m *MockInput) CraftInputScript(signer Signer, txn *wire.MsgTx,
|
||||
hashCache *txscript.TxSigHashes,
|
||||
prevOutputFetcher txscript.PrevOutputFetcher,
|
||||
txinIdx int) (*Script, error) {
|
||||
|
||||
args := m.Called(signer, txn, hashCache, prevOutputFetcher, txinIdx)
|
||||
|
||||
s := args.Get(0)
|
||||
if s == nil {
|
||||
return nil, args.Error(1)
|
||||
}
|
||||
|
||||
return s.(*Script), args.Error(1)
|
||||
}
|
||||
|
||||
// BlocksToMaturity returns the relative timelock, as a number of blocks, that
|
||||
// must be built on top of the confirmation height before the output can be
|
||||
// spent. For non-CSV locked inputs this is always zero.
|
||||
func (m *MockInput) BlocksToMaturity() uint32 {
|
||||
args := m.Called()
|
||||
|
||||
return args.Get(0).(uint32)
|
||||
}
|
||||
|
||||
// HeightHint returns the minimum height at which a confirmed spending tx can
|
||||
// occur.
|
||||
func (m *MockInput) HeightHint() uint32 {
|
||||
args := m.Called()
|
||||
|
||||
return args.Get(0).(uint32)
|
||||
}
|
||||
|
||||
// UnconfParent returns information about a possibly unconfirmed parent tx.
|
||||
func (m *MockInput) UnconfParent() *TxInfo {
|
||||
args := m.Called()
|
||||
|
||||
info := args.Get(0)
|
||||
if info == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return info.(*TxInfo)
|
||||
}
|
@ -39,8 +39,8 @@ func DefaultWtClientCfg() *WtClient {
|
||||
// The sweep fee rate used internally by the tower client is in sats/kw
|
||||
// but the config exposed to the user is in sats/byte, so we convert the
|
||||
// default here before exposing it to the user.
|
||||
sweepSatsPerKvB := wtpolicy.DefaultSweepFeeRate.FeePerKVByte()
|
||||
sweepFeeRate := uint64(sweepSatsPerKvB / 1000)
|
||||
sweepSatsPerVB := wtpolicy.DefaultSweepFeeRate.FeePerVByte()
|
||||
sweepFeeRate := uint64(sweepSatsPerVB)
|
||||
|
||||
return &WtClient{
|
||||
SweepFeeRate: sweepFeeRate,
|
||||
|
@ -782,12 +782,12 @@ func (w *WalletKit) PendingSweeps(ctx context.Context,
|
||||
|
||||
op := lnrpc.MarshalOutPoint(&pendingInput.OutPoint)
|
||||
amountSat := uint32(pendingInput.Amount)
|
||||
satPerVbyte := uint64(pendingInput.LastFeeRate.FeePerKVByte() / 1000)
|
||||
satPerVbyte := uint64(pendingInput.LastFeeRate.FeePerVByte())
|
||||
broadcastAttempts := uint32(pendingInput.BroadcastAttempts)
|
||||
nextBroadcastHeight := uint32(pendingInput.NextBroadcastHeight)
|
||||
|
||||
requestedFee := pendingInput.Params.Fee
|
||||
requestedFeeRate := uint64(requestedFee.FeeRate.FeePerKVByte() / 1000)
|
||||
requestedFeeRate := uint64(requestedFee.FeeRate.FeePerVByte())
|
||||
|
||||
rpcPendingSweeps = append(rpcPendingSweeps, &PendingSweep{
|
||||
Outpoint: op,
|
||||
|
@ -476,15 +476,11 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
|
||||
}
|
||||
|
||||
return &PolicyResponse{
|
||||
MaxUpdates: uint32(policy.MaxUpdates),
|
||||
SweepSatPerVbyte: uint32(
|
||||
policy.SweepFeeRate.FeePerKVByte() / 1000,
|
||||
),
|
||||
MaxUpdates: uint32(policy.MaxUpdates),
|
||||
SweepSatPerVbyte: uint32(policy.SweepFeeRate.FeePerVByte()),
|
||||
|
||||
// Deprecated field.
|
||||
SweepSatPerByte: uint32(
|
||||
policy.SweepFeeRate.FeePerKVByte() / 1000,
|
||||
),
|
||||
SweepSatPerByte: uint32(policy.SweepFeeRate.FeePerVByte()),
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -519,7 +515,7 @@ func marshallTower(tower *wtclient.RegisteredTower, policyType PolicyType,
|
||||
|
||||
rpcSessions = make([]*TowerSession, 0, len(tower.Sessions))
|
||||
for _, session := range sessions {
|
||||
satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000
|
||||
satPerVByte := session.Policy.SweepFeeRate.FeePerVByte()
|
||||
rpcSessions = append(rpcSessions, &TowerSession{
|
||||
NumBackups: uint32(ackCounts[session.ID]),
|
||||
NumPendingBackups: uint32(pendingCounts[session.ID]),
|
||||
|
@ -70,6 +70,11 @@ func (s SatPerKWeight) FeePerKVByte() SatPerKVByte {
|
||||
return SatPerKVByte(s * blockchain.WitnessScaleFactor)
|
||||
}
|
||||
|
||||
// FeePerVByte converts the current fee rate from sat/kw to sat/vb.
|
||||
func (s SatPerKWeight) FeePerVByte() SatPerVByte {
|
||||
return SatPerVByte(s * blockchain.WitnessScaleFactor / 1000)
|
||||
}
|
||||
|
||||
// String returns a human-readable string of the fee rate.
|
||||
func (s SatPerKWeight) String() string {
|
||||
return fmt.Sprintf("%v sat/kw", int64(s))
|
||||
|
@ -1197,10 +1197,10 @@ func (r *rpcServer) EstimateFee(ctx context.Context,
|
||||
|
||||
resp := &lnrpc.EstimateFeeResponse{
|
||||
FeeSat: totalFee,
|
||||
SatPerVbyte: uint64(feePerKw.FeePerKVByte() / 1000),
|
||||
SatPerVbyte: uint64(feePerKw.FeePerVByte()),
|
||||
|
||||
// Deprecated field.
|
||||
FeerateSatPerByte: int64(feePerKw.FeePerKVByte() / 1000),
|
||||
FeerateSatPerByte: int64(feePerKw.FeePerVByte()),
|
||||
}
|
||||
|
||||
rpcsLog.Debugf("[estimatefee] fee estimate for conf target %d: %v",
|
||||
|
@ -1059,10 +1059,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
|
||||
}
|
||||
|
||||
s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{
|
||||
FeeEstimator: cc.FeeEstimator,
|
||||
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
Wallet: newSweeperWallet(cc.Wallet),
|
||||
FeeEstimator: cc.FeeEstimator,
|
||||
DetermineFeePerKw: sweep.DetermineFeePerKw,
|
||||
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
|
||||
Signer: cc.Wallet.Cfg.Signer,
|
||||
Wallet: newSweeperWallet(cc.Wallet),
|
||||
NewBatchTimer: func() <-chan time.Time {
|
||||
return time.NewTimer(cfg.Sweeper.BatchWindowDuration).C
|
||||
},
|
||||
|
@ -47,6 +47,10 @@ var (
|
||||
// request from a client whom did not specify a fee preference.
|
||||
ErrNoFeePreference = errors.New("no fee preference specified")
|
||||
|
||||
// ErrFeePreferenceTooLow is returned when the fee preference gives a
|
||||
// fee rate that's below the relay fee rate.
|
||||
ErrFeePreferenceTooLow = errors.New("fee preference too low")
|
||||
|
||||
// ErrExclusiveGroupSpend is returned in case a different input of the
|
||||
// same exclusive group was spent.
|
||||
ErrExclusiveGroupSpend = errors.New("other member of exclusive group " +
|
||||
@ -237,12 +241,21 @@ type UtxoSweeper struct {
|
||||
wg sync.WaitGroup
|
||||
}
|
||||
|
||||
// feeDeterminer defines an alias to the function signature of
|
||||
// `DetermineFeePerKw`.
|
||||
type feeDeterminer func(chainfee.Estimator,
|
||||
FeePreference) (chainfee.SatPerKWeight, error)
|
||||
|
||||
// UtxoSweeperConfig contains dependencies of UtxoSweeper.
|
||||
type UtxoSweeperConfig struct {
|
||||
// GenSweepScript generates a P2WKH script belonging to the wallet where
|
||||
// funds can be swept.
|
||||
GenSweepScript func() ([]byte, error)
|
||||
|
||||
// DetermineFeePerKw determines the fee in sat/kw based on the given
|
||||
// estimator and fee preference.
|
||||
DetermineFeePerKw feeDeterminer
|
||||
|
||||
// FeeEstimator is used when crafting sweep transactions to estimate
|
||||
// the necessary fee relative to the expected size of the sweep
|
||||
// transaction.
|
||||
@ -470,13 +483,16 @@ func (s *UtxoSweeper) feeRateForPreference(
|
||||
return 0, ErrNoFeePreference
|
||||
}
|
||||
|
||||
feeRate, err := DetermineFeePerKw(s.cfg.FeeEstimator, feePreference)
|
||||
feeRate, err := s.cfg.DetermineFeePerKw(
|
||||
s.cfg.FeeEstimator, feePreference,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if feeRate < s.relayFeeRate {
|
||||
return 0, fmt.Errorf("fee preference resulted in invalid fee "+
|
||||
"rate %v, minimum is %v", feeRate, s.relayFeeRate)
|
||||
return 0, fmt.Errorf("%w: got %v, minimum is %v",
|
||||
ErrFeePreferenceTooLow, feeRate, s.relayFeeRate)
|
||||
}
|
||||
|
||||
// If the estimated fee rate is above the maximum allowed fee rate,
|
||||
@ -912,7 +928,6 @@ func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
|
||||
pendingInputs) {
|
||||
|
||||
locktimes := make(map[uint32]pendingInputs)
|
||||
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
|
||||
rem := make(pendingInputs)
|
||||
|
||||
// Go through all inputs and check if they require a certain locktime.
|
||||
@ -924,41 +939,48 @@ func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
|
||||
}
|
||||
|
||||
// Check if we already have inputs with this locktime.
|
||||
p, ok := locktimes[lt]
|
||||
cluster, ok := locktimes[lt]
|
||||
if !ok {
|
||||
p = make(pendingInputs)
|
||||
cluster = make(pendingInputs)
|
||||
}
|
||||
|
||||
p[op] = input
|
||||
locktimes[lt] = p
|
||||
|
||||
// We also get the preferred fee rate for this input.
|
||||
// Get the fee rate based on the fee preference. If an error is
|
||||
// returned, we'll skip sweeping this input for this round of
|
||||
// cluster creation and retry it when we create the clusters
|
||||
// from the pending inputs again.
|
||||
feeRate, err := s.feeRateForPreference(input.params.Fee)
|
||||
if err != nil {
|
||||
log.Warnf("Skipping input %v: %v", op, err)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf("Adding input %v to cluster with locktime=%v, "+
|
||||
"feeRate=%v", op, lt, feeRate)
|
||||
|
||||
// Attach the fee rate to the input.
|
||||
input.lastFeeRate = feeRate
|
||||
inputFeeRates[op] = feeRate
|
||||
|
||||
// Update the cluster about the updated input.
|
||||
cluster[op] = input
|
||||
locktimes[lt] = cluster
|
||||
}
|
||||
|
||||
// We'll then determine the sweep fee rate for each set of inputs by
|
||||
// calculating the average fee rate of the inputs within each set.
|
||||
inputClusters := make([]inputCluster, 0, len(locktimes))
|
||||
for lt, inputs := range locktimes {
|
||||
for lt, cluster := range locktimes {
|
||||
lt := lt
|
||||
|
||||
var sweepFeeRate chainfee.SatPerKWeight
|
||||
for op := range inputs {
|
||||
sweepFeeRate += inputFeeRates[op]
|
||||
for _, input := range cluster {
|
||||
sweepFeeRate += input.lastFeeRate
|
||||
}
|
||||
|
||||
sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
|
||||
sweepFeeRate /= chainfee.SatPerKWeight(len(cluster))
|
||||
inputClusters = append(inputClusters, inputCluster{
|
||||
lockTime: <,
|
||||
sweepFeeRate: sweepFeeRate,
|
||||
inputs: inputs,
|
||||
inputs: cluster,
|
||||
})
|
||||
}
|
||||
|
||||
@ -1599,7 +1621,7 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) (
|
||||
func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference,
|
||||
currentBlockHeight uint32) (*wire.MsgTx, error) {
|
||||
|
||||
feePerKw, err := DetermineFeePerKw(s.cfg.FeeEstimator, feePref)
|
||||
feePerKw, err := s.cfg.DetermineFeePerKw(s.cfg.FeeEstimator, feePref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -1,10 +1,12 @@
|
||||
package sweep
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime/debug"
|
||||
"runtime/pprof"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -149,6 +151,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
|
||||
},
|
||||
MaxFeeRate: DefaultMaxFeeRate,
|
||||
FeeRateBucketSize: DefaultFeeRateBucketSize,
|
||||
DetermineFeePerKw: DetermineFeePerKw,
|
||||
})
|
||||
|
||||
ctx.sweeper.Start()
|
||||
@ -2238,3 +2241,301 @@ func TestSweeperShutdownHandling(t *testing.T) {
|
||||
)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// TestFeeRateForPreference checks `feeRateForPreference` works as expected.
|
||||
func TestFeeRateForPreference(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dummyErr := errors.New("dummy")
|
||||
|
||||
// Create a test sweeper.
|
||||
s := New(&UtxoSweeperConfig{})
|
||||
|
||||
// errFeeFunc is a mock over DetermineFeePerKw that always return the
|
||||
// above dummy error.
|
||||
errFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
|
||||
chainfee.SatPerKWeight, error) {
|
||||
|
||||
return 0, dummyErr
|
||||
}
|
||||
|
||||
// Set the relay fee rate to be 1.
|
||||
s.relayFeeRate = 1
|
||||
|
||||
// smallFeeFunc is a mock over DetermineFeePerKw that always return a
|
||||
// fee rate that's below the relayFeeRate.
|
||||
smallFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
|
||||
chainfee.SatPerKWeight, error) {
|
||||
|
||||
return s.relayFeeRate - 1, nil
|
||||
}
|
||||
|
||||
// Set the max fee rate to be 1000 sat/vb.
|
||||
s.cfg.MaxFeeRate = 1000
|
||||
|
||||
// largeFeeFunc is a mock over DetermineFeePerKw that always return a
|
||||
// fee rate that's larger than the MaxFeeRate.
|
||||
largeFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
|
||||
chainfee.SatPerKWeight, error) {
|
||||
|
||||
return s.cfg.MaxFeeRate.FeePerKWeight() + 1, nil
|
||||
}
|
||||
|
||||
// validFeeRate is used to test the success case.
|
||||
validFeeRate := (s.cfg.MaxFeeRate.FeePerKWeight() + s.relayFeeRate) / 2
|
||||
|
||||
// normalFeeFunc is a mock over DetermineFeePerKw that always return a
|
||||
// fee rate that's within the range.
|
||||
normalFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
|
||||
chainfee.SatPerKWeight, error) {
|
||||
|
||||
return validFeeRate, nil
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
feePref FeePreference
|
||||
determineFeePerKw feeDeterminer
|
||||
expectedFeeRate chainfee.SatPerKWeight
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
// When the fee preference is empty, we should see an
|
||||
// error.
|
||||
name: "empty fee preference",
|
||||
feePref: FeePreference{},
|
||||
expectedErr: ErrNoFeePreference,
|
||||
},
|
||||
{
|
||||
// When an error is returned from the fee determinor,
|
||||
// we should return it.
|
||||
name: "error from DetermineFeePerKw",
|
||||
feePref: FeePreference{FeeRate: 1},
|
||||
determineFeePerKw: errFeeFunc,
|
||||
expectedErr: dummyErr,
|
||||
},
|
||||
{
|
||||
// When DetermineFeePerKw gives a too small value, we
|
||||
// should return an error.
|
||||
name: "fee rate below relay fee rate",
|
||||
feePref: FeePreference{FeeRate: 1},
|
||||
determineFeePerKw: smallFeeFunc,
|
||||
expectedErr: ErrFeePreferenceTooLow,
|
||||
},
|
||||
{
|
||||
// When DetermineFeePerKw gives a too large value, we
|
||||
// should cap it at the max fee rate.
|
||||
name: "fee rate above max fee rate",
|
||||
feePref: FeePreference{FeeRate: 1},
|
||||
determineFeePerKw: largeFeeFunc,
|
||||
expectedFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(),
|
||||
},
|
||||
{
|
||||
// When DetermineFeePerKw gives a sane fee rate, we
|
||||
// should return it without any error.
|
||||
name: "success",
|
||||
feePref: FeePreference{FeeRate: 1},
|
||||
determineFeePerKw: normalFeeFunc,
|
||||
expectedFeeRate: validFeeRate,
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:paralleltest
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Attach the mocked method.
|
||||
s.cfg.DetermineFeePerKw = tc.determineFeePerKw
|
||||
|
||||
// Call the function under test.
|
||||
feerate, err := s.feeRateForPreference(tc.feePref)
|
||||
|
||||
// Assert the expected feerate.
|
||||
require.Equal(t, tc.expectedFeeRate, feerate)
|
||||
|
||||
// Assert the expected error.
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestClusterByLockTime tests the method clusterByLockTime works as expected.
|
||||
func TestClusterByLockTime(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create a test param with a dummy fee preference. This is needed so
|
||||
// `feeRateForPreference` won't throw an error.
|
||||
param := Params{Fee: FeePreference{ConfTarget: 1}}
|
||||
|
||||
// We begin the test by creating three clusters of inputs, the first
|
||||
// cluster has a locktime of 1, the second has a locktime of 2, and the
|
||||
// final has no locktime.
|
||||
lockTime1 := uint32(1)
|
||||
lockTime2 := uint32(2)
|
||||
|
||||
// Create cluster one, which has a locktime of 1.
|
||||
input1LockTime1 := &input.MockInput{}
|
||||
input2LockTime1 := &input.MockInput{}
|
||||
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
|
||||
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
|
||||
|
||||
// Create cluster two, which has a locktime of 2.
|
||||
input3LockTime2 := &input.MockInput{}
|
||||
input4LockTime2 := &input.MockInput{}
|
||||
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
|
||||
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
|
||||
|
||||
// Create cluster three, which has no locktime.
|
||||
input5NoLockTime := &input.MockInput{}
|
||||
input6NoLockTime := &input.MockInput{}
|
||||
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
|
||||
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
|
||||
|
||||
// With the inner Input being mocked, we can now create the pending
|
||||
// inputs.
|
||||
input1 := &pendingInput{Input: input1LockTime1, params: param}
|
||||
input2 := &pendingInput{Input: input2LockTime1, params: param}
|
||||
input3 := &pendingInput{Input: input3LockTime2, params: param}
|
||||
input4 := &pendingInput{Input: input4LockTime2, params: param}
|
||||
input5 := &pendingInput{Input: input5NoLockTime, params: param}
|
||||
input6 := &pendingInput{Input: input6NoLockTime, params: param}
|
||||
|
||||
// Create the pending inputs map, which will be passed to the method
|
||||
// under test.
|
||||
//
|
||||
// NOTE: we don't care the actual outpoint values as long as they are
|
||||
// unique.
|
||||
inputs := pendingInputs{
|
||||
wire.OutPoint{Index: 1}: input1,
|
||||
wire.OutPoint{Index: 2}: input2,
|
||||
wire.OutPoint{Index: 3}: input3,
|
||||
wire.OutPoint{Index: 4}: input4,
|
||||
wire.OutPoint{Index: 5}: input5,
|
||||
wire.OutPoint{Index: 6}: input6,
|
||||
}
|
||||
|
||||
// Create expected clusters so we can shorten the line length in the
|
||||
// test cases below.
|
||||
cluster1 := pendingInputs{
|
||||
wire.OutPoint{Index: 1}: input1,
|
||||
wire.OutPoint{Index: 2}: input2,
|
||||
}
|
||||
cluster2 := pendingInputs{
|
||||
wire.OutPoint{Index: 3}: input3,
|
||||
wire.OutPoint{Index: 4}: input4,
|
||||
}
|
||||
|
||||
// cluster3 should be the remaining inputs since they don't have
|
||||
// locktime.
|
||||
cluster3 := pendingInputs{
|
||||
wire.OutPoint{Index: 5}: input5,
|
||||
wire.OutPoint{Index: 6}: input6,
|
||||
}
|
||||
|
||||
// Set the min fee rate to be 1000 sat/kw.
|
||||
const minFeeRate = chainfee.SatPerKWeight(1000)
|
||||
|
||||
// Create a test sweeper.
|
||||
s := New(&UtxoSweeperConfig{
|
||||
MaxFeeRate: minFeeRate.FeePerVByte() * 10,
|
||||
})
|
||||
|
||||
// Set the relay fee to be the minFeeRate. Any fee rate below the
|
||||
// minFeeRate will cause an error to be returned.
|
||||
s.relayFeeRate = minFeeRate
|
||||
|
||||
// applyFeeRate takes a testing fee rate and makes a mocker over
|
||||
// DetermineFeePerKw that always return the testing fee rate. This
|
||||
// mocked method is then attached to the sweeper.
|
||||
applyFeeRate := func(feeRate chainfee.SatPerKWeight) {
|
||||
mockFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
|
||||
chainfee.SatPerKWeight, error) {
|
||||
|
||||
return feeRate, nil
|
||||
}
|
||||
|
||||
s.cfg.DetermineFeePerKw = mockFeeFunc
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFeeRate chainfee.SatPerKWeight
|
||||
expectedClusters []inputCluster
|
||||
expectedRemainingInputs pendingInputs
|
||||
}{
|
||||
{
|
||||
// Test a successful case where the locktime clusters
|
||||
// are created and the no-locktime cluster is returned
|
||||
// as the remaining inputs.
|
||||
name: "successfully create clusters",
|
||||
// Use a fee rate above the min value so we don't hit
|
||||
// an error when performing fee estimation.
|
||||
//
|
||||
// TODO(yy): we should customize the returned fee rate
|
||||
// for each input to further test the averaging logic.
|
||||
// Or we can split the method into two, one for
|
||||
// grouping the clusters and the other for averaging
|
||||
// the fee rates so it's easier to be tested.
|
||||
testFeeRate: minFeeRate + 1,
|
||||
expectedClusters: []inputCluster{
|
||||
{
|
||||
lockTime: &lockTime1,
|
||||
sweepFeeRate: minFeeRate + 1,
|
||||
inputs: cluster1,
|
||||
},
|
||||
{
|
||||
lockTime: &lockTime2,
|
||||
sweepFeeRate: minFeeRate + 1,
|
||||
inputs: cluster2,
|
||||
},
|
||||
},
|
||||
expectedRemainingInputs: cluster3,
|
||||
},
|
||||
{
|
||||
// Test that when the input is skipped when the fee
|
||||
// estimation returns an error.
|
||||
name: "error from fee estimation",
|
||||
// Use a fee rate below the min value so we hit an
|
||||
// error when performing fee estimation.
|
||||
testFeeRate: minFeeRate - 1,
|
||||
expectedClusters: []inputCluster{},
|
||||
// Remaining inputs should stay untouched.
|
||||
expectedRemainingInputs: cluster3,
|
||||
},
|
||||
}
|
||||
|
||||
//nolint:paralleltest
|
||||
for _, tc := range testCases {
|
||||
tc := tc
|
||||
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Apply the test fee rate so `feeRateForPreference` is
|
||||
// mocked to return the specified value.
|
||||
applyFeeRate(tc.testFeeRate)
|
||||
|
||||
// Call the method under test.
|
||||
clusters, remainingInputs := s.clusterByLockTime(inputs)
|
||||
|
||||
// Sort by locktime as the order is not guaranteed.
|
||||
sort.Slice(clusters, func(i, j int) bool {
|
||||
return *clusters[i].lockTime <
|
||||
*clusters[j].lockTime
|
||||
})
|
||||
|
||||
// Validate the values are returned as expected.
|
||||
require.Equal(t, tc.expectedClusters, clusters)
|
||||
require.Equal(t, tc.expectedRemainingInputs,
|
||||
remainingInputs,
|
||||
)
|
||||
|
||||
// Assert the mocked methods are called as expeceted.
|
||||
input1LockTime1.AssertExpectations(t)
|
||||
input2LockTime1.AssertExpectations(t)
|
||||
input3LockTime2.AssertExpectations(t)
|
||||
input4LockTime2.AssertExpectations(t)
|
||||
input5NoLockTime.AssertExpectations(t)
|
||||
input6NoLockTime.AssertExpectations(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user