mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-11-10 06:07:16 +01:00
sweep: add missing output to the weight estimation
When overlay channels are used the extra output needs to be considered.
This commit is contained in:
@@ -1668,6 +1668,16 @@ func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey,
|
||||
return 0, noChange, noLocktime, err
|
||||
}
|
||||
|
||||
// We also add the extra change output to the change pk scripts.
|
||||
//
|
||||
// NOTE: The weight estimation will not be quite accurate because the
|
||||
// witness data is greater when overlay channels are used. But that
|
||||
// shouldn't be a problem since we will increase the fee rate
|
||||
// incrementally via the fee function.
|
||||
extraChangeOut.WhenSome(func(o SweepOutput) {
|
||||
changePkScripts = append(changePkScripts, o.TxOut.PkScript)
|
||||
})
|
||||
|
||||
// Creating a weight estimator with nil outputs and zero max fee rate.
|
||||
// We don't allow adding customized outputs in the sweeping tx, and the
|
||||
// fee rate is already being managed before we get here.
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/fn/v2"
|
||||
"github.com/lightningnetwork/lnd/input"
|
||||
"github.com/lightningnetwork/lnd/keychain"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwallet"
|
||||
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
|
||||
"github.com/stretchr/testify/mock"
|
||||
@@ -2109,3 +2110,200 @@ func createTestSpendEvent(tx *wire.MsgTx) *chainntnfs.SpendEvent {
|
||||
Cancel: func() {},
|
||||
}
|
||||
}
|
||||
|
||||
// TestPrepareSweepTx tests the prepareSweepTx function behavior.
|
||||
func TestPrepareSweepTx(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Create test inputs with different values.
|
||||
inp1 := createTestInput(1000000, input.WitnessKeyHash)
|
||||
inp2 := createTestInput(2000000, input.WitnessKeyHash)
|
||||
|
||||
// Test fee rate and height.
|
||||
feeRate := chainfee.SatPerKWeight(1000)
|
||||
currentHeight := int32(800000)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
inputs []input.Input
|
||||
changePkScript lnwallet.AddrWithKey
|
||||
feeRate chainfee.SatPerKWeight
|
||||
currentHeight int32
|
||||
auxSweeper fn.Option[AuxSweeper]
|
||||
expectedErr error
|
||||
checkResults func(t *testing.T, fee btcutil.Amount,
|
||||
changeOuts fn.Option[[]SweepOutput],
|
||||
locktime fn.Option[int32])
|
||||
}{
|
||||
{
|
||||
name: "successful sweep with change - no " +
|
||||
"extra output",
|
||||
inputs: []input.Input{&inp1, &inp2},
|
||||
changePkScript: changePkScript,
|
||||
feeRate: feeRate,
|
||||
currentHeight: currentHeight,
|
||||
auxSweeper: fn.None[AuxSweeper](),
|
||||
expectedErr: nil,
|
||||
checkResults: func(t *testing.T, fee btcutil.Amount,
|
||||
changeOuts fn.Option[[]SweepOutput],
|
||||
locktime fn.Option[int32]) {
|
||||
|
||||
// Calculate expected weight - only regular
|
||||
// change output, no extra.
|
||||
expectedWeight, err := calcSweepTxWeight(
|
||||
[]input.Input{&inp1, &inp2},
|
||||
[][]byte{
|
||||
changePkScript.DeliveryAddress,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expected fee based on fee rate and weight.
|
||||
expectedFee := feeRate.FeeForWeight(
|
||||
expectedWeight,
|
||||
)
|
||||
|
||||
require.Equal(t, fee, expectedFee)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "successful sweep with extra output",
|
||||
inputs: []input.Input{&inp1, &inp2},
|
||||
changePkScript: changePkScript,
|
||||
feeRate: feeRate,
|
||||
currentHeight: currentHeight,
|
||||
auxSweeper: fn.Some[AuxSweeper](&MockAuxSweeper{}),
|
||||
expectedErr: nil,
|
||||
checkResults: func(t *testing.T, fee btcutil.Amount,
|
||||
changeOuts fn.Option[[]SweepOutput],
|
||||
locktime fn.Option[int32]) {
|
||||
|
||||
// Calculate expected weight - includes both
|
||||
// regular change and extra output.
|
||||
expectedWeight, err := calcSweepTxWeight(
|
||||
[]input.Input{&inp1, &inp2},
|
||||
[][]byte{changePkScript.DeliveryAddress,
|
||||
changePkScript.DeliveryAddress},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Expected fee based on fee rate and weight.
|
||||
expectedFee := feeRate.FeeForWeight(
|
||||
expectedWeight,
|
||||
)
|
||||
|
||||
require.Equal(t, fee, expectedFee)
|
||||
|
||||
// Should have change outputs (both regular
|
||||
// and extra).
|
||||
require.True(t, changeOuts.IsSome())
|
||||
outputs := changeOuts.UnwrapOr([]SweepOutput{})
|
||||
require.Equal(t, 2, len(outputs))
|
||||
|
||||
// Check if extra output is present.
|
||||
hasExtra := false
|
||||
for _, out := range outputs {
|
||||
if out.IsExtra {
|
||||
hasExtra = true
|
||||
break
|
||||
}
|
||||
}
|
||||
require.True(
|
||||
t, hasExtra, "Should have extra output",
|
||||
)
|
||||
|
||||
// Locktime should be None since no inputs
|
||||
// require locktime.
|
||||
require.True(t, locktime.IsNone())
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "insufficient inputs",
|
||||
inputs: []input.Input{},
|
||||
changePkScript: changePkScript,
|
||||
feeRate: feeRate,
|
||||
currentHeight: currentHeight,
|
||||
auxSweeper: fn.None[AuxSweeper](),
|
||||
expectedErr: ErrNotEnoughInputs,
|
||||
},
|
||||
{
|
||||
name: "high fee rate causes insufficient " +
|
||||
"inputs",
|
||||
inputs: []input.Input{&inp1},
|
||||
changePkScript: changePkScript,
|
||||
feeRate: chainfee.SatPerKWeight(10000000),
|
||||
currentHeight: currentHeight,
|
||||
auxSweeper: fn.None[AuxSweeper](),
|
||||
expectedErr: ErrNotEnoughInputs,
|
||||
},
|
||||
{
|
||||
name: "immature locktime",
|
||||
inputs: []input.Input{
|
||||
createTestInputWithLocktime(
|
||||
1000000, input.WitnessKeyHash,
|
||||
uint32(currentHeight+100),
|
||||
),
|
||||
},
|
||||
changePkScript: changePkScript,
|
||||
feeRate: feeRate,
|
||||
currentHeight: currentHeight,
|
||||
auxSweeper: fn.None[AuxSweeper](),
|
||||
expectedErr: ErrLocktimeImmature,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fee, changeOuts, locktime, err := prepareSweepTx(
|
||||
tc.inputs,
|
||||
tc.changePkScript,
|
||||
tc.feeRate,
|
||||
tc.currentHeight,
|
||||
tc.auxSweeper,
|
||||
)
|
||||
|
||||
// Check error expectations.
|
||||
if tc.expectedErr != nil {
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
return
|
||||
}
|
||||
|
||||
// For successful cases, run additional checks.
|
||||
require.NoError(t, err)
|
||||
if tc.checkResults != nil {
|
||||
tc.checkResults(t, fee, changeOuts, locktime)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// createTestInputWithLocktime creates a test input with a specific locktime
|
||||
// requirement.
|
||||
func createTestInputWithLocktime(value int64, witnessType input.WitnessType,
|
||||
locktime uint32) *input.BaseInput {
|
||||
|
||||
// Create a unique test identifier based on input count.
|
||||
hash := chainhash.Hash{}
|
||||
hash[lntypes.HashSize-1] = byte(testInputCount.Add(1))
|
||||
|
||||
// Use NewCsvInputWithCltv to create an input with locktime requirement.
|
||||
inp := input.NewCsvInputWithCltv(
|
||||
&wire.OutPoint{
|
||||
Hash: hash,
|
||||
},
|
||||
witnessType,
|
||||
&input.SignDescriptor{
|
||||
Output: &wire.TxOut{
|
||||
Value: value,
|
||||
},
|
||||
KeyDesc: keychain.KeyDescriptor{
|
||||
PubKey: testPubKey,
|
||||
},
|
||||
},
|
||||
1, 0, locktime,
|
||||
)
|
||||
|
||||
return inp
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user