diff --git a/peer/brontide.go b/peer/brontide.go index 19a2d8eb3..4e99923d7 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -3136,32 +3136,37 @@ func (p *Brontide) retryRequestEnable(activeChans map[wire.OutPoint]struct{}) { // chooseDeliveryScript takes two optionally set shutdown scripts and returns // a suitable script to close out to. This may be nil if neither script is // set. If both scripts are set, this function will error if they do not match. -func chooseDeliveryScript(upfront, - requested lnwire.DeliveryAddress) (lnwire.DeliveryAddress, error) { +func chooseDeliveryScript(upfront, requested lnwire.DeliveryAddress, + genDeliveryScript func() ([]byte, error), +) (lnwire.DeliveryAddress, error) { + + switch { + // If no script was provided, then we'll generate a new delivery script. + case len(upfront) == 0 && len(requested) == 0: + return genDeliveryScript() // If no upfront shutdown script was provided, return the user // requested address (which may be nil). - if len(upfront) == 0 { + case len(upfront) == 0: return requested, nil - } // If an upfront shutdown script was provided, and the user did not // request a custom shutdown script, return the upfront address. - if len(requested) == 0 { + case len(requested) == 0: return upfront, nil - } // If both an upfront shutdown script and a custom close script were // provided, error if the user provided shutdown script does not match // the upfront shutdown script (because closing out to a different // script would violate upfront shutdown). - if !bytes.Equal(upfront, requested) { + case !bytes.Equal(upfront, requested): return nil, chancloser.ErrUpfrontShutdownScriptMismatch - } // The user requested script matches the upfront shutdown script, so we // can return it without error. - return upfront, nil + default: + return upfront, nil + } } // restartCoopClose checks whether we need to restart the cooperative close @@ -3340,6 +3345,7 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { // are set) and error if they are both set and do not match. deliveryScript, err := chooseDeliveryScript( channel.LocalUpfrontShutdownScript(), req.DeliveryScript, + p.genDeliveryScript, ) if err != nil { p.log.Errorf("cannot close channel %v: %v", req.ChanPoint, err) @@ -3347,16 +3353,6 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { return } - // If neither an upfront address or a user set address was - // provided, generate a fresh script. - if len(deliveryScript) == 0 { - deliveryScript, err = p.genDeliveryScript() - if err != nil { - p.log.Errorf(err.Error()) - req.Err <- err - return - } - } addr, err := p.addrWithInternalKey(deliveryScript) if err != nil { err = fmt.Errorf("unable to parse addr for channel "+ diff --git a/peer/brontide_test.go b/peer/brontide_test.go index eded65888..283b62671 100644 --- a/peer/brontide_test.go +++ b/peer/brontide_test.go @@ -689,15 +689,9 @@ func TestChooseDeliveryScript(t *testing.T) { userScript lnwire.DeliveryAddress shutdownScript lnwire.DeliveryAddress expectedScript lnwire.DeliveryAddress + newAddr func() ([]byte, error) expectedError error }{ - { - name: "Neither set", - userScript: nil, - shutdownScript: nil, - expectedScript: nil, - expectedError: nil, - }, { name: "Both set and equal", userScript: script1, @@ -726,6 +720,16 @@ func TestChooseDeliveryScript(t *testing.T) { expectedScript: script2, expectedError: nil, }, + { + name: "no script generate new one", + userScript: nil, + shutdownScript: nil, + expectedScript: script2, + newAddr: func() ([]byte, error) { + return script2, nil + }, + expectedError: nil, + }, } for _, test := range tests { @@ -734,13 +738,16 @@ func TestChooseDeliveryScript(t *testing.T) { t.Run(test.name, func(t *testing.T) { script, err := chooseDeliveryScript( test.shutdownScript, test.userScript, + test.newAddr, ) if err != test.expectedError { - t.Fatalf("Expected: %v, got: %v", test.expectedError, err) + t.Fatalf("Expected: %v, got: %v", + test.expectedError, err) } if !bytes.Equal(script, test.expectedScript) { - t.Fatalf("Expected: %x, got: %x", test.expectedScript, script) + t.Fatalf("Expected: %x, got: %x", + test.expectedScript, script) } }) }