lnwallet/chanclose: update ProcessCloseMsg to check co-op close addrs

We only want to allow p2wkh, p2tr, and p2wsh addresses, so we'll utilize
the newly public wallet function to restrict this.
This commit is contained in:
Olaoluwa Osuntokun
2022-06-10 11:17:20 -07:00
parent c79ffc07ce
commit a61b6c25b3
3 changed files with 109 additions and 37 deletions

View File

@@ -1,12 +1,14 @@
package chancloser
import (
"crypto/rand"
"bytes"
"fmt"
"testing"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet/chainfee"
@@ -14,48 +16,58 @@ import (
"github.com/stretchr/testify/require"
)
// randDeliveryAddress generates a random delivery address for testing.
func randDeliveryAddress(t *testing.T) lnwire.DeliveryAddress {
// Generate an address of maximum length.
da := lnwire.DeliveryAddress(make([]byte, 34))
_, err := rand.Read(da)
require.NoError(t, err, "cannot generate random address")
return da
}
// TestMaybeMatchScript tests that the maybeMatchScript errors appropriately
// when an upfront shutdown script is set and the script provided does not
// match, and does not error in any other case.
func TestMaybeMatchScript(t *testing.T) {
t.Parallel()
addr1 := randDeliveryAddress(t)
addr2 := randDeliveryAddress(t)
pubHash := bytes.Repeat([]byte{0x0}, 20)
scriptHash := bytes.Repeat([]byte{0x0}, 32)
tests := []struct {
p2wkh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
AddData(pubHash).Script()
require.NoError(t, err)
p2wsh, err := txscript.NewScriptBuilder().AddOp(txscript.OP_0).
AddData(scriptHash).Script()
require.NoError(t, err)
p2tr, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
AddData(scriptHash).Script()
require.NoError(t, err)
p2OtherV1, err := txscript.NewScriptBuilder().AddOp(txscript.OP_1).
AddData(pubHash).Script()
require.NoError(t, err)
invalidFork, err := txscript.NewScriptBuilder().AddOp(txscript.OP_NOP).
AddData(scriptHash).Script()
require.NoError(t, err)
type testCase struct {
name string
shutdownScript lnwire.DeliveryAddress
upfrontScript lnwire.DeliveryAddress
expectedErr error
}{
}
tests := []testCase{
{
name: "no upfront shutdown set, script ok",
shutdownScript: addr1,
shutdownScript: p2wkh,
upfrontScript: []byte{},
expectedErr: nil,
},
{
name: "upfront shutdown set, script ok",
shutdownScript: addr1,
upfrontScript: addr1,
shutdownScript: p2wkh,
upfrontScript: p2wkh,
expectedErr: nil,
},
{
name: "upfront shutdown set, script not ok",
shutdownScript: addr1,
upfrontScript: addr2,
shutdownScript: p2wkh,
upfrontScript: p2wsh,
expectedErr: ErrUpfrontShutdownScriptMismatch,
},
{
@@ -64,6 +76,40 @@ func TestMaybeMatchScript(t *testing.T) {
upfrontScript: []byte{},
expectedErr: nil,
},
{
name: "p2tr is ok",
shutdownScript: p2tr,
},
{
name: "segwit v1 is ok",
shutdownScript: p2OtherV1,
},
{
name: "invalid script not allowed",
shutdownScript: invalidFork,
expectedErr: ErrInvalidShutdownScript,
},
}
// All future segwit softforks should also be ok.
futureForks := []byte{
txscript.OP_1, txscript.OP_2, txscript.OP_3, txscript.OP_4,
txscript.OP_5, txscript.OP_6, txscript.OP_7, txscript.OP_8,
txscript.OP_9, txscript.OP_10, txscript.OP_11, txscript.OP_12,
txscript.OP_13, txscript.OP_14, txscript.OP_15, txscript.OP_16,
}
for _, witnessVersion := range futureForks {
p2FutureFork, err := txscript.NewScriptBuilder().AddOp(witnessVersion).
AddData(scriptHash).Script()
require.NoError(t, err)
opString, err := txscript.DisasmString([]byte{witnessVersion})
require.NoError(t, err)
tests = append(tests, testCase{
name: fmt.Sprintf("witness_version=%v", opString),
shutdownScript: p2FutureFork,
})
}
for _, test := range tests {
@@ -72,9 +118,9 @@ func TestMaybeMatchScript(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
t.Parallel()
err := maybeMatchScript(
err := validateShutdownScript(
func() error { return nil }, test.upfrontScript,
test.shutdownScript,
test.shutdownScript, &chaincfg.SimNetParams,
)
if err != test.expectedErr {