peer: update chooseDeliveryScript to gen script if needed

In this commit, we update `chooseDeliveryScript` to generate a new
script if needed. This allows us to fold in a few other lines that
always followed this function into this expanded function.

The tests have been updated accordingly.
This commit is contained in:
Olaoluwa Osuntokun 2025-02-25 18:55:06 -08:00
parent 8432e706d3
commit baf0d83c37
2 changed files with 31 additions and 28 deletions

View File

@ -3136,32 +3136,37 @@ func (p *Brontide) retryRequestEnable(activeChans map[wire.OutPoint]struct{}) {
// chooseDeliveryScript takes two optionally set shutdown scripts and returns
// a suitable script to close out to. This may be nil if neither script is
// set. If both scripts are set, this function will error if they do not match.
func chooseDeliveryScript(upfront,
requested lnwire.DeliveryAddress) (lnwire.DeliveryAddress, error) {
func chooseDeliveryScript(upfront, requested lnwire.DeliveryAddress,
genDeliveryScript func() ([]byte, error),
) (lnwire.DeliveryAddress, error) {
switch {
// If no script was provided, then we'll generate a new delivery script.
case len(upfront) == 0 && len(requested) == 0:
return genDeliveryScript()
// If no upfront shutdown script was provided, return the user
// requested address (which may be nil).
if len(upfront) == 0 {
case len(upfront) == 0:
return requested, nil
}
// If an upfront shutdown script was provided, and the user did not
// request a custom shutdown script, return the upfront address.
if len(requested) == 0 {
case len(requested) == 0:
return upfront, nil
}
// If both an upfront shutdown script and a custom close script were
// provided, error if the user provided shutdown script does not match
// the upfront shutdown script (because closing out to a different
// script would violate upfront shutdown).
if !bytes.Equal(upfront, requested) {
case !bytes.Equal(upfront, requested):
return nil, chancloser.ErrUpfrontShutdownScriptMismatch
}
// The user requested script matches the upfront shutdown script, so we
// can return it without error.
return upfront, nil
default:
return upfront, nil
}
}
// restartCoopClose checks whether we need to restart the cooperative close
@ -3340,6 +3345,7 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
// are set) and error if they are both set and do not match.
deliveryScript, err := chooseDeliveryScript(
channel.LocalUpfrontShutdownScript(), req.DeliveryScript,
p.genDeliveryScript,
)
if err != nil {
p.log.Errorf("cannot close channel %v: %v", req.ChanPoint, err)
@ -3347,16 +3353,6 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
return
}
// If neither an upfront address or a user set address was
// provided, generate a fresh script.
if len(deliveryScript) == 0 {
deliveryScript, err = p.genDeliveryScript()
if err != nil {
p.log.Errorf(err.Error())
req.Err <- err
return
}
}
addr, err := p.addrWithInternalKey(deliveryScript)
if err != nil {
err = fmt.Errorf("unable to parse addr for channel "+

View File

@ -689,15 +689,9 @@ func TestChooseDeliveryScript(t *testing.T) {
userScript lnwire.DeliveryAddress
shutdownScript lnwire.DeliveryAddress
expectedScript lnwire.DeliveryAddress
newAddr func() ([]byte, error)
expectedError error
}{
{
name: "Neither set",
userScript: nil,
shutdownScript: nil,
expectedScript: nil,
expectedError: nil,
},
{
name: "Both set and equal",
userScript: script1,
@ -726,6 +720,16 @@ func TestChooseDeliveryScript(t *testing.T) {
expectedScript: script2,
expectedError: nil,
},
{
name: "no script generate new one",
userScript: nil,
shutdownScript: nil,
expectedScript: script2,
newAddr: func() ([]byte, error) {
return script2, nil
},
expectedError: nil,
},
}
for _, test := range tests {
@ -734,13 +738,16 @@ func TestChooseDeliveryScript(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
script, err := chooseDeliveryScript(
test.shutdownScript, test.userScript,
test.newAddr,
)
if err != test.expectedError {
t.Fatalf("Expected: %v, got: %v", test.expectedError, err)
t.Fatalf("Expected: %v, got: %v",
test.expectedError, err)
}
if !bytes.Equal(script, test.expectedScript) {
t.Fatalf("Expected: %x, got: %x", test.expectedScript, script)
t.Fatalf("Expected: %x, got: %x",
test.expectedScript, script)
}
})
}