diff --git a/breacharbiter.go b/breacharbiter.go index ab1a8bd54..420ca665c 100644 --- a/breacharbiter.go +++ b/breacharbiter.go @@ -866,6 +866,19 @@ func (bo *breachedOutput) OutPoint() *wire.OutPoint { return &bo.outpoint } +// RequiredTxOut returns a non-nil TxOut if input commits to a certain +// transaction output. This is used in the SINGLE|ANYONECANPAY case to make +// sure any presigned input is still valid by including the output. +func (bo *breachedOutput) RequiredTxOut() *wire.TxOut { + return nil +} + +// RequiredLockTime returns whether this input commits to a tx locktime that +// must be used in the transaction including it. +func (bo *breachedOutput) RequiredLockTime() (uint32, bool) { + return 0, false +} + // WitnessType returns the type of witness that must be generated to spend the // breached output. func (bo *breachedOutput) WitnessType() input.WitnessType { diff --git a/input/input.go b/input/input.go index 2e3a71c0b..7c0d79f66 100644 --- a/input/input.go +++ b/input/input.go @@ -15,6 +15,16 @@ type Input interface { // construct the corresponding transaction input. OutPoint() *wire.OutPoint + // RequiredTxOut returns a non-nil TxOut if input commits to a certain + // transaction output. This is used in the SINGLE|ANYONECANPAY case to + // make sure any presigned input is still valid by including the + // output. + RequiredTxOut() *wire.TxOut + + // RequiredLockTime returns whether this input commits to a tx locktime + // that must be used in the transaction including it. + RequiredLockTime() (uint32, bool) + // WitnessType returns an enum specifying the type of witness that must // be generated in order to spend this output. WitnessType() WitnessType @@ -75,6 +85,18 @@ func (i *inputKit) OutPoint() *wire.OutPoint { return &i.outpoint } +// RequiredTxOut returns a nil for the base input type. +func (i *inputKit) RequiredTxOut() *wire.TxOut { + return nil +} + +// RequiredLockTime returns whether this input commits to a tx locktime that +// must be used in the transaction including it. This will be false for the +// base input type since we can re-sign for any lock time. +func (i *inputKit) RequiredLockTime() (uint32, bool) { + return 0, false +} + // WitnessType returns the type of witness that must be generated to spend the // breached output. func (i *inputKit) WitnessType() WitnessType { diff --git a/input/size.go b/input/size.go index 6cebc2824..c19647162 100644 --- a/input/size.go +++ b/input/size.go @@ -533,6 +533,14 @@ func (twe *TxWeightEstimator) AddNestedP2WSHInput(witnessSize int) *TxWeightEsti return twe } +// AddTxOutput adds a known TxOut to the weight estimator. +func (twe *TxWeightEstimator) AddTxOutput(txOut *wire.TxOut) *TxWeightEstimator { + twe.outputSize += txOut.SerializeSize() + twe.outputCount++ + + return twe +} + // AddP2PKHOutput updates the weight estimate to account for an additional P2PKH // output. func (twe *TxWeightEstimator) AddP2PKHOutput() *TxWeightEstimator { diff --git a/rpcserver.go b/rpcserver.go index 580c15006..dc9188f32 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -1254,7 +1254,8 @@ func (r *rpcServer) SendCoins(ctx context.Context, // single transaction. This will be generated in a concurrent // safe manner, so no need to worry about locking. sweepTxPkg, err := sweep.CraftSweepAllTx( - feePerKw, uint32(bestHeight), targetAddr, wallet, + feePerKw, lnwallet.DefaultDustLimit(), + uint32(bestHeight), targetAddr, wallet, wallet.WalletController, wallet.WalletController, r.server.cc.FeeEstimator, r.server.cc.Signer, ) diff --git a/sweep/backend_mock_test.go b/sweep/backend_mock_test.go index dea018e34..a466b2726 100644 --- a/sweep/backend_mock_test.go +++ b/sweep/backend_mock_test.go @@ -27,6 +27,7 @@ type mockBackend struct { publishChan chan wire.MsgTx walletUtxos []*lnwallet.Utxo + utxoCnt int } func newMockBackend(t *testing.T, notifier *MockNotifier) *mockBackend { @@ -88,6 +89,16 @@ func (b *mockBackend) PublishTransaction(tx *wire.MsgTx, _ string) error { func (b *mockBackend) ListUnspentWitness(minconfirms, maxconfirms int32) ( []*lnwallet.Utxo, error) { + b.lock.Lock() + defer b.lock.Unlock() + + // Each time we list output, we increment the utxo counter, to + // ensure we don't return the same outpoint every time. + b.utxoCnt++ + + for i := range b.walletUtxos { + b.walletUtxos[i].OutPoint.Hash[0] = byte(b.utxoCnt) + } return b.walletUtxos, nil } diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 501ed0784..a1b3cf163 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -144,6 +144,7 @@ type pendingInputs = map[wire.OutPoint]*pendingInput // inputCluster is a helper struct to gather a set of pending inputs that should // be swept with the specified fee rate. type inputCluster struct { + lockTime *uint32 sweepFeeRate chainfee.SatPerKWeight inputs pendingInputs } @@ -647,7 +648,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // this to ensure any inputs which have had their fee // rate bumped are broadcast first in order enforce the // RBF policy. - inputClusters := s.clusterBySweepFeeRate() + inputClusters := s.createInputClusters() sort.Slice(inputClusters, func(i, j int) bool { return inputClusters[i].sweepFeeRate > inputClusters[j].sweepFeeRate @@ -750,17 +751,100 @@ func (s *UtxoSweeper) bucketForFeeRate( return 1 + int(feeRate-s.relayFeeRate)/s.cfg.FeeRateBucketSize } +// createInputClusters creates a list of input clusters from the set of pending +// inputs known by the UtxoSweeper. It clusters inputs by +// 1) Required tx locktime +// 2) Similar fee rates +func (s *UtxoSweeper) createInputClusters() []inputCluster { + inputs := s.pendingInputs + + // We start by getting the inputs clusters by locktime. Since the + // inputs commit to the locktime, they can only be clustered together + // if the locktime is equal. + lockTimeClusters, nonLockTimeInputs := s.clusterByLockTime(inputs) + + // Cluster the the remaining inputs by sweep fee rate. + feeClusters := s.clusterBySweepFeeRate(nonLockTimeInputs) + + // Since the inputs that we clustered by fee rate don't commit to a + // specific locktime, we can try to merge a locktime cluster with a fee + // cluster. + return zipClusters(lockTimeClusters, feeClusters) +} + +// clusterByLockTime takes the given set of pending inputs and clusters those +// with equal locktime together. Each cluster contains a sweep fee rate, which +// is determined by calculating the average fee rate of all inputs within that +// cluster. In addition to the created clusters, inputs that did not specify a +// required lock time are returned. +func (s *UtxoSweeper) clusterByLockTime(inputs pendingInputs) ([]inputCluster, + pendingInputs) { + + locktimes := make(map[uint32]pendingInputs) + inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) + rem := make(pendingInputs) + + // Go through all inputs and check if they require a certain locktime. + for op, input := range inputs { + lt, ok := input.RequiredLockTime() + if !ok { + rem[op] = input + continue + } + + // Check if we already have inputs with this locktime. + p, ok := locktimes[lt] + if !ok { + p = make(pendingInputs) + } + + p[op] = input + locktimes[lt] = p + + // We also get the preferred fee rate for this input. + feeRate, err := s.feeRateForPreference(input.params.Fee) + if err != nil { + log.Warnf("Skipping input %v: %v", op, err) + continue + } + + input.lastFeeRate = feeRate + inputFeeRates[op] = feeRate + } + + // We'll then determine the sweep fee rate for each set of inputs by + // calculating the average fee rate of the inputs within each set. + inputClusters := make([]inputCluster, 0, len(locktimes)) + for lt, inputs := range locktimes { + lt := lt + + var sweepFeeRate chainfee.SatPerKWeight + for op := range inputs { + sweepFeeRate += inputFeeRates[op] + } + + sweepFeeRate /= chainfee.SatPerKWeight(len(inputs)) + inputClusters = append(inputClusters, inputCluster{ + lockTime: <, + sweepFeeRate: sweepFeeRate, + inputs: inputs, + }) + } + + return inputClusters, rem +} + // clusterBySweepFeeRate takes the set of pending inputs within the UtxoSweeper // and clusters those together with similar fee rates. Each cluster contains a // sweep fee rate, which is determined by calculating the average fee rate of // all inputs within that cluster. -func (s *UtxoSweeper) clusterBySweepFeeRate() []inputCluster { +func (s *UtxoSweeper) clusterBySweepFeeRate(inputs pendingInputs) []inputCluster { bucketInputs := make(map[int]*bucketList) inputFeeRates := make(map[wire.OutPoint]chainfee.SatPerKWeight) // First, we'll group together all inputs with similar fee rates. This // is done by determining the fee rate bucket they should belong in. - for op, input := range s.pendingInputs { + for op, input := range inputs { feeRate, err := s.feeRateForPreference(input.params.Fee) if err != nil { log.Warnf("Skipping input %v: %v", op, err) @@ -824,6 +908,99 @@ func (s *UtxoSweeper) clusterBySweepFeeRate() []inputCluster { return inputClusters } +// zipClusters merges pairwise clusters from as and bs such that cluster a from +// as is merged with a cluster from bs that has at least the fee rate of a. +// This to ensure we don't delay confirmation by decreasing the fee rate (the +// lock time inputs are typically second level HTLC transactions, that are time +// sensitive). +func zipClusters(as, bs []inputCluster) []inputCluster { + // Sort the clusters by decreasing fee rates. + sort.Slice(as, func(i, j int) bool { + return as[i].sweepFeeRate > + as[j].sweepFeeRate + }) + sort.Slice(bs, func(i, j int) bool { + return bs[i].sweepFeeRate > + bs[j].sweepFeeRate + }) + + var ( + finalClusters []inputCluster + j int + ) + + // Go through each cluster in as, and merge with the next one from bs + // if it has at least the fee rate needed. + for i := range as { + a := as[i] + + switch { + + // If the fee rate for the next one from bs is at least a's, we + // merge. + case j < len(bs) && bs[j].sweepFeeRate >= a.sweepFeeRate: + merged := mergeClusters(a, bs[j]) + finalClusters = append(finalClusters, merged...) + + // Increment j for the next round. + j++ + + // We did not merge, meaning all the remining clusters from bs + // have lower fee rate. Instead we add a directly to the final + // clusters. + default: + finalClusters = append(finalClusters, a) + } + } + + // Add any remaining clusters from bs. + for ; j < len(bs); j++ { + b := bs[j] + finalClusters = append(finalClusters, b) + } + + return finalClusters +} + +// mergeClusters attempts to merge cluster a and b if they are compatible. The +// new cluster will have the locktime set if a or b had a locktime set, and a +// sweep fee rate that is the maximum of a and b's. If the two clusters are not +// compatible, they will be returned unchanged. +func mergeClusters(a, b inputCluster) []inputCluster { + newCluster := inputCluster{} + + switch { + + // Incompatible locktimes, return the sets without merging them. + case a.lockTime != nil && b.lockTime != nil && *a.lockTime != *b.lockTime: + return []inputCluster{a, b} + + case a.lockTime != nil: + newCluster.lockTime = a.lockTime + + case b.lockTime != nil: + newCluster.lockTime = b.lockTime + } + + if a.sweepFeeRate > b.sweepFeeRate { + newCluster.sweepFeeRate = a.sweepFeeRate + } else { + newCluster.sweepFeeRate = b.sweepFeeRate + } + + newCluster.inputs = make(pendingInputs) + + for op, in := range a.inputs { + newCluster.inputs[op] = in + } + + for op, in := range b.inputs { + newCluster.inputs[op] = in + } + + return []inputCluster{newCluster} +} + // scheduleSweep starts the sweep timer to create an opportunity for more inputs // to be added. func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error { @@ -836,7 +1013,7 @@ func (s *UtxoSweeper) scheduleSweep(currentHeight int32) error { // We'll only start our timer once we have inputs we're able to sweep. startTimer := false - for _, cluster := range s.clusterBySweepFeeRate() { + for _, cluster := range s.createInputClusters() { // Examine pending inputs and try to construct lists of inputs. // We don't need to obtain the coin selection lock, because we // just need an indication as to whether we can sweep. More @@ -988,7 +1165,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, // Create sweep tx. tx, err := createSweepTx( inputs, s.currentOutputScript, uint32(currentHeight), feeRate, - s.cfg.Signer, + dustLimit(s.relayFeeRate), s.cfg.Signer, ) if err != nil { return fmt.Errorf("create sweep tx: %v", err) @@ -1278,7 +1455,8 @@ func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference, } return createSweepTx( - inputs, pkScript, currentBlockHeight, feePerKw, s.cfg.Signer, + inputs, pkScript, currentBlockHeight, feePerKw, + dustLimit(s.relayFeeRate), s.cfg.Signer, ) } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 2f71c35bd..deea56f4c 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -2,6 +2,7 @@ package sweep import ( "os" + "reflect" "runtime/debug" "runtime/pprof" "testing" @@ -11,6 +12,7 @@ import ( "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" + "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/build" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/keychain" @@ -62,7 +64,7 @@ var ( func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput { hash := chainhash.Hash{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - byte(testInputCount)} + byte(testInputCount + 1)} input := input.MakeBaseInput( &wire.OutPoint{ @@ -88,7 +90,7 @@ func createTestInput(value int64, witnessType input.WitnessType) input.BaseInput func init() { // Create a set of test spendable inputs. - for i := 0; i < 5; i++ { + for i := 0; i < 20; i++ { input := createTestInput(int64(10000+i*500), input.CommitmentTimeLock) @@ -104,7 +106,7 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { backend := newMockBackend(t, notifier) backend.walletUtxos = []*lnwallet.Utxo{ { - Value: btcutil.Amount(10000), + Value: btcutil.Amount(1_000_000), AddressType: lnwallet.WitnessPubKey, }, } @@ -491,8 +493,9 @@ func TestWalletUtxo(t *testing.T) { "inputs instead", len(sweepTx.TxIn)) } - // Calculate expected output value based on wallet utxo of 10000 sats. - expectedOutputValue := int64(294 + 10000 - 180) + // Calculate expected output value based on wallet utxo of 1_000_000 + // sats. + expectedOutputValue := int64(294 + 1_000_000 - 180) if sweepTx.TxOut[0].Value != expectedOutputValue { t.Fatalf("Expected output value of %v, but got %v", expectedOutputValue, sweepTx.TxOut[0].Value) @@ -1367,8 +1370,8 @@ func TestCpfp(t *testing.T) { // package, making a total of 1059. At 5000 sat/kw, the required fee for // the package is 5295 sats. The parent already paid 900 sats, so there // is 4395 sat remaining to be paid. The expected output value is - // therefore: 10000 + 330 - 4395 = 5935. - require.Equal(t, int64(5935), tx.TxOut[0].Value) + // therefore: 1_000_000 + 330 - 4395 = 995 935. + require.Equal(t, int64(995_935), tx.TxOut[0].Value) // Mine the tx and assert that the result is passed back. ctx.backend.mine() @@ -1376,3 +1379,703 @@ func TestCpfp(t *testing.T) { ctx.finish(1) } + +var ( + testInputsA = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + } + + testInputsB = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } + + testInputsC = pendingInputs{ + wire.OutPoint{Hash: chainhash.Hash{}, Index: 0}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 1}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 2}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 10}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 11}: &pendingInput{}, + wire.OutPoint{Hash: chainhash.Hash{}, Index: 12}: &pendingInput{}, + } +) + +// TestMergeClusters check that we properly can merge clusters together, +// according to their required locktime. +func TestMergeClusters(t *testing.T) { + t.Parallel() + + lockTime1 := uint32(100) + lockTime2 := uint32(200) + + testCases := []struct { + name string + a inputCluster + b inputCluster + res []inputCluster + }{ + { + name: "max fee rate", + a: inputCluster{ + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "same locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 7000, + inputs: testInputsC, + }, + }, + }, + { + name: "diff locktime", + a: inputCluster{ + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + b: inputCluster{ + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + res: []inputCluster{ + { + lockTime: &lockTime1, + sweepFeeRate: 5000, + inputs: testInputsA, + }, + { + lockTime: &lockTime2, + sweepFeeRate: 7000, + inputs: testInputsB, + }, + }, + }, + } + + for _, test := range testCases { + merged := mergeClusters(test.a, test.b) + if !reflect.DeepEqual(merged, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(merged)) + } + } +} + +// TestZipClusters tests that we can merge lists of inputs clusters correctly. +func TestZipClusters(t *testing.T) { + t.Parallel() + + createCluster := func(inp pendingInputs, f chainfee.SatPerKWeight) inputCluster { + return inputCluster{ + sweepFeeRate: f, + inputs: inp, + } + } + + testCases := []struct { + name string + as []inputCluster + bs []inputCluster + res []inputCluster + }{ + { + name: "merge A into B", + as: []inputCluster{ + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + }, + }, + { + name: "A can't merge with B", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsA, 7000), + createCluster(testInputsB, 5000), + }, + }, + { + name: "empty bs", + as: []inputCluster{ + createCluster(testInputsA, 7000), + }, + bs: []inputCluster{}, + res: []inputCluster{ + createCluster(testInputsA, 7000), + }, + }, + { + name: "empty as", + as: []inputCluster{}, + bs: []inputCluster{ + createCluster(testInputsB, 5000), + }, + res: []inputCluster{ + createCluster(testInputsB, 5000), + }, + }, + + { + name: "zip 3xA into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + createCluster(testInputsA, 5000), + }, + bs: []inputCluster{ + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + createCluster(testInputsB, 7000), + }, + res: []inputCluster{ + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + createCluster(testInputsC, 7000), + }, + }, + { + name: "zip A into 3xB", + as: []inputCluster{ + createCluster(testInputsA, 2500), + }, + bs: []inputCluster{ + createCluster(testInputsB, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + res: []inputCluster{ + createCluster(testInputsC, 3000), + createCluster(testInputsB, 2000), + createCluster(testInputsB, 1000), + }, + }, + } + + for _, test := range testCases { + zipped := zipClusters(test.as, test.bs) + if !reflect.DeepEqual(zipped, test.res) { + t.Fatalf("[%s] unexpected result: %v", + test.name, spew.Sdump(zipped)) + } + } +} + +type testInput struct { + *input.BaseInput + + locktime *uint32 + reqTxOut *wire.TxOut +} + +func (i *testInput) RequiredLockTime() (uint32, bool) { + if i.locktime != nil { + return *i.locktime, true + } + + return 0, false +} + +func (i *testInput) RequiredTxOut() *wire.TxOut { + return i.reqTxOut +} + +// TestLockTimes checks that the sweeper properly groups inputs requiring the +// same locktime together into sweep transactions. +func TestLockTimes(t *testing.T) { + ctx := createSweeperTestContext(t) + + // We increase the number of max inputs to a tx so that won't + // impact our test. + ctx.sweeper.cfg.MaxInputsPerTx = 100 + + // We will set up the lock times in such a way that we expect the + // sweeper to divide the inputs into 4 diffeerent transactions. + const numSweeps = 4 + + // Sweep 8 inputs, using 4 different lock times. + var ( + results []chan Result + inputs = make(map[wire.OutPoint]input.Input) + ) + for i := 0; i < numSweeps*2; i++ { + lt := uint32(10 + (i % numSweeps)) + inp := &testInput{ + BaseInput: spendableInputs[i], + locktime: <, + } + + result, err := ctx.sweeper.SweepInput( + inp, Params{ + Fee: FeePreference{ConfTarget: 6}, + }, + ) + if err != nil { + t.Fatal(err) + } + results = append(results, result) + + op := inp.OutPoint() + inputs[*op] = inp + } + + // We also add 3 regular inputs that don't require any specific lock + // time. + for i := 0; i < 3; i++ { + inp := spendableInputs[i+numSweeps*2] + result, err := ctx.sweeper.SweepInput( + inp, Params{ + Fee: FeePreference{ConfTarget: 6}, + }, + ) + if err != nil { + t.Fatal(err) + } + + results = append(results, result) + + op := inp.OutPoint() + inputs[*op] = inp + } + + // We expect all inputs to be published in separate transactions, even + // though they share the same fee preference. + ctx.tick() + + // Check the sweeps transactions, ensuring all inputs are there, and + // all the locktimes are satisfied. + for i := 0; i < numSweeps; i++ { + sweepTx := ctx.receiveTx() + if len(sweepTx.TxOut) != 1 { + t.Fatal("expected a single tx out in the sweep tx") + } + + for _, txIn := range sweepTx.TxIn { + op := txIn.PreviousOutPoint + inp, ok := inputs[op] + if !ok { + t.Fatalf("Unexpected outpoint: %v", op) + } + + delete(inputs, op) + + // If this input had a required locktime, ensure the tx + // has that set correctly. + lt, ok := inp.RequiredLockTime() + if !ok { + continue + } + + if lt != sweepTx.LockTime { + t.Fatalf("Input required locktime %v, sweep "+ + "tx had locktime %v", lt, sweepTx.LockTime) + } + + } + } + + // The should be no inputs not foud in any of the sweeps. + if len(inputs) != 0 { + t.Fatalf("had unsweeped inputs") + } + + // Mine the first sweeps + ctx.backend.mine() + + // Results should all come back. + for i := range results { + result := <-results[i] + if result.Err != nil { + t.Fatal("expected input to be swept") + } + } +} + +// TestRequiredTxOuts checks that inputs having a required TxOut gets swept with +// sweep transactions paying into these outputs. +func TestRequiredTxOuts(t *testing.T) { + // Create some test inputs and locktime vars. + var inputs []*input.BaseInput + for i := 0; i < 20; i++ { + input := createTestInput( + int64(btcutil.SatoshiPerBitcoin+i*500), + input.CommitmentTimeLock, + ) + + inputs = append(inputs, &input) + } + + locktime1 := uint32(51) + locktime2 := uint32(52) + locktime3 := uint32(53) + + testCases := []struct { + name string + inputs []*testInput + assertSweeps func(*testing.T, map[wire.OutPoint]*testInput, + []*wire.MsgTx) + }{ + { + // Single input with a required TX out that is smaller. + // We expect a change output to be added. + name: "single input, leftover change", + inputs: []*testInput{ + { + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: 100000, + }, + }, + }, + + // Since the required output value is small, we expect + // the rest after fees to go into a change output. + assertSweeps: func(t *testing.T, + _ map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 1, len(tx.TxIn)) + + // We should have two outputs, the required + // output must be the first one. + require.Equal(t, 2, len(tx.TxOut)) + out := tx.TxOut[0] + require.Equal(t, []byte("aaa"), out.PkScript) + require.Equal(t, int64(100000), out.Value) + }, + }, + { + // An input committing to a slightly smaller output, so + // it will pay its own fees. + name: "single input, no change", + inputs: []*testInput{ + { + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + + // Fee will be about 5340 sats. + // Subtract a bit more to + // ensure no dust change output + // is manifested. + Value: inputs[0].SignDesc().Output.Value - 5600, + }, + }, + }, + + // We expect this single input/output pair. + assertSweeps: func(t *testing.T, + _ map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 1, len(tx.TxIn)) + + require.Equal(t, 1, len(tx.TxOut)) + out := tx.TxOut[0] + require.Equal(t, []byte("aaa"), out.PkScript) + require.Equal( + t, + inputs[0].SignDesc().Output.Value-5600, + out.Value, + ) + }, + }, + { + // An input committing to an output of equal value, just + // add input to pay fees. + name: "single input, extra fee input", + inputs: []*testInput{ + { + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: inputs[0].SignDesc().Output.Value, + }, + }, + }, + + // We expect an extra input and output. + assertSweeps: func(t *testing.T, + _ map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 2, len(tx.TxIn)) + + require.Equal(t, 2, len(tx.TxOut)) + out := tx.TxOut[0] + require.Equal(t, []byte("aaa"), out.PkScript) + require.Equal( + t, inputs[0].SignDesc().Output.Value, + out.Value, + ) + }, + }, + { + // Three inputs added, should be combined into a single + // sweep. + name: "three inputs", + inputs: []*testInput{ + { + BaseInput: inputs[0], + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: inputs[0].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[1], + reqTxOut: &wire.TxOut{ + PkScript: []byte("bbb"), + Value: inputs[1].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[2], + reqTxOut: &wire.TxOut{ + PkScript: []byte("ccc"), + Value: inputs[2].SignDesc().Output.Value, + }, + }, + }, + + // We expect an extra input and output to pay fees. + assertSweeps: func(t *testing.T, + testInputs map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 1, len(txs)) + + tx := txs[0] + require.Equal(t, 4, len(tx.TxIn)) + require.Equal(t, 4, len(tx.TxOut)) + + // The inputs and outputs must be in the same + // order. + for i, in := range tx.TxIn { + // Last one is the change input/output + // pair, so we'll skip it. + if i == 3 { + continue + } + + // Get this input to ensure the output + // on index i coresponsd to this one. + inp := testInputs[in.PreviousOutPoint] + require.NotNil(t, inp) + + require.Equal( + t, tx.TxOut[i].Value, + inp.SignDesc().Output.Value, + ) + } + }, + }, + { + // Six inputs added, which 3 different locktimes. + // Should result in 3 sweeps. + name: "six inputs", + inputs: []*testInput{ + { + BaseInput: inputs[0], + locktime: &locktime1, + reqTxOut: &wire.TxOut{ + PkScript: []byte("aaa"), + Value: inputs[0].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[1], + locktime: &locktime1, + reqTxOut: &wire.TxOut{ + PkScript: []byte("bbb"), + Value: inputs[1].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[2], + locktime: &locktime2, + reqTxOut: &wire.TxOut{ + PkScript: []byte("ccc"), + Value: inputs[2].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[3], + locktime: &locktime2, + reqTxOut: &wire.TxOut{ + PkScript: []byte("ddd"), + Value: inputs[3].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[4], + locktime: &locktime3, + reqTxOut: &wire.TxOut{ + PkScript: []byte("eee"), + Value: inputs[4].SignDesc().Output.Value, + }, + }, + { + BaseInput: inputs[5], + locktime: &locktime3, + reqTxOut: &wire.TxOut{ + PkScript: []byte("fff"), + Value: inputs[5].SignDesc().Output.Value, + }, + }, + }, + + // We expect three sweeps, each having two of our + // inputs, one extra input and output to pay fees. + assertSweeps: func(t *testing.T, + testInputs map[wire.OutPoint]*testInput, + txs []*wire.MsgTx) { + + require.Equal(t, 3, len(txs)) + + for _, tx := range txs { + require.Equal(t, 3, len(tx.TxIn)) + require.Equal(t, 3, len(tx.TxOut)) + + // The inputs and outputs must be in + // the same order. + for i, in := range tx.TxIn { + // Last one is the change + // output, so we'll skip it. + if i == 2 { + continue + } + + // Get this input to ensure the + // output on index i coresponsd + // to this one. + inp := testInputs[in.PreviousOutPoint] + require.NotNil(t, inp) + + require.Equal( + t, tx.TxOut[i].Value, + inp.SignDesc().Output.Value, + ) + + // Check that the locktimes are + // kept intact. + require.Equal( + t, tx.LockTime, + *inp.locktime, + ) + } + } + }, + }, + } + + for _, testCase := range testCases { + testCase := testCase + + t.Run(testCase.name, func(t *testing.T) { + ctx := createSweeperTestContext(t) + + // We increase the number of max inputs to a tx so that + // won't impact our test. + ctx.sweeper.cfg.MaxInputsPerTx = 100 + + // Sweep all test inputs. + var ( + inputs = make(map[wire.OutPoint]*testInput) + results = make(map[wire.OutPoint]chan Result) + ) + for _, inp := range testCase.inputs { + result, err := ctx.sweeper.SweepInput( + inp, Params{ + Fee: FeePreference{ConfTarget: 6}, + }, + ) + if err != nil { + t.Fatal(err) + } + + op := inp.OutPoint() + results[*op] = result + inputs[*op] = inp + } + + // Tick, which should trigger a sweep of all inputs. + ctx.tick() + + // Check the sweeps transactions, ensuring all inputs + // are there, and all the locktimes are satisfied. + var sweeps []*wire.MsgTx + Loop: + for { + select { + case tx := <-ctx.publishChan: + sweeps = append(sweeps, &tx) + case <-time.After(200 * time.Millisecond): + break Loop + } + } + + // Mine the sweeps. + ctx.backend.mine() + + // Results should all come back. + for _, resultChan := range results { + result := <-resultChan + if result.Err != nil { + t.Fatalf("expected input to be "+ + "swept: %v", result.Err) + } + } + + // Assert the transactions are what we expect. + testCase.assertSweeps(t, inputs, sweeps) + }) + } +} diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index 708369006..b05b5ab07 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -31,15 +31,22 @@ const ( ) type txInputSetState struct { - // weightEstimate is the (worst case) tx weight with the current set of - // inputs. - weightEstimate *weightEstimator + // feeRate is the fee rate to use for the sweep transaction. + feeRate chainfee.SatPerKWeight // inputTotal is the total value of all inputs. inputTotal btcutil.Amount - // outputValue is the value of the tx output. - outputValue btcutil.Amount + // requiredOutput is the sum of the outputs committed to by the inputs. + requiredOutput btcutil.Amount + + // changeOutput is the value of the change output. This will be what is + // left over after subtracting the requiredOutput and the tx fee from + // the inputTotal. + // + // NOTE: This might be below the dust limit, or even negative since it + // is the change remaining in csse we pay the fee for a change output. + changeOutput btcutil.Amount // inputs is the set of tx inputs. inputs []input.Input @@ -52,11 +59,42 @@ type txInputSetState struct { force bool } +// weightEstimate is the (worst case) tx weight with the current set of +// inputs. It takes a parameter whether to add a change output or not. +func (t *txInputSetState) weightEstimate(change bool) *weightEstimator { + weightEstimate := newWeightEstimator(t.feeRate) + for _, i := range t.inputs { + // Can ignore error, because it has already been checked when + // calculating the yields. + _ = weightEstimate.add(i) + + r := i.RequiredTxOut() + if r != nil { + weightEstimate.addOutput(r) + } + } + + // Add a change output to the weight estimate if requested. + if change { + weightEstimate.addP2WKHOutput() + } + + return weightEstimate +} + +// totalOutput is the total amount left for us after paying fees. +// +// NOTE: This might be dust. +func (t *txInputSetState) totalOutput() btcutil.Amount { + return t.requiredOutput + t.changeOutput +} + func (t *txInputSetState) clone() txInputSetState { s := txInputSetState{ - weightEstimate: t.weightEstimate.clone(), + feeRate: t.feeRate, inputTotal: t.inputTotal, - outputValue: t.outputValue, + changeOutput: t.changeOutput, + requiredOutput: t.requiredOutput, walletInputTotal: t.walletInputTotal, force: t.force, inputs: make([]input.Input, len(t.inputs)), @@ -83,17 +121,21 @@ type txInputSet struct { wallet Wallet } +func dustLimit(relayFee chainfee.SatPerKWeight) btcutil.Amount { + return txrules.GetDustThreshold( + input.P2WPKHSize, + btcutil.Amount(relayFee.FeePerKVByte()), + ) +} + // newTxInputSet constructs a new, empty input set. func newTxInputSet(wallet Wallet, feePerKW, relayFee chainfee.SatPerKWeight, maxInputs int) *txInputSet { - dustLimit := txrules.GetDustThreshold( - input.P2WPKHSize, - btcutil.Amount(relayFee.FeePerKVByte()), - ) + dustLimit := dustLimit(relayFee) state := txInputSetState{ - weightEstimate: newWeightEstimator(feePerKW), + feeRate: feePerKW, } b := txInputSet{ @@ -103,16 +145,36 @@ func newTxInputSet(wallet Wallet, feePerKW, txInputSetState: state, } - // Add the sweep tx output to the weight estimate. - b.weightEstimate.addP2WKHOutput() - return &b } -// dustLimitReached returns true if we've accumulated enough inputs to meet the -// dust limit. -func (t *txInputSet) dustLimitReached() bool { - return t.outputValue >= t.dustLimit +// enoughInput returns true if we've accumulated enough inputs to pay the fees +// and have at least one output that meets the dust limit. +func (t *txInputSet) enoughInput() bool { + // If we have a change output above dust, then we certainly have enough + // inputs to the transaction. + if t.changeOutput >= t.dustLimit { + return true + } + + // We did not have enough input for a change output. Check if we have + // enough input to pay the fees for a transaction with no change + // output. + fee := t.weightEstimate(false).fee() + if t.inputTotal < t.requiredOutput+fee { + return false + } + + // We could pay the fees, but we still need at least one output to be + // above the dust limit for the tx to be valid (we assume that these + // required outputs only get added if they are above dust) + for _, inp := range t.inputs { + if inp.RequiredTxOut() != nil { + return true + } + } + + return false } // add adds a new input to the set. It returns a bool indicating whether the @@ -127,28 +189,35 @@ func (t *txInputSet) addToState(inp input.Input, constraints addConstraints) *tx return nil } + // If the input comes with a required tx out that is below dust, we + // won't add it. + reqOut := inp.RequiredTxOut() + if reqOut != nil && btcutil.Amount(reqOut.Value) < t.dustLimit { + return nil + } + // Clone the current set state. s := t.clone() // Add the new input. s.inputs = append(s.inputs, inp) - // Can ignore error, because it has already been checked when - // calculating the yields. - _ = s.weightEstimate.add(inp) - // Add the value of the new input. value := btcutil.Amount(inp.SignDesc().Output.Value) s.inputTotal += value // Recalculate the tx fee. - fee := s.weightEstimate.fee() + fee := s.weightEstimate(true).fee() // Calculate the new output value. - s.outputValue = s.inputTotal - fee + if reqOut != nil { + s.requiredOutput += btcutil.Amount(reqOut.Value) + } + s.changeOutput = s.inputTotal - s.requiredOutput - fee - // Calculate the yield of this input from the change in tx output value. - inputYield := s.outputValue - t.outputValue + // Calculate the yield of this input from the change in total tx output + // value. + inputYield := s.totalOutput() - t.totalOutput() switch constraints { @@ -188,11 +257,11 @@ func (t *txInputSet) addToState(inp input.Input, constraints addConstraints) *tx // value of the wallet input and what we get out of this // transaction. To prevent attaching and locking a big utxo for // very little benefit. - if !s.force && s.walletInputTotal >= s.outputValue { + if !s.force && s.walletInputTotal >= s.totalOutput() { log.Debugf("Rejecting wallet input of %v, because it "+ "would make a negative yielding transaction "+ "(%v)", - value, s.outputValue-s.walletInputTotal) + value, s.totalOutput()-s.walletInputTotal) return nil } @@ -246,8 +315,9 @@ func (t *txInputSet) addPositiveYieldInputs(sweepableInputs []txInput) { // tryAddWalletInputsIfNeeded retrieves utxos from the wallet and tries adding as // many as required to bring the tx output value above the given minimum. func (t *txInputSet) tryAddWalletInputsIfNeeded() error { - // If we've already reached the dust limit, no action is needed. - if t.dustLimitReached() { + // If we've already have enough to pay the transaction fees and have at + // least one output materialize, no action is needed. + if t.enoughInput() { return nil } @@ -271,7 +341,7 @@ func (t *txInputSet) tryAddWalletInputsIfNeeded() error { } // Return if we've reached the minimum output amount. - if t.dustLimitReached() { + if t.enoughInput() { return nil } } diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index d9e98f733..2f72b3673 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -3,9 +3,11 @@ package sweep import ( "testing" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/stretchr/testify/require" ) // TestTxInputSet tests adding various sized inputs to the set. @@ -34,12 +36,15 @@ func TestTxInputSet(t *testing.T) { t.Fatal("expected add of positively yielding input to succeed") } + fee := set.weightEstimate(true).fee() + require.Equal(t, btcutil.Amount(439), fee) + // The tx output should now be 700-439 = 261 sats. The dust limit isn't // reached yet. - if set.outputValue != 261 { + if set.totalOutput() != 261 { t.Fatal("unexpected output value") } - if set.dustLimitReached() { + if set.enoughInput() { t.Fatal("expected dust limit not yet to be reached") } @@ -48,10 +53,10 @@ func TestTxInputSet(t *testing.T) { if !set.add(createP2WKHInput(1000), constraintsRegular) { t.Fatal("expected add of positively yielding input to succeed") } - if set.outputValue != 988 { + if set.totalOutput() != 988 { t.Fatal("unexpected output value") } - if !set.dustLimitReached() { + if !set.enoughInput() { t.Fatal("expected dust limit to be reached") } } @@ -73,7 +78,7 @@ func TestTxInputSetFromWallet(t *testing.T) { if !set.add(createP2WKHInput(700), constraintsRegular) { t.Fatal("expected add of positively yielding input to succeed") } - if set.dustLimitReached() { + if set.enoughInput() { t.Fatal("expected dust limit not yet to be reached") } @@ -92,7 +97,7 @@ func TestTxInputSetFromWallet(t *testing.T) { t.Fatal(err) } - if !set.dustLimitReached() { + if !set.enoughInput() { t.Fatal("expected dust limit to be reached") } } @@ -117,3 +122,129 @@ func (m *mockWallet) ListUnspentWitness(minconfirms, maxconfirms int32) ( }, }, nil } + +type reqInput struct { + input.Input + + txOut *wire.TxOut +} + +func (r *reqInput) RequiredTxOut() *wire.TxOut { + return r.txOut +} + +// TestTxInputSetRequiredOutput tests that the tx input set behaves as expected +// when we add inputs that have required tx outs. +func TestTxInputSetRequiredOutput(t *testing.T) { + const ( + feeRate = 1000 + relayFee = 300 + maxInputs = 10 + ) + set := newTxInputSet(nil, feeRate, relayFee, maxInputs) + if set.dustLimit != 537 { + t.Fatalf("incorrect dust limit") + } + + // Attempt to add an input with a required txout below the dust limit. + // This should fail since we cannot trim such outputs. + inp := &reqInput{ + Input: createP2WKHInput(500), + txOut: &wire.TxOut{ + Value: 500, + PkScript: make([]byte, 33), + }, + } + require.False(t, set.add(inp, constraintsRegular), + "expected adding dust required tx out to fail") + + // Create a 1000 sat input that also has a required TxOut of 1000 sat. + // The fee to sweep this input to a P2WKH output is 439 sats. + inp = &reqInput{ + Input: createP2WKHInput(1000), + txOut: &wire.TxOut{ + Value: 1000, + PkScript: make([]byte, 22), + }, + } + require.True(t, set.add(inp, constraintsRegular), "failed adding input") + + // The fee needed to pay for this input and output should be 439 sats. + fee := set.weightEstimate(false).fee() + require.Equal(t, btcutil.Amount(439), fee) + + // Since the tx set currently pays no fees, we expect the current + // change to actually be negative, since this is what it would cost us + // in fees to add a change output. + feeWithChange := set.weightEstimate(true).fee() + if set.changeOutput != -feeWithChange { + t.Fatalf("expected negative change of %v, had %v", + -feeWithChange, set.changeOutput) + } + + // This should also be reflected by not having enough input. + require.False(t, set.enoughInput()) + + // Get a weight estimate without change output, and add an additional + // input to it. + dummyInput := createP2WKHInput(1000) + weight := set.weightEstimate(false) + require.NoError(t, weight.add(dummyInput)) + + // Now we add a an input that is large enough to pay the fee for the + // transaction without a change output, but not large enough to afford + // adding a change output. + extraInput1 := weight.fee() + 100 + require.True(t, set.add(createP2WKHInput(extraInput1), constraintsRegular), + "expected add of positively yielding input to succeed") + + // The change should be negative, since we would have to add a change + // output, which we cannot yet afford. + if set.changeOutput >= 0 { + t.Fatal("expected change to be negaitve") + } + + // Even though we cannot afford a change output, the tx set is valid, + // since we can pay the fees without the change output. + require.True(t, set.enoughInput()) + + // Get another weight estimate, this time with a change output, and + // figure out how much we must add to afford a change output. + weight = set.weightEstimate(true) + require.NoError(t, weight.add(dummyInput)) + + // We add what is left to reach this value. + extraInput2 := weight.fee() - extraInput1 + 100 + + // Add this input, which should result in the change now being 100 sats. + require.True(t, set.add(createP2WKHInput(extraInput2), constraintsRegular)) + + // The change should be 100, since this is what is left after paying + // fees in case of a change output. + change := set.changeOutput + if change != 100 { + t.Fatalf("expected change be 100, was %v", change) + } + + // Even though the change output is dust, we have enough for fees, and + // we have an output, so it should be considered enough to craft a + // valid sweep transaction. + require.True(t, set.enoughInput()) + + // Finally we add an input that should push the change output above the + // dust limit. + weight = set.weightEstimate(true) + require.NoError(t, weight.add(dummyInput)) + + // We expect the change to everything that is left after paying the tx + // fee. + extraInput3 := weight.fee() - extraInput1 - extraInput2 + 1000 + require.True(t, set.add(createP2WKHInput(extraInput3), constraintsRegular)) + + change = set.changeOutput + if change != 1000 { + t.Fatalf("expected change to be %v, had %v", 1000, change) + + } + require.True(t, set.enoughInput()) +} diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index f023c3178..670a22551 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -110,17 +110,19 @@ func generateInputPartitionings(sweepableInputs []txInput, // the dust limit, stop sweeping. Because of the sorting, // continuing with the remaining inputs will only lead to sets // with an even lower output value. - if !txInputs.dustLimitReached() { - log.Debugf("Set value %v below dust limit of %v", - txInputs.outputValue, txInputs.dustLimit) + if !txInputs.enoughInput() { + log.Debugf("Set value %v (r=%v, c=%v) below dust "+ + "limit of %v", txInputs.totalOutput(), + txInputs.requiredOutput, txInputs.changeOutput, + txInputs.dustLimit) return sets, nil } log.Infof("Candidate sweep set of size=%v (+%v wallet inputs), "+ "has yield=%v, weight=%v", inputCount, len(txInputs.inputs)-inputCount, - txInputs.outputValue-txInputs.walletInputTotal, - txInputs.weightEstimate.weight()) + txInputs.totalOutput()-txInputs.walletInputTotal, + txInputs.weightEstimate(true).weight()) sets = append(sets, txInputs.inputs) sweepableInputs = sweepableInputs[inputCount:] @@ -132,39 +134,93 @@ func generateInputPartitionings(sweepableInputs []txInput, // createSweepTx builds a signed tx spending the inputs to a the output script. func createSweepTx(inputs []input.Input, outputPkScript []byte, currentBlockHeight uint32, feePerKw chainfee.SatPerKWeight, - signer input.Signer) (*wire.MsgTx, error) { + dustLimit btcutil.Amount, signer input.Signer) (*wire.MsgTx, error) { inputs, estimator := getWeightEstimate(inputs, feePerKw) txFee := estimator.fee() - // Sum up the total value contained in the inputs. - var totalSum btcutil.Amount + // Create the sweep transaction that we will be building. We use + // version 2 as it is required for CSV. + sweepTx := wire.NewMsgTx(2) + + // Track whether any of the inputs require a certain locktime. + locktime := int32(-1) + + // We start by adding all inputs that commit to an output. We do this + // since the input and output index must stay the same for the + // signatures to be valid. + var ( + totalInput btcutil.Amount + requiredOutput btcutil.Amount + ) for _, o := range inputs { - totalSum += btcutil.Amount(o.SignDesc().Output.Value) + if o.RequiredTxOut() == nil { + continue + } + + sweepTx.AddTxIn(&wire.TxIn{ + PreviousOutPoint: *o.OutPoint(), + Sequence: o.BlocksToMaturity(), + }) + sweepTx.AddTxOut(o.RequiredTxOut()) + + if lt, ok := o.RequiredLockTime(); ok { + // If another input commits to a different locktime, + // they cannot be combined in the same transcation. + if locktime != -1 && locktime != int32(lt) { + return nil, fmt.Errorf("incompatible locktime") + } + + locktime = int32(lt) + } + + totalInput += btcutil.Amount(o.SignDesc().Output.Value) + requiredOutput += btcutil.Amount(o.RequiredTxOut().Value) } - // Sweep as much possible, after subtracting txn fees. - sweepAmt := int64(totalSum - txFee) + // Sum up the value contained in the remaining inputs, and add them to + // the sweep transaction. + for _, o := range inputs { + if o.RequiredTxOut() != nil { + continue + } - // Create the sweep transaction that we will be building. We use - // version 2 as it is required for CSV. The txn will sweep the amount - // after fees to the pkscript generated above. - sweepTx := wire.NewMsgTx(2) - sweepTx.AddTxOut(&wire.TxOut{ - PkScript: outputPkScript, - Value: sweepAmt, - }) - - sweepTx.LockTime = currentBlockHeight - - // Add all inputs to the sweep transaction. Ensure that for each - // csvInput, we set the sequence number properly. - for _, input := range inputs { sweepTx.AddTxIn(&wire.TxIn{ - PreviousOutPoint: *input.OutPoint(), - Sequence: input.BlocksToMaturity(), + PreviousOutPoint: *o.OutPoint(), + Sequence: o.BlocksToMaturity(), }) + + if lt, ok := o.RequiredLockTime(); ok { + if locktime != -1 && locktime != int32(lt) { + return nil, fmt.Errorf("incompatible locktime") + } + + locktime = int32(lt) + } + + totalInput += btcutil.Amount(o.SignDesc().Output.Value) + } + + // The value remaining after the required output and fees, go to + // change. Not that this fee is what we would have to pay in case the + // sweep tx has a change output. + changeAmt := totalInput - requiredOutput - txFee + + // The txn will sweep the amount after fees to the pkscript generated + // above. + if changeAmt >= dustLimit { + sweepTx.AddTxOut(&wire.TxOut{ + PkScript: outputPkScript, + Value: int64(changeAmt), + }) + } + + // We'll default to using the current block height as locktime, if none + // of the inputs commits to a different locktime. + sweepTx.LockTime = currentBlockHeight + if locktime != -1 { + sweepTx.LockTime = uint32(locktime) } // Before signing the transaction, check to ensure that it meets some @@ -233,7 +289,12 @@ func getWeightEstimate(inputs []input.Input, feeRate chainfee.SatPerKWeight) ( weightEstimate := newWeightEstimator(feeRate) // Our sweep transaction will pay to a single segwit p2wkh address, - // ensure it contributes to our weight estimate. + // ensure it contributes to our weight estimate. If the inputs we add + // have required TxOuts, then this will be our change address. Note + // that if we have required TxOuts, we might end up creating a sweep tx + // without a change output. It is okay to add the change output to the + // weight estimate regardless, since the estimated fee will just be + // subtracted from this already dust output, and trimmed. weightEstimate.addP2WKHOutput() // For each output, use its witness type to determine the estimate @@ -252,6 +313,12 @@ func getWeightEstimate(inputs []input.Input, feeRate chainfee.SatPerKWeight) ( continue } + // If this input comes with a committed output, add that as + // well. + if inp.RequiredTxOut() != nil { + weightEstimate.addOutput(inp.RequiredTxOut()) + } + sweepInputs = append(sweepInputs, inp) } diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go index aa6d658d7..06c41e5a5 100644 --- a/sweep/walletsweep.go +++ b/sweep/walletsweep.go @@ -153,10 +153,10 @@ type WalletSweepPackage struct { // by the delivery address. The sweep transaction will be crafted with the // target fee rate, and will use the utxoSource and outpointLocker as sources // for wallet funds. -func CraftSweepAllTx(feeRate chainfee.SatPerKWeight, blockHeight uint32, - deliveryAddr btcutil.Address, coinSelectLocker CoinSelectionLocker, - utxoSource UtxoSource, outpointLocker OutpointLocker, - feeEstimator chainfee.Estimator, +func CraftSweepAllTx(feeRate chainfee.SatPerKWeight, dustLimit btcutil.Amount, + blockHeight uint32, deliveryAddr btcutil.Address, + coinSelectLocker CoinSelectionLocker, utxoSource UtxoSource, + outpointLocker OutpointLocker, feeEstimator chainfee.Estimator, signer input.Signer) (*WalletSweepPackage, error) { // TODO(roasbeef): turn off ATPL as well when available? @@ -273,7 +273,8 @@ func CraftSweepAllTx(feeRate chainfee.SatPerKWeight, blockHeight uint32, // Finally, we'll ask the sweeper to craft a sweep transaction which // respects our fee preference and targets all the UTXOs of the wallet. sweepTx, err := createSweepTx( - inputsToSweep, deliveryPkScript, blockHeight, feeRate, signer, + inputsToSweep, deliveryPkScript, blockHeight, feeRate, + dustLimit, signer, ) if err != nil { unlockOutputs() diff --git a/sweep/walletsweep_test.go b/sweep/walletsweep_test.go index 4e0987579..032f30c00 100644 --- a/sweep/walletsweep_test.go +++ b/sweep/walletsweep_test.go @@ -288,7 +288,7 @@ func TestCraftSweepAllTxCoinSelectFail(t *testing.T) { utxoLocker := newMockOutpointLocker() _, err := CraftSweepAllTx( - 0, 100, nil, coinSelectLocker, utxoSource, utxoLocker, nil, nil, + 0, 100, 10, nil, coinSelectLocker, utxoSource, utxoLocker, nil, nil, ) // Since we instructed the coin select locker to fail above, we should @@ -313,7 +313,7 @@ func TestCraftSweepAllTxUnknownWitnessType(t *testing.T) { utxoLocker := newMockOutpointLocker() _, err := CraftSweepAllTx( - 0, 100, nil, coinSelectLocker, utxoSource, utxoLocker, nil, nil, + 0, 100, 10, nil, coinSelectLocker, utxoSource, utxoLocker, nil, nil, ) // Since passed in a p2wsh output, which is unknown, we should fail to @@ -347,7 +347,7 @@ func TestCraftSweepAllTx(t *testing.T) { utxoLocker := newMockOutpointLocker() sweepPkg, err := CraftSweepAllTx( - 0, 100, deliveryAddr, coinSelectLocker, utxoSource, utxoLocker, + 0, 100, 10, deliveryAddr, coinSelectLocker, utxoSource, utxoLocker, feeEstimator, signer, ) if err != nil { diff --git a/sweep/weight_estimator.go b/sweep/weight_estimator.go index 5a4068158..011094fef 100644 --- a/sweep/weight_estimator.go +++ b/sweep/weight_estimator.go @@ -2,6 +2,7 @@ package sweep import ( "github.com/btcsuite/btcd/chaincfg/chainhash" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" @@ -25,22 +26,6 @@ func newWeightEstimator(feeRate chainfee.SatPerKWeight) *weightEstimator { } } -// clone returns a copy of this weight estimator. -func (w *weightEstimator) clone() *weightEstimator { - parents := make(map[chainhash.Hash]struct{}, len(w.parents)) - for hash := range w.parents { - parents[hash] = struct{}{} - } - - return &weightEstimator{ - estimator: w.estimator, - feeRate: w.feeRate, - parents: parents, - parentsFee: w.parentsFee, - parentsWeight: w.parentsWeight, - } -} - // add adds the weight of the given input to the weight estimate. func (w *weightEstimator) add(inp input.Input) error { // If there is a parent tx, add the parent's fee and weight. @@ -92,6 +77,12 @@ func (w *weightEstimator) addP2WKHOutput() { w.estimator.AddP2WKHOutput() } +// addOutput updates the weight estimate to account for the known +// output given. +func (w *weightEstimator) addOutput(txOut *wire.TxOut) { + w.estimator.AddTxOutput(txOut) +} + // weight gets the estimated weight of the transaction. func (w *weightEstimator) weight() int { return w.estimator.Weight() diff --git a/sweep/weight_estimator_test.go b/sweep/weight_estimator_test.go index e4f1d489b..f64b8b897 100644 --- a/sweep/weight_estimator_test.go +++ b/sweep/weight_estimator_test.go @@ -3,7 +3,10 @@ package sweep import ( "testing" + "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/require" @@ -77,3 +80,32 @@ func TestWeightEstimator(t *testing.T) { require.Equal(t, expectedFee, w.fee()) } + +// TestWeightEstimatorAddOutput tests that adding the raw P2WKH output to the +// estimator yield the same result as an estimated add. +func TestWeightEstimatorAddOutput(t *testing.T) { + testFeeRate := chainfee.SatPerKWeight(20000) + + p2wkhAddr, err := btcutil.NewAddressWitnessPubKeyHash( + make([]byte, 20), &chaincfg.MainNetParams, + ) + require.NoError(t, err) + + p2wkhScript, err := txscript.PayToAddrScript(p2wkhAddr) + require.NoError(t, err) + + // Create two estimators, add the raw P2WKH out to one. + txOut := &wire.TxOut{ + PkScript: p2wkhScript, + Value: 10000, + } + + w1 := newWeightEstimator(testFeeRate) + w1.addOutput(txOut) + + w2 := newWeightEstimator(testFeeRate) + w2.addP2WKHOutput() + + // Estimate hhould be the same. + require.Equal(t, w1.weight(), w2.weight()) +}