sweep: add missing output to the weight estimation

When overlay channels are used the extra output needs to be
considered.
This commit is contained in:
ziggie
2025-09-02 09:59:06 +02:00
parent 22ac4082a4
commit d257198365
2 changed files with 208 additions and 0 deletions

View File

@@ -1668,6 +1668,16 @@ func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey,
return 0, noChange, noLocktime, err
}
// We also add the extra change output to the change pk scripts.
//
// NOTE: The weight estimation will not be quite accurate because the
// witness data is greater when overlay channels are used. But that
// shouldn't be a problem since we will increase the fee rate
// incrementally via the fee function.
extraChangeOut.WhenSome(func(o SweepOutput) {
changePkScripts = append(changePkScripts, o.TxOut.PkScript)
})
// Creating a weight estimator with nil outputs and zero max fee rate.
// We don't allow adding customized outputs in the sweeping tx, and the
// fee rate is already being managed before we get here.

View File

@@ -14,6 +14,7 @@ import (
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
"github.com/stretchr/testify/mock"
@@ -2109,3 +2110,200 @@ func createTestSpendEvent(tx *wire.MsgTx) *chainntnfs.SpendEvent {
Cancel: func() {},
}
}
// TestPrepareSweepTx tests the prepareSweepTx function behavior.
func TestPrepareSweepTx(t *testing.T) {
t.Parallel()
// Create test inputs with different values.
inp1 := createTestInput(1000000, input.WitnessKeyHash)
inp2 := createTestInput(2000000, input.WitnessKeyHash)
// Test fee rate and height.
feeRate := chainfee.SatPerKWeight(1000)
currentHeight := int32(800000)
testCases := []struct {
name string
inputs []input.Input
changePkScript lnwallet.AddrWithKey
feeRate chainfee.SatPerKWeight
currentHeight int32
auxSweeper fn.Option[AuxSweeper]
expectedErr error
checkResults func(t *testing.T, fee btcutil.Amount,
changeOuts fn.Option[[]SweepOutput],
locktime fn.Option[int32])
}{
{
name: "successful sweep with change - no " +
"extra output",
inputs: []input.Input{&inp1, &inp2},
changePkScript: changePkScript,
feeRate: feeRate,
currentHeight: currentHeight,
auxSweeper: fn.None[AuxSweeper](),
expectedErr: nil,
checkResults: func(t *testing.T, fee btcutil.Amount,
changeOuts fn.Option[[]SweepOutput],
locktime fn.Option[int32]) {
// Calculate expected weight - only regular
// change output, no extra.
expectedWeight, err := calcSweepTxWeight(
[]input.Input{&inp1, &inp2},
[][]byte{
changePkScript.DeliveryAddress,
},
)
require.NoError(t, err)
// Expected fee based on fee rate and weight.
expectedFee := feeRate.FeeForWeight(
expectedWeight,
)
require.Equal(t, fee, expectedFee)
},
},
{
name: "successful sweep with extra output",
inputs: []input.Input{&inp1, &inp2},
changePkScript: changePkScript,
feeRate: feeRate,
currentHeight: currentHeight,
auxSweeper: fn.Some[AuxSweeper](&MockAuxSweeper{}),
expectedErr: nil,
checkResults: func(t *testing.T, fee btcutil.Amount,
changeOuts fn.Option[[]SweepOutput],
locktime fn.Option[int32]) {
// Calculate expected weight - includes both
// regular change and extra output.
expectedWeight, err := calcSweepTxWeight(
[]input.Input{&inp1, &inp2},
[][]byte{changePkScript.DeliveryAddress,
changePkScript.DeliveryAddress},
)
require.NoError(t, err)
// Expected fee based on fee rate and weight.
expectedFee := feeRate.FeeForWeight(
expectedWeight,
)
require.Equal(t, fee, expectedFee)
// Should have change outputs (both regular
// and extra).
require.True(t, changeOuts.IsSome())
outputs := changeOuts.UnwrapOr([]SweepOutput{})
require.Equal(t, 2, len(outputs))
// Check if extra output is present.
hasExtra := false
for _, out := range outputs {
if out.IsExtra {
hasExtra = true
break
}
}
require.True(
t, hasExtra, "Should have extra output",
)
// Locktime should be None since no inputs
// require locktime.
require.True(t, locktime.IsNone())
},
},
{
name: "insufficient inputs",
inputs: []input.Input{},
changePkScript: changePkScript,
feeRate: feeRate,
currentHeight: currentHeight,
auxSweeper: fn.None[AuxSweeper](),
expectedErr: ErrNotEnoughInputs,
},
{
name: "high fee rate causes insufficient " +
"inputs",
inputs: []input.Input{&inp1},
changePkScript: changePkScript,
feeRate: chainfee.SatPerKWeight(10000000),
currentHeight: currentHeight,
auxSweeper: fn.None[AuxSweeper](),
expectedErr: ErrNotEnoughInputs,
},
{
name: "immature locktime",
inputs: []input.Input{
createTestInputWithLocktime(
1000000, input.WitnessKeyHash,
uint32(currentHeight+100),
),
},
changePkScript: changePkScript,
feeRate: feeRate,
currentHeight: currentHeight,
auxSweeper: fn.None[AuxSweeper](),
expectedErr: ErrLocktimeImmature,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fee, changeOuts, locktime, err := prepareSweepTx(
tc.inputs,
tc.changePkScript,
tc.feeRate,
tc.currentHeight,
tc.auxSweeper,
)
// Check error expectations.
if tc.expectedErr != nil {
require.ErrorIs(t, err, tc.expectedErr)
return
}
// For successful cases, run additional checks.
require.NoError(t, err)
if tc.checkResults != nil {
tc.checkResults(t, fee, changeOuts, locktime)
}
})
}
}
// createTestInputWithLocktime creates a test input with a specific locktime
// requirement.
func createTestInputWithLocktime(value int64, witnessType input.WitnessType,
locktime uint32) *input.BaseInput {
// Create a unique test identifier based on input count.
hash := chainhash.Hash{}
hash[lntypes.HashSize-1] = byte(testInputCount.Add(1))
// Use NewCsvInputWithCltv to create an input with locktime requirement.
inp := input.NewCsvInputWithCltv(
&wire.OutPoint{
Hash: hash,
},
witnessType,
&input.SignDescriptor{
Output: &wire.TxOut{
Value: value,
},
KeyDesc: keychain.KeyDescriptor{
PubKey: testPubKey,
},
},
1, 0, locktime,
)
return inp
}