From 2204cbfd30736aec11bad485b97fdf379c3ec318 Mon Sep 17 00:00:00 2001 From: positiveblue Date: Fri, 23 Dec 2022 03:26:55 -0800 Subject: [PATCH] rpc: validate closing channel address in open channel requests Our OpenChannelRPC was accepting invalid values for the closing address field. If we were able to decode the address we would use it in the script even if the address is for another bitcoin net. --- chanacceptor/acceptor_test.go | 4 +- chanacceptor/rpcacceptor_test.go | 4 +- lnwallet/chancloser/chancloser.go | 5 +++ lnwallet/chancloser/chancloser_test.go | 55 ++++++++++++++++++++++++++ 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/chanacceptor/acceptor_test.go b/chanacceptor/acceptor_test.go index 19f452289..5a6aaa012 100644 --- a/chanacceptor/acceptor_test.go +++ b/chanacceptor/acceptor_test.go @@ -55,7 +55,7 @@ func newChanAcceptorCtx(t *testing.T, acceptCallCount int, testCtx.acceptor = NewRPCAcceptor( testCtx.receiveResponse, testCtx.sendRequest, testTimeout*5, - &chaincfg.TestNet3Params, testCtx.quit, + &chaincfg.RegressionNetParams, testCtx.quit, ) return testCtx @@ -162,7 +162,7 @@ func (c *channelAcceptorCtx) queryAndAssert(queries map[*lnwire.OpenChannel]*Cha func TestMultipleAcceptClients(t *testing.T) { testAddr := "bcrt1qwrmq9uca0t3dy9t9wtuq5tm4405r7tfzyqn9pp" testUpfront, err := chancloser.ParseUpfrontShutdownAddress( - testAddr, &chaincfg.TestNet3Params, + testAddr, &chaincfg.RegressionNetParams, ) require.NoError(t, err) diff --git a/chanacceptor/rpcacceptor_test.go b/chanacceptor/rpcacceptor_test.go index b755924eb..de1f380c1 100644 --- a/chanacceptor/rpcacceptor_test.go +++ b/chanacceptor/rpcacceptor_test.go @@ -20,7 +20,7 @@ func TestValidateAcceptorResponse(t *testing.T) { customError = errors.New("custom error") validAddr = "bcrt1qwrmq9uca0t3dy9t9wtuq5tm4405r7tfzyqn9pp" addr, _ = chancloser.ParseUpfrontShutdownAddress( - validAddr, &chaincfg.TestNet3Params, + validAddr, &chaincfg.RegressionNetParams, ) ) @@ -124,7 +124,7 @@ func TestValidateAcceptorResponse(t *testing.T) { // Create an acceptor, everything can be nil because // we just need the params. acceptor := NewRPCAcceptor( - nil, nil, 0, &chaincfg.TestNet3Params, nil, + nil, nil, 0, &chaincfg.RegressionNetParams, nil, ) accept, acceptErr, shutdown, err := acceptor.validateAcceptorResponse( diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 2ddf47950..af71e7b92 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -918,5 +918,10 @@ func ParseUpfrontShutdownAddress(address string, return nil, fmt.Errorf("invalid address: %v", err) } + if !addr.IsForNet(params) { + return nil, fmt.Errorf("invalid address: %v is not a %s "+ + "address", addr, params.Name) + } + return txscript.PayToAddrScript(addr) } diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 046ccb515..6a9c8a2f0 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -322,3 +322,58 @@ func TestMaxFeeBailOut(t *testing.T) { }) } } + +// TestParseUpfrontShutdownAddress tests the we are able to parse the upfront +// shutdown address properly. +func TestParseUpfrontShutdownAddress(t *testing.T) { + t.Parallel() + + var ( + testnetAddress = "tb1qdfkmwwgdaa5dnezrlhtftvmj5qn2kwgp7n0z6r" + regtestAddress = "bcrt1q09crvvuj95x5nk64wsxf5n6ky0kr8358vpx4d8" + ) + + tests := []struct { + name string + address string + params chaincfg.Params + expectedErr string + }{ + { + name: "invalid closing address", + address: "non-valid-address", + params: chaincfg.RegressionNetParams, + expectedErr: "invalid address", + }, + { + name: "closing address from another net", + address: testnetAddress, + params: chaincfg.RegressionNetParams, + expectedErr: "not a regtest address", + }, + { + name: "valid p2wkh closing address", + address: regtestAddress, + params: chaincfg.RegressionNetParams, + }, + } + + for _, tc := range tests { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := ParseUpfrontShutdownAddress( + tc.address, &tc.params, + ) + + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + return + } + + require.NoError(t, err) + }) + } +}