diff --git a/chanacceptor/acceptor_test.go b/chanacceptor/acceptor_test.go index 19f452289..5a6aaa012 100644 --- a/chanacceptor/acceptor_test.go +++ b/chanacceptor/acceptor_test.go @@ -55,7 +55,7 @@ func newChanAcceptorCtx(t *testing.T, acceptCallCount int, testCtx.acceptor = NewRPCAcceptor( testCtx.receiveResponse, testCtx.sendRequest, testTimeout*5, - &chaincfg.TestNet3Params, testCtx.quit, + &chaincfg.RegressionNetParams, testCtx.quit, ) return testCtx @@ -162,7 +162,7 @@ func (c *channelAcceptorCtx) queryAndAssert(queries map[*lnwire.OpenChannel]*Cha func TestMultipleAcceptClients(t *testing.T) { testAddr := "bcrt1qwrmq9uca0t3dy9t9wtuq5tm4405r7tfzyqn9pp" testUpfront, err := chancloser.ParseUpfrontShutdownAddress( - testAddr, &chaincfg.TestNet3Params, + testAddr, &chaincfg.RegressionNetParams, ) require.NoError(t, err) diff --git a/chanacceptor/rpcacceptor_test.go b/chanacceptor/rpcacceptor_test.go index b755924eb..de1f380c1 100644 --- a/chanacceptor/rpcacceptor_test.go +++ b/chanacceptor/rpcacceptor_test.go @@ -20,7 +20,7 @@ func TestValidateAcceptorResponse(t *testing.T) { customError = errors.New("custom error") validAddr = "bcrt1qwrmq9uca0t3dy9t9wtuq5tm4405r7tfzyqn9pp" addr, _ = chancloser.ParseUpfrontShutdownAddress( - validAddr, &chaincfg.TestNet3Params, + validAddr, &chaincfg.RegressionNetParams, ) ) @@ -124,7 +124,7 @@ func TestValidateAcceptorResponse(t *testing.T) { // Create an acceptor, everything can be nil because // we just need the params. acceptor := NewRPCAcceptor( - nil, nil, 0, &chaincfg.TestNet3Params, nil, + nil, nil, 0, &chaincfg.RegressionNetParams, nil, ) accept, acceptErr, shutdown, err := acceptor.validateAcceptorResponse( diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 2ddf47950..af71e7b92 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -918,5 +918,10 @@ func ParseUpfrontShutdownAddress(address string, return nil, fmt.Errorf("invalid address: %v", err) } + if !addr.IsForNet(params) { + return nil, fmt.Errorf("invalid address: %v is not a %s "+ + "address", addr, params.Name) + } + return txscript.PayToAddrScript(addr) } diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 046ccb515..6a9c8a2f0 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -322,3 +322,58 @@ func TestMaxFeeBailOut(t *testing.T) { }) } } + +// TestParseUpfrontShutdownAddress tests the we are able to parse the upfront +// shutdown address properly. +func TestParseUpfrontShutdownAddress(t *testing.T) { + t.Parallel() + + var ( + testnetAddress = "tb1qdfkmwwgdaa5dnezrlhtftvmj5qn2kwgp7n0z6r" + regtestAddress = "bcrt1q09crvvuj95x5nk64wsxf5n6ky0kr8358vpx4d8" + ) + + tests := []struct { + name string + address string + params chaincfg.Params + expectedErr string + }{ + { + name: "invalid closing address", + address: "non-valid-address", + params: chaincfg.RegressionNetParams, + expectedErr: "invalid address", + }, + { + name: "closing address from another net", + address: testnetAddress, + params: chaincfg.RegressionNetParams, + expectedErr: "not a regtest address", + }, + { + name: "valid p2wkh closing address", + address: regtestAddress, + params: chaincfg.RegressionNetParams, + }, + } + + for _, tc := range tests { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + _, err := ParseUpfrontShutdownAddress( + tc.address, &tc.params, + ) + + if tc.expectedErr != "" { + require.ErrorContains(t, err, tc.expectedErr) + return + } + + require.NoError(t, err) + }) + } +}