Merge pull request #7824 from yyforyongyu/sweeper-unit-test

input+sweep: make sure input with no fee rate is not added to cluster
This commit is contained in:
Oliver Gugger 2023-10-17 06:36:21 +00:00 committed by GitHub
commit e8d865aba5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 491 additions and 37 deletions

View File

@ -19,6 +19,10 @@
# Bug Fixes
* [Fixed a potential case](https://github.com/lightningnetwork/lnd/pull/7824)
that when sweeping inputs with locktime, an unexpected lower fee rate is
applied.
# New Features
## Functional Enhancements

View File

@ -331,14 +331,14 @@ func (b *Batcher) BatchFund(ctx context.Context,
// settings from the first request as all of them should be equal
// anyway.
firstReq := b.channels[0].fundingReq
feeRateSatPerKVByte := firstReq.FundingFeePerKw.FeePerKVByte()
feeRateSatPerVByte := firstReq.FundingFeePerKw.FeePerVByte()
changeType := walletrpc.ChangeAddressType_CHANGE_ADDRESS_TYPE_P2TR
fundPsbtReq := &walletrpc.FundPsbtRequest{
Template: &walletrpc.FundPsbtRequest_Raw{
Raw: txTemplate,
},
Fees: &walletrpc.FundPsbtRequest_SatPerVbyte{
SatPerVbyte: uint64(feeRateSatPerKVByte) / 1000,
SatPerVbyte: uint64(feeRateSatPerVByte),
},
MinConfs: firstReq.MinConfs,
SpendUnconfirmed: firstReq.MinConfs == 0,

125
input/mocks.go Normal file
View File

@ -0,0 +1,125 @@
package input
import (
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/stretchr/testify/mock"
)
// MockInput implements the `Input` interface and is used by other packages for
// mock testing.
type MockInput struct {
mock.Mock
}
// Compile time assertion that MockInput implements Input.
var _ Input = (*MockInput)(nil)
// Outpoint returns the reference to the output being spent, used to construct
// the corresponding transaction input.
func (m *MockInput) OutPoint() *wire.OutPoint {
args := m.Called()
op := args.Get(0)
if op == nil {
return nil
}
return op.(*wire.OutPoint)
}
// RequiredTxOut returns a non-nil TxOut if input commits to a certain
// transaction output. This is used in the SINGLE|ANYONECANPAY case to make
// sure any presigned input is still valid by including the output.
func (m *MockInput) RequiredTxOut() *wire.TxOut {
args := m.Called()
txOut := args.Get(0)
if txOut == nil {
return nil
}
return txOut.(*wire.TxOut)
}
// RequiredLockTime returns whether this input commits to a tx locktime that
// must be used in the transaction including it.
func (m *MockInput) RequiredLockTime() (uint32, bool) {
args := m.Called()
return args.Get(0).(uint32), args.Bool(1)
}
// WitnessType returns an enum specifying the type of witness that must be
// generated in order to spend this output.
func (m *MockInput) WitnessType() WitnessType {
args := m.Called()
wt := args.Get(0)
if wt == nil {
return nil
}
return wt.(WitnessType)
}
// SignDesc returns a reference to a spendable output's sign descriptor, which
// is used during signing to compute a valid witness that spends this output.
func (m *MockInput) SignDesc() *SignDescriptor {
args := m.Called()
sd := args.Get(0)
if sd == nil {
return nil
}
return sd.(*SignDescriptor)
}
// CraftInputScript returns a valid set of input scripts allowing this output
// to be spent. The returns input scripts should target the input at location
// txIndex within the passed transaction. The input scripts generated by this
// method support spending p2wkh, p2wsh, and also nested p2sh outputs.
func (m *MockInput) CraftInputScript(signer Signer, txn *wire.MsgTx,
hashCache *txscript.TxSigHashes,
prevOutputFetcher txscript.PrevOutputFetcher,
txinIdx int) (*Script, error) {
args := m.Called(signer, txn, hashCache, prevOutputFetcher, txinIdx)
s := args.Get(0)
if s == nil {
return nil, args.Error(1)
}
return s.(*Script), args.Error(1)
}
// BlocksToMaturity returns the relative timelock, as a number of blocks, that
// must be built on top of the confirmation height before the output can be
// spent. For non-CSV locked inputs this is always zero.
func (m *MockInput) BlocksToMaturity() uint32 {
args := m.Called()
return args.Get(0).(uint32)
}
// HeightHint returns the minimum height at which a confirmed spending tx can
// occur.
func (m *MockInput) HeightHint() uint32 {
args := m.Called()
return args.Get(0).(uint32)
}
// UnconfParent returns information about a possibly unconfirmed parent tx.
func (m *MockInput) UnconfParent() *TxInfo {
args := m.Called()
info := args.Get(0)
if info == nil {
return nil
}
return info.(*TxInfo)
}

View File

@ -39,8 +39,8 @@ func DefaultWtClientCfg() *WtClient {
// The sweep fee rate used internally by the tower client is in sats/kw
// but the config exposed to the user is in sats/byte, so we convert the
// default here before exposing it to the user.
sweepSatsPerKvB := wtpolicy.DefaultSweepFeeRate.FeePerKVByte()
sweepFeeRate := uint64(sweepSatsPerKvB / 1000)
sweepSatsPerVB := wtpolicy.DefaultSweepFeeRate.FeePerVByte()
sweepFeeRate := uint64(sweepSatsPerVB)
return &WtClient{
SweepFeeRate: sweepFeeRate,

View File

@ -782,12 +782,12 @@ func (w *WalletKit) PendingSweeps(ctx context.Context,
op := lnrpc.MarshalOutPoint(&pendingInput.OutPoint)
amountSat := uint32(pendingInput.Amount)
satPerVbyte := uint64(pendingInput.LastFeeRate.FeePerKVByte() / 1000)
satPerVbyte := uint64(pendingInput.LastFeeRate.FeePerVByte())
broadcastAttempts := uint32(pendingInput.BroadcastAttempts)
nextBroadcastHeight := uint32(pendingInput.NextBroadcastHeight)
requestedFee := pendingInput.Params.Fee
requestedFeeRate := uint64(requestedFee.FeeRate.FeePerKVByte() / 1000)
requestedFeeRate := uint64(requestedFee.FeeRate.FeePerVByte())
rpcPendingSweeps = append(rpcPendingSweeps, &PendingSweep{
Outpoint: op,

View File

@ -476,15 +476,11 @@ func (c *WatchtowerClient) Policy(ctx context.Context,
}
return &PolicyResponse{
MaxUpdates: uint32(policy.MaxUpdates),
SweepSatPerVbyte: uint32(
policy.SweepFeeRate.FeePerKVByte() / 1000,
),
MaxUpdates: uint32(policy.MaxUpdates),
SweepSatPerVbyte: uint32(policy.SweepFeeRate.FeePerVByte()),
// Deprecated field.
SweepSatPerByte: uint32(
policy.SweepFeeRate.FeePerKVByte() / 1000,
),
SweepSatPerByte: uint32(policy.SweepFeeRate.FeePerVByte()),
}, nil
}
@ -519,7 +515,7 @@ func marshallTower(tower *wtclient.RegisteredTower, policyType PolicyType,
rpcSessions = make([]*TowerSession, 0, len(tower.Sessions))
for _, session := range sessions {
satPerVByte := session.Policy.SweepFeeRate.FeePerKVByte() / 1000
satPerVByte := session.Policy.SweepFeeRate.FeePerVByte()
rpcSessions = append(rpcSessions, &TowerSession{
NumBackups: uint32(ackCounts[session.ID]),
NumPendingBackups: uint32(pendingCounts[session.ID]),

View File

@ -70,6 +70,11 @@ func (s SatPerKWeight) FeePerKVByte() SatPerKVByte {
return SatPerKVByte(s * blockchain.WitnessScaleFactor)
}
// FeePerVByte converts the current fee rate from sat/kw to sat/vb.
func (s SatPerKWeight) FeePerVByte() SatPerVByte {
return SatPerVByte(s * blockchain.WitnessScaleFactor / 1000)
}
// String returns a human-readable string of the fee rate.
func (s SatPerKWeight) String() string {
return fmt.Sprintf("%v sat/kw", int64(s))

View File

@ -1197,10 +1197,10 @@ func (r *rpcServer) EstimateFee(ctx context.Context,
resp := &lnrpc.EstimateFeeResponse{
FeeSat: totalFee,
SatPerVbyte: uint64(feePerKw.FeePerKVByte() / 1000),
SatPerVbyte: uint64(feePerKw.FeePerVByte()),
// Deprecated field.
FeerateSatPerByte: int64(feePerKw.FeePerKVByte() / 1000),
FeerateSatPerByte: int64(feePerKw.FeePerVByte()),
}
rpcsLog.Debugf("[estimatefee] fee estimate for conf target %d: %v",

View File

@ -1059,10 +1059,11 @@ func newServer(cfg *Config, listenAddrs []net.Addr,
}
s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{
FeeEstimator: cc.FeeEstimator,
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
Signer: cc.Wallet.Cfg.Signer,
Wallet: newSweeperWallet(cc.Wallet),
FeeEstimator: cc.FeeEstimator,
DetermineFeePerKw: sweep.DetermineFeePerKw,
GenSweepScript: newSweepPkScriptGen(cc.Wallet),
Signer: cc.Wallet.Cfg.Signer,
Wallet: newSweeperWallet(cc.Wallet),
NewBatchTimer: func() <-chan time.Time {
return time.NewTimer(cfg.Sweeper.BatchWindowDuration).C
},

View File

@ -47,6 +47,10 @@ var (
// request from a client whom did not specify a fee preference.
ErrNoFeePreference = errors.New("no fee preference specified")
// ErrFeePreferenceTooLow is returned when the fee preference gives a
// fee rate that's below the relay fee rate.
ErrFeePreferenceTooLow = errors.New("fee preference too low")
// ErrExclusiveGroupSpend is returned in case a different input of the
// same exclusive group was spent.
ErrExclusiveGroupSpend = errors.New("other member of exclusive group " +
@ -237,12 +241,21 @@ type UtxoSweeper struct {
wg sync.WaitGroup
}
// feeDeterminer defines an alias to the function signature of
// `DetermineFeePerKw`.
type feeDeterminer func(chainfee.Estimator,
FeePreference) (chainfee.SatPerKWeight, error)
// UtxoSweeperConfig contains dependencies of UtxoSweeper.
type UtxoSweeperConfig struct {
// GenSweepScript generates a P2WKH script belonging to the wallet where
// funds can be swept.
GenSweepScript func() ([]byte, error)
// DetermineFeePerKw determines the fee in sat/kw based on the given
// estimator and fee preference.
DetermineFeePerKw feeDeterminer
// FeeEstimator is used when crafting sweep transactions to estimate
// the necessary fee relative to the expected size of the sweep
// transaction.
@ -470,13 +483,16 @@ func (s *UtxoSweeper) feeRateForPreference(
return 0, ErrNoFeePreference
}
feeRate, err := DetermineFeePerKw(s.cfg.FeeEstimator, feePreference)
feeRate, err := s.cfg.DetermineFeePerKw(
s.cfg.FeeEstimator, feePreference,
)
if err != nil {
return 0, err
}
if feeRate < s.relayFeeRate {
return 0, fmt.Errorf("fee preference resulted in invalid fee "+
"rate %v, minimum is %v", feeRate, s.relayFeeRate)
return 0, fmt.Errorf("%w: got %v, minimum is %v",
ErrFeePreferenceTooLow, feeRate, s.relayFeeRate)
}
// If the estimated fee rate is above the maximum allowed fee rate,
@ -912,7 +928,6 @@ func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
pendingInputs) {
locktimes := make(map[uint32]pendingInputs)
inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight)
rem := make(pendingInputs)
// Go through all inputs and check if they require a certain locktime.
@ -924,41 +939,48 @@ func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster,
}
// Check if we already have inputs with this locktime.
p, ok := locktimes[lt]
cluster, ok := locktimes[lt]
if !ok {
p = make(pendingInputs)
cluster = make(pendingInputs)
}
p[op] = input
locktimes[lt] = p
// We also get the preferred fee rate for this input.
// Get the fee rate based on the fee preference. If an error is
// returned, we'll skip sweeping this input for this round of
// cluster creation and retry it when we create the clusters
// from the pending inputs again.
feeRate, err := s.feeRateForPreference(input.params.Fee)
if err != nil {
log.Warnf("Skipping input %v: %v", op, err)
continue
}
log.Debugf("Adding input %v to cluster with locktime=%v, "+
"feeRate=%v", op, lt, feeRate)
// Attach the fee rate to the input.
input.lastFeeRate = feeRate
inputFeeRates[op] = feeRate
// Update the cluster about the updated input.
cluster[op] = input
locktimes[lt] = cluster
}
// We'll then determine the sweep fee rate for each set of inputs by
// calculating the average fee rate of the inputs within each set.
inputClusters := make([]inputCluster, 0, len(locktimes))
for lt, inputs := range locktimes {
for lt, cluster := range locktimes {
lt := lt
var sweepFeeRate chainfee.SatPerKWeight
for op := range inputs {
sweepFeeRate += inputFeeRates[op]
for _, input := range cluster {
sweepFeeRate += input.lastFeeRate
}
sweepFeeRate /= chainfee.SatPerKWeight(len(inputs))
sweepFeeRate /= chainfee.SatPerKWeight(len(cluster))
inputClusters = append(inputClusters, inputCluster{
lockTime: &lt,
sweepFeeRate: sweepFeeRate,
inputs: inputs,
inputs: cluster,
})
}
@ -1599,7 +1621,7 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) (
func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference,
currentBlockHeight uint32) (*wire.MsgTx, error) {
feePerKw, err := DetermineFeePerKw(s.cfg.FeeEstimator, feePref)
feePerKw, err := s.cfg.DetermineFeePerKw(s.cfg.FeeEstimator, feePref)
if err != nil {
return nil, err
}

View File

@ -1,10 +1,12 @@
package sweep
import (
"errors"
"os"
"reflect"
"runtime/debug"
"runtime/pprof"
"sort"
"testing"
"time"
@ -149,6 +151,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext {
},
MaxFeeRate: DefaultMaxFeeRate,
FeeRateBucketSize: DefaultFeeRateBucketSize,
DetermineFeePerKw: DetermineFeePerKw,
})
ctx.sweeper.Start()
@ -2238,3 +2241,301 @@ func TestSweeperShutdownHandling(t *testing.T) {
)
require.Error(t, err)
}
// TestFeeRateForPreference checks `feeRateForPreference` works as expected.
func TestFeeRateForPreference(t *testing.T) {
t.Parallel()
dummyErr := errors.New("dummy")
// Create a test sweeper.
s := New(&UtxoSweeperConfig{})
// errFeeFunc is a mock over DetermineFeePerKw that always return the
// above dummy error.
errFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
chainfee.SatPerKWeight, error) {
return 0, dummyErr
}
// Set the relay fee rate to be 1.
s.relayFeeRate = 1
// smallFeeFunc is a mock over DetermineFeePerKw that always return a
// fee rate that's below the relayFeeRate.
smallFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
chainfee.SatPerKWeight, error) {
return s.relayFeeRate - 1, nil
}
// Set the max fee rate to be 1000 sat/vb.
s.cfg.MaxFeeRate = 1000
// largeFeeFunc is a mock over DetermineFeePerKw that always return a
// fee rate that's larger than the MaxFeeRate.
largeFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
chainfee.SatPerKWeight, error) {
return s.cfg.MaxFeeRate.FeePerKWeight() + 1, nil
}
// validFeeRate is used to test the success case.
validFeeRate := (s.cfg.MaxFeeRate.FeePerKWeight() + s.relayFeeRate) / 2
// normalFeeFunc is a mock over DetermineFeePerKw that always return a
// fee rate that's within the range.
normalFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
chainfee.SatPerKWeight, error) {
return validFeeRate, nil
}
testCases := []struct {
name string
feePref FeePreference
determineFeePerKw feeDeterminer
expectedFeeRate chainfee.SatPerKWeight
expectedErr error
}{
{
// When the fee preference is empty, we should see an
// error.
name: "empty fee preference",
feePref: FeePreference{},
expectedErr: ErrNoFeePreference,
},
{
// When an error is returned from the fee determinor,
// we should return it.
name: "error from DetermineFeePerKw",
feePref: FeePreference{FeeRate: 1},
determineFeePerKw: errFeeFunc,
expectedErr: dummyErr,
},
{
// When DetermineFeePerKw gives a too small value, we
// should return an error.
name: "fee rate below relay fee rate",
feePref: FeePreference{FeeRate: 1},
determineFeePerKw: smallFeeFunc,
expectedErr: ErrFeePreferenceTooLow,
},
{
// When DetermineFeePerKw gives a too large value, we
// should cap it at the max fee rate.
name: "fee rate above max fee rate",
feePref: FeePreference{FeeRate: 1},
determineFeePerKw: largeFeeFunc,
expectedFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(),
},
{
// When DetermineFeePerKw gives a sane fee rate, we
// should return it without any error.
name: "success",
feePref: FeePreference{FeeRate: 1},
determineFeePerKw: normalFeeFunc,
expectedFeeRate: validFeeRate,
},
}
//nolint:paralleltest
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Attach the mocked method.
s.cfg.DetermineFeePerKw = tc.determineFeePerKw
// Call the function under test.
feerate, err := s.feeRateForPreference(tc.feePref)
// Assert the expected feerate.
require.Equal(t, tc.expectedFeeRate, feerate)
// Assert the expected error.
require.ErrorIs(t, err, tc.expectedErr)
})
}
}
// TestClusterByLockTime tests the method clusterByLockTime works as expected.
func TestClusterByLockTime(t *testing.T) {
t.Parallel()
// Create a test param with a dummy fee preference. This is needed so
// `feeRateForPreference` won't throw an error.
param := Params{Fee: FeePreference{ConfTarget: 1}}
// We begin the test by creating three clusters of inputs, the first
// cluster has a locktime of 1, the second has a locktime of 2, and the
// final has no locktime.
lockTime1 := uint32(1)
lockTime2 := uint32(2)
// Create cluster one, which has a locktime of 1.
input1LockTime1 := &input.MockInput{}
input2LockTime1 := &input.MockInput{}
input1LockTime1.On("RequiredLockTime").Return(lockTime1, true)
input2LockTime1.On("RequiredLockTime").Return(lockTime1, true)
// Create cluster two, which has a locktime of 2.
input3LockTime2 := &input.MockInput{}
input4LockTime2 := &input.MockInput{}
input3LockTime2.On("RequiredLockTime").Return(lockTime2, true)
input4LockTime2.On("RequiredLockTime").Return(lockTime2, true)
// Create cluster three, which has no locktime.
input5NoLockTime := &input.MockInput{}
input6NoLockTime := &input.MockInput{}
input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false)
// With the inner Input being mocked, we can now create the pending
// inputs.
input1 := &pendingInput{Input: input1LockTime1, params: param}
input2 := &pendingInput{Input: input2LockTime1, params: param}
input3 := &pendingInput{Input: input3LockTime2, params: param}
input4 := &pendingInput{Input: input4LockTime2, params: param}
input5 := &pendingInput{Input: input5NoLockTime, params: param}
input6 := &pendingInput{Input: input6NoLockTime, params: param}
// Create the pending inputs map, which will be passed to the method
// under test.
//
// NOTE: we don't care the actual outpoint values as long as they are
// unique.
inputs := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
// Create expected clusters so we can shorten the line length in the
// test cases below.
cluster1 := pendingInputs{
wire.OutPoint{Index: 1}: input1,
wire.OutPoint{Index: 2}: input2,
}
cluster2 := pendingInputs{
wire.OutPoint{Index: 3}: input3,
wire.OutPoint{Index: 4}: input4,
}
// cluster3 should be the remaining inputs since they don't have
// locktime.
cluster3 := pendingInputs{
wire.OutPoint{Index: 5}: input5,
wire.OutPoint{Index: 6}: input6,
}
// Set the min fee rate to be 1000 sat/kw.
const minFeeRate = chainfee.SatPerKWeight(1000)
// Create a test sweeper.
s := New(&UtxoSweeperConfig{
MaxFeeRate: minFeeRate.FeePerVByte() * 10,
})
// Set the relay fee to be the minFeeRate. Any fee rate below the
// minFeeRate will cause an error to be returned.
s.relayFeeRate = minFeeRate
// applyFeeRate takes a testing fee rate and makes a mocker over
// DetermineFeePerKw that always return the testing fee rate. This
// mocked method is then attached to the sweeper.
applyFeeRate := func(feeRate chainfee.SatPerKWeight) {
mockFeeFunc := func(_ chainfee.Estimator, _ FeePreference) (
chainfee.SatPerKWeight, error) {
return feeRate, nil
}
s.cfg.DetermineFeePerKw = mockFeeFunc
}
testCases := []struct {
name string
testFeeRate chainfee.SatPerKWeight
expectedClusters []inputCluster
expectedRemainingInputs pendingInputs
}{
{
// Test a successful case where the locktime clusters
// are created and the no-locktime cluster is returned
// as the remaining inputs.
name: "successfully create clusters",
// Use a fee rate above the min value so we don't hit
// an error when performing fee estimation.
//
// TODO(yy): we should customize the returned fee rate
// for each input to further test the averaging logic.
// Or we can split the method into two, one for
// grouping the clusters and the other for averaging
// the fee rates so it's easier to be tested.
testFeeRate: minFeeRate + 1,
expectedClusters: []inputCluster{
{
lockTime: &lockTime1,
sweepFeeRate: minFeeRate + 1,
inputs: cluster1,
},
{
lockTime: &lockTime2,
sweepFeeRate: minFeeRate + 1,
inputs: cluster2,
},
},
expectedRemainingInputs: cluster3,
},
{
// Test that when the input is skipped when the fee
// estimation returns an error.
name: "error from fee estimation",
// Use a fee rate below the min value so we hit an
// error when performing fee estimation.
testFeeRate: minFeeRate - 1,
expectedClusters: []inputCluster{},
// Remaining inputs should stay untouched.
expectedRemainingInputs: cluster3,
},
}
//nolint:paralleltest
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
// Apply the test fee rate so `feeRateForPreference` is
// mocked to return the specified value.
applyFeeRate(tc.testFeeRate)
// Call the method under test.
clusters, remainingInputs := s.clusterByLockTime(inputs)
// Sort by locktime as the order is not guaranteed.
sort.Slice(clusters, func(i, j int) bool {
return *clusters[i].lockTime <
*clusters[j].lockTime
})
// Validate the values are returned as expected.
require.Equal(t, tc.expectedClusters, clusters)
require.Equal(t, tc.expectedRemainingInputs,
remainingInputs,
)
// Assert the mocked methods are called as expeceted.
input1LockTime1.AssertExpectations(t)
input2LockTime1.AssertExpectations(t)
input3LockTime2.AssertExpectations(t)
input4LockTime2.AssertExpectations(t)
input5NoLockTime.AssertExpectations(t)
input6NoLockTime.AssertExpectations(t)
})
}
}