mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 01:33:02 +01:00
peer: update chooseDeliveryScript to gen script if needed
In this commit, we update `chooseDeliveryScript` to generate a new script if needed. This allows us to fold in a few other lines that always followed this function into this expanded function. The tests have been updated accordingly.
This commit is contained in:
parent
8432e706d3
commit
baf0d83c37
@ -3136,32 +3136,37 @@ func (p *Brontide) retryRequestEnable(activeChans map[wire.OutPoint]struct{}) {
|
||||
// chooseDeliveryScript takes two optionally set shutdown scripts and returns
|
||||
// a suitable script to close out to. This may be nil if neither script is
|
||||
// set. If both scripts are set, this function will error if they do not match.
|
||||
func chooseDeliveryScript(upfront,
|
||||
requested lnwire.DeliveryAddress) (lnwire.DeliveryAddress, error) {
|
||||
func chooseDeliveryScript(upfront, requested lnwire.DeliveryAddress,
|
||||
genDeliveryScript func() ([]byte, error),
|
||||
) (lnwire.DeliveryAddress, error) {
|
||||
|
||||
switch {
|
||||
// If no script was provided, then we'll generate a new delivery script.
|
||||
case len(upfront) == 0 && len(requested) == 0:
|
||||
return genDeliveryScript()
|
||||
|
||||
// If no upfront shutdown script was provided, return the user
|
||||
// requested address (which may be nil).
|
||||
if len(upfront) == 0 {
|
||||
case len(upfront) == 0:
|
||||
return requested, nil
|
||||
}
|
||||
|
||||
// If an upfront shutdown script was provided, and the user did not
|
||||
// request a custom shutdown script, return the upfront address.
|
||||
if len(requested) == 0 {
|
||||
case len(requested) == 0:
|
||||
return upfront, nil
|
||||
}
|
||||
|
||||
// If both an upfront shutdown script and a custom close script were
|
||||
// provided, error if the user provided shutdown script does not match
|
||||
// the upfront shutdown script (because closing out to a different
|
||||
// script would violate upfront shutdown).
|
||||
if !bytes.Equal(upfront, requested) {
|
||||
case !bytes.Equal(upfront, requested):
|
||||
return nil, chancloser.ErrUpfrontShutdownScriptMismatch
|
||||
}
|
||||
|
||||
// The user requested script matches the upfront shutdown script, so we
|
||||
// can return it without error.
|
||||
return upfront, nil
|
||||
default:
|
||||
return upfront, nil
|
||||
}
|
||||
}
|
||||
|
||||
// restartCoopClose checks whether we need to restart the cooperative close
|
||||
@ -3340,6 +3345,7 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
|
||||
// are set) and error if they are both set and do not match.
|
||||
deliveryScript, err := chooseDeliveryScript(
|
||||
channel.LocalUpfrontShutdownScript(), req.DeliveryScript,
|
||||
p.genDeliveryScript,
|
||||
)
|
||||
if err != nil {
|
||||
p.log.Errorf("cannot close channel %v: %v", req.ChanPoint, err)
|
||||
@ -3347,16 +3353,6 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
|
||||
return
|
||||
}
|
||||
|
||||
// If neither an upfront address or a user set address was
|
||||
// provided, generate a fresh script.
|
||||
if len(deliveryScript) == 0 {
|
||||
deliveryScript, err = p.genDeliveryScript()
|
||||
if err != nil {
|
||||
p.log.Errorf(err.Error())
|
||||
req.Err <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
addr, err := p.addrWithInternalKey(deliveryScript)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to parse addr for channel "+
|
||||
|
@ -689,15 +689,9 @@ func TestChooseDeliveryScript(t *testing.T) {
|
||||
userScript lnwire.DeliveryAddress
|
||||
shutdownScript lnwire.DeliveryAddress
|
||||
expectedScript lnwire.DeliveryAddress
|
||||
newAddr func() ([]byte, error)
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "Neither set",
|
||||
userScript: nil,
|
||||
shutdownScript: nil,
|
||||
expectedScript: nil,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "Both set and equal",
|
||||
userScript: script1,
|
||||
@ -726,6 +720,16 @@ func TestChooseDeliveryScript(t *testing.T) {
|
||||
expectedScript: script2,
|
||||
expectedError: nil,
|
||||
},
|
||||
{
|
||||
name: "no script generate new one",
|
||||
userScript: nil,
|
||||
shutdownScript: nil,
|
||||
expectedScript: script2,
|
||||
newAddr: func() ([]byte, error) {
|
||||
return script2, nil
|
||||
},
|
||||
expectedError: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
@ -734,13 +738,16 @@ func TestChooseDeliveryScript(t *testing.T) {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
script, err := chooseDeliveryScript(
|
||||
test.shutdownScript, test.userScript,
|
||||
test.newAddr,
|
||||
)
|
||||
if err != test.expectedError {
|
||||
t.Fatalf("Expected: %v, got: %v", test.expectedError, err)
|
||||
t.Fatalf("Expected: %v, got: %v",
|
||||
test.expectedError, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(script, test.expectedScript) {
|
||||
t.Fatalf("Expected: %x, got: %x", test.expectedScript, script)
|
||||
t.Fatalf("Expected: %x, got: %x",
|
||||
test.expectedScript, script)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user