Merge pull request #8464 from ellemouton/resend-shutdown-2

multi: resend shutdown on reestablish
This commit is contained in:
Elle 2024-02-21 14:10:05 +02:00 committed by GitHub
commit 16279765eb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 634 additions and 179 deletions

View File

@ -20,6 +20,7 @@ import (
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btcwallet/walletdb" "github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
@ -121,6 +122,12 @@ var (
// broadcasted when moving the channel to state CoopBroadcasted. // broadcasted when moving the channel to state CoopBroadcasted.
coopCloseTxKey = []byte("coop-closing-tx-key") coopCloseTxKey = []byte("coop-closing-tx-key")
// shutdownInfoKey points to the serialised shutdown info that has been
// persisted for a channel. The existence of this info means that we
// have sent the Shutdown message before and so should re-initiate the
// shutdown on re-establish.
shutdownInfoKey = []byte("shutdown-info-key")
// commitDiffKey stores the current pending commitment state we've // commitDiffKey stores the current pending commitment state we've
// extended to the remote party (if any). Each time we propose a new // extended to the remote party (if any). Each time we propose a new
// state, we store the information necessary to reconstruct this state // state, we store the information necessary to reconstruct this state
@ -188,6 +195,10 @@ var (
// in the state CommitBroadcasted. // in the state CommitBroadcasted.
ErrNoCloseTx = fmt.Errorf("no closing tx found") ErrNoCloseTx = fmt.Errorf("no closing tx found")
// ErrNoShutdownInfo is returned when no shutdown info has been
// persisted for a channel.
ErrNoShutdownInfo = errors.New("no shutdown info")
// ErrNoRestoredChannelMutation is returned when a caller attempts to // ErrNoRestoredChannelMutation is returned when a caller attempts to
// mutate a channel that's been recovered. // mutate a channel that's been recovered.
ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " + ErrNoRestoredChannelMutation = fmt.Errorf("cannot mutate restored " +
@ -1575,6 +1586,79 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) {
}, nil }, nil
} }
// MarkShutdownSent serialises and persist the given ShutdownInfo for this
// channel. Persisting this info represents the fact that we have sent the
// Shutdown message to the remote side and hence that we should re-transmit the
// same Shutdown message on re-establish.
func (c *OpenChannel) MarkShutdownSent(info *ShutdownInfo) error {
c.Lock()
defer c.Unlock()
return c.storeShutdownInfo(info)
}
// storeShutdownInfo serialises the ShutdownInfo and persists it under the
// shutdownInfoKey.
func (c *OpenChannel) storeShutdownInfo(info *ShutdownInfo) error {
var b bytes.Buffer
err := info.encode(&b)
if err != nil {
return err
}
return kvdb.Update(c.Db.backend, func(tx kvdb.RwTx) error {
chanBucket, err := fetchChanBucketRw(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
if err != nil {
return err
}
return chanBucket.Put(shutdownInfoKey, b.Bytes())
}, func() {})
}
// ShutdownInfo decodes the shutdown info stored for this channel and returns
// the result. If no shutdown info has been persisted for this channel then the
// ErrNoShutdownInfo error is returned.
func (c *OpenChannel) ShutdownInfo() (fn.Option[ShutdownInfo], error) {
c.RLock()
defer c.RUnlock()
var shutdownInfo *ShutdownInfo
err := kvdb.View(c.Db.backend, func(tx kvdb.RTx) error {
chanBucket, err := fetchChanBucket(
tx, c.IdentityPub, &c.FundingOutpoint, c.ChainHash,
)
switch {
case err == nil:
case errors.Is(err, ErrNoChanDBExists),
errors.Is(err, ErrNoActiveChannels),
errors.Is(err, ErrChannelNotFound):
return ErrNoShutdownInfo
default:
return err
}
shutdownInfoBytes := chanBucket.Get(shutdownInfoKey)
if shutdownInfoBytes == nil {
return ErrNoShutdownInfo
}
shutdownInfo, err = decodeShutdownInfo(shutdownInfoBytes)
return err
}, func() {
shutdownInfo = nil
})
if err != nil {
return fn.None[ShutdownInfo](), err
}
return fn.Some[ShutdownInfo](*shutdownInfo), nil
}
// isBorked returns true if the channel has been marked as borked in the // isBorked returns true if the channel has been marked as borked in the
// database. This requires an existing database transaction to already be // database. This requires an existing database transaction to already be
// active. // active.
@ -4294,3 +4378,59 @@ func MakeScidRecord(typ tlv.Type, scid *lnwire.ShortChannelID) tlv.Record {
typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID, typ, scid, 8, lnwire.EShortChannelID, lnwire.DShortChannelID,
) )
} }
// ShutdownInfo contains various info about the shutdown initiation of a
// channel.
type ShutdownInfo struct {
// DeliveryScript is the address that we have included in any previous
// Shutdown message for a particular channel and so should include in
// any future re-sends of the Shutdown message.
DeliveryScript tlv.RecordT[tlv.TlvType0, lnwire.DeliveryAddress]
// LocalInitiator is true if we sent a Shutdown message before ever
// receiving a Shutdown message from the remote peer.
LocalInitiator tlv.RecordT[tlv.TlvType1, bool]
}
// NewShutdownInfo constructs a new ShutdownInfo object.
func NewShutdownInfo(deliveryScript lnwire.DeliveryAddress,
locallyInitiated bool) *ShutdownInfo {
return &ShutdownInfo{
DeliveryScript: tlv.NewRecordT[tlv.TlvType0](deliveryScript),
LocalInitiator: tlv.NewPrimitiveRecord[tlv.TlvType1](
locallyInitiated,
),
}
}
// encode serialises the ShutdownInfo to the given io.Writer.
func (s *ShutdownInfo) encode(w io.Writer) error {
records := []tlv.Record{
s.DeliveryScript.Record(),
s.LocalInitiator.Record(),
}
stream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return stream.Encode(w)
}
// decodeShutdownInfo constructs a ShutdownInfo struct by decoding the given
// byte slice.
func decodeShutdownInfo(b []byte) (*ShutdownInfo, error) {
tlvStream := lnwire.ExtraOpaqueData(b)
var info ShutdownInfo
records := []tlv.RecordProducer{
&info.DeliveryScript,
&info.LocalInitiator,
}
_, err := tlvStream.ExtractRecords(records...)
return &info, err
}

View File

@ -1158,6 +1158,70 @@ func TestFetchWaitingCloseChannels(t *testing.T) {
} }
} }
// TestShutdownInfo tests that a channel's shutdown info can correctly be
// persisted and retrieved.
func TestShutdownInfo(t *testing.T) {
t.Parallel()
tests := []struct {
name string
localInit bool
}{
{
name: "local node initiated",
localInit: true,
},
{
name: "remote node initiated",
localInit: false,
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
testShutdownInfo(t, test.localInit)
})
}
}
func testShutdownInfo(t *testing.T, locallyInitiated bool) {
fullDB, err := MakeTestDB(t)
require.NoError(t, err, "unable to make test database")
cdb := fullDB.ChannelStateDB()
// First a test channel.
channel := createTestChannel(t, cdb)
// We haven't persisted any shutdown info for this channel yet.
_, err = channel.ShutdownInfo()
require.Error(t, err, ErrNoShutdownInfo)
// Construct a new delivery script and create a new ShutdownInfo object.
script := []byte{1, 3, 4, 5}
// Create a ShutdownInfo struct.
shutdownInfo := NewShutdownInfo(script, locallyInitiated)
// Persist the shutdown info.
require.NoError(t, channel.MarkShutdownSent(shutdownInfo))
// We should now be able to retrieve the shutdown info.
info, err := channel.ShutdownInfo()
require.NoError(t, err)
require.True(t, info.IsSome())
// Assert that the decoded values of the shutdown info are correct.
info.WhenSome(func(info ShutdownInfo) {
require.EqualValues(t, script, info.DeliveryScript.Val)
require.Equal(t, locallyInitiated, info.LocalInitiator.Val)
})
}
// TestRefresh asserts that Refresh updates the in-memory state of another // TestRefresh asserts that Refresh updates the in-memory state of another
// OpenChannel to reflect a preceding call to MarkOpen on a different // OpenChannel to reflect a preceding call to MarkOpen on a different
// OpenChannel. // OpenChannel.

View File

@ -73,6 +73,11 @@
a `shutdown` message if there were currently HTLCs on the channel. After this a `shutdown` message if there were currently HTLCs on the channel. After this
change, the shutdown procedure should be compliant with BOLT2 requirements. change, the shutdown procedure should be compliant with BOLT2 requirements.
* If HTLCs are in-flight at the same time that a `shutdown` is sent and then
a re-connect happens before the coop-close is completed we now [ensure that
we re-init the `shutdown`
exchange](https://github.com/lightningnetwork/lnd/pull/8464)
* The AMP struct in payment hops will [now be populated](https://github.com/lightningnetwork/lnd/pull/7976) when the AMP TLV is set. * The AMP struct in payment hops will [now be populated](https://github.com/lightningnetwork/lnd/pull/7976) when the AMP TLV is set.
* [Add Taproot witness types * [Add Taproot witness types

View File

@ -135,14 +135,16 @@ type ChannelUpdateHandler interface {
MayAddOutgoingHtlc(lnwire.MilliSatoshi) error MayAddOutgoingHtlc(lnwire.MilliSatoshi) error
// EnableAdds sets the ChannelUpdateHandler state to allow // EnableAdds sets the ChannelUpdateHandler state to allow
// UpdateAddHtlc's in the specified direction. It returns an error if // UpdateAddHtlc's in the specified direction. It returns true if the
// the state already allowed those adds. // state was changed and false if the desired state was already set
EnableAdds(direction LinkDirection) error // before the method was called.
EnableAdds(direction LinkDirection) bool
// DiableAdds sets the ChannelUpdateHandler state to allow // DisableAdds sets the ChannelUpdateHandler state to allow
// UpdateAddHtlc's in the specified direction. It returns an error if // UpdateAddHtlc's in the specified direction. It returns true if the
// the state already disallowed those adds. // state was changed and false if the desired state was already set
DisableAdds(direction LinkDirection) error // before the method was called.
DisableAdds(direction LinkDirection) bool
// IsFlushing returns true when UpdateAddHtlc's are disabled in the // IsFlushing returns true when UpdateAddHtlc's are disabled in the
// direction of the argument. // direction of the argument.

View File

@ -19,6 +19,7 @@ import (
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
"github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/htlcswitch/hop"
"github.com/lightningnetwork/lnd/invoices" "github.com/lightningnetwork/lnd/invoices"
@ -271,6 +272,14 @@ type ChannelLinkConfig struct {
// GetAliases is used by the link and switch to fetch the set of // GetAliases is used by the link and switch to fetch the set of
// aliases for a given link. // aliases for a given link.
GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID GetAliases func(base lnwire.ShortChannelID) []lnwire.ShortChannelID
// PreviouslySentShutdown is an optional value that is set if, at the
// time of the link being started, persisted shutdown info was found for
// the channel. This value being set means that we previously sent a
// Shutdown message to our peer, and so we should do so again on
// re-establish and should not allow anymore HTLC adds on the outgoing
// direction of the link.
PreviouslySentShutdown fn.Option[lnwire.Shutdown]
} }
// channelLink is the service which drives a channel's commitment update // channelLink is the service which drives a channel's commitment update
@ -618,41 +627,25 @@ func (l *channelLink) EligibleToUpdate() bool {
} }
// EnableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in // EnableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in
// the specified direction. It returns an error if the state already allowed // the specified direction. It returns true if the state was changed and false
// those adds. // if the desired state was already set before the method was called.
func (l *channelLink) EnableAdds(linkDirection LinkDirection) error { func (l *channelLink) EnableAdds(linkDirection LinkDirection) bool {
if linkDirection == Outgoing { if linkDirection == Outgoing {
if !l.isOutgoingAddBlocked.Swap(false) { return l.isOutgoingAddBlocked.Swap(false)
return errors.New("outgoing adds already enabled")
}
} }
if linkDirection == Incoming { return l.isIncomingAddBlocked.Swap(false)
if !l.isIncomingAddBlocked.Swap(false) {
return errors.New("incoming adds already enabled")
}
}
return nil
} }
// DiableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in // DisableAdds sets the ChannelUpdateHandler state to allow UpdateAddHtlc's in
// the specified direction. It returns an error if the state already disallowed // the specified direction. It returns true if the state was changed and false
// those adds. // if the desired state was already set before the method was called.
func (l *channelLink) DisableAdds(linkDirection LinkDirection) error { func (l *channelLink) DisableAdds(linkDirection LinkDirection) bool {
if linkDirection == Outgoing { if linkDirection == Outgoing {
if l.isOutgoingAddBlocked.Swap(true) { return !l.isOutgoingAddBlocked.Swap(true)
return errors.New("outgoing adds already disabled")
}
} }
if linkDirection == Incoming { return !l.isIncomingAddBlocked.Swap(true)
if l.isIncomingAddBlocked.Swap(true) {
return errors.New("incoming adds already disabled")
}
}
return nil
} }
// IsFlushing returns true when UpdateAddHtlc's are disabled in the direction of // IsFlushing returns true when UpdateAddHtlc's are disabled in the direction of
@ -1206,6 +1199,25 @@ func (l *channelLink) htlcManager() {
} }
} }
// If a shutdown message has previously been sent on this link, then we
// need to make sure that we have disabled any HTLC adds on the outgoing
// direction of the link and that we re-resend the same shutdown message
// that we previously sent.
l.cfg.PreviouslySentShutdown.WhenSome(func(shutdown lnwire.Shutdown) {
// Immediately disallow any new outgoing HTLCs.
if !l.DisableAdds(Outgoing) {
l.log.Warnf("Outgoing link adds already disabled")
}
// Re-send the shutdown message the peer. Since syncChanStates
// would have sent any outstanding CommitSig, it is fine for us
// to immediately queue the shutdown message now.
err := l.cfg.Peer.SendMessage(false, &shutdown)
if err != nil {
l.log.Warnf("Error sending shutdown message: %v", err)
}
})
// We've successfully reestablished the channel, mark it as such to // We've successfully reestablished the channel, mark it as such to
// allow the switch to forward HTLCs in the outbound direction. // allow the switch to forward HTLCs in the outbound direction.
l.markReestablished() l.markReestablished()

View File

@ -6969,27 +6969,22 @@ func TestLinkFlushApiDirectionIsolation(t *testing.T) {
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if prand.Uint64()%2 == 0 { if prand.Uint64()%2 == 0 {
//nolint:errcheck
aliceLink.EnableAdds(Outgoing) aliceLink.EnableAdds(Outgoing)
require.False(t, aliceLink.IsFlushing(Outgoing)) require.False(t, aliceLink.IsFlushing(Outgoing))
} else { } else {
//nolint:errcheck
aliceLink.DisableAdds(Outgoing) aliceLink.DisableAdds(Outgoing)
require.True(t, aliceLink.IsFlushing(Outgoing)) require.True(t, aliceLink.IsFlushing(Outgoing))
} }
require.False(t, aliceLink.IsFlushing(Incoming)) require.False(t, aliceLink.IsFlushing(Incoming))
} }
//nolint:errcheck
aliceLink.EnableAdds(Outgoing) aliceLink.EnableAdds(Outgoing)
for i := 0; i < 10; i++ { for i := 0; i < 10; i++ {
if prand.Uint64()%2 == 0 { if prand.Uint64()%2 == 0 {
//nolint:errcheck
aliceLink.EnableAdds(Incoming) aliceLink.EnableAdds(Incoming)
require.False(t, aliceLink.IsFlushing(Incoming)) require.False(t, aliceLink.IsFlushing(Incoming))
} else { } else {
//nolint:errcheck
aliceLink.DisableAdds(Incoming) aliceLink.DisableAdds(Incoming)
require.True(t, aliceLink.IsFlushing(Incoming)) require.True(t, aliceLink.IsFlushing(Incoming))
} }
@ -7010,16 +7005,16 @@ func TestLinkFlushApiGateStateIdempotence(t *testing.T) {
) )
for _, dir := range []LinkDirection{Incoming, Outgoing} { for _, dir := range []LinkDirection{Incoming, Outgoing} {
require.Nil(t, aliceLink.DisableAdds(dir)) require.True(t, aliceLink.DisableAdds(dir))
require.True(t, aliceLink.IsFlushing(dir)) require.True(t, aliceLink.IsFlushing(dir))
require.NotNil(t, aliceLink.DisableAdds(dir)) require.False(t, aliceLink.DisableAdds(dir))
require.True(t, aliceLink.IsFlushing(dir)) require.True(t, aliceLink.IsFlushing(dir))
require.Nil(t, aliceLink.EnableAdds(dir)) require.True(t, aliceLink.EnableAdds(dir))
require.False(t, aliceLink.IsFlushing(dir)) require.False(t, aliceLink.IsFlushing(dir))
require.NotNil(t, aliceLink.EnableAdds(dir)) require.False(t, aliceLink.EnableAdds(dir))
require.False(t, aliceLink.IsFlushing(dir)) require.False(t, aliceLink.IsFlushing(dir))
} }
} }

View File

@ -906,13 +906,14 @@ func (f *mockChannelLink) UpdateShortChanID() (lnwire.ShortChannelID, error) {
return f.shortChanID, nil return f.shortChanID, nil
} }
func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) error { func (f *mockChannelLink) EnableAdds(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement // TODO(proofofkeags): Implement
return nil return true
} }
func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) error {
func (f *mockChannelLink) DisableAdds(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement // TODO(proofofkeags): Implement
return nil return true
} }
func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool { func (f *mockChannelLink) IsFlushing(linkDirection LinkDirection) bool {
// TODO(proofofkeags): Implement // TODO(proofofkeags): Implement

View File

@ -1,24 +1,44 @@
package itest package itest
import ( import (
"testing"
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
"github.com/lightningnetwork/lnd/lnrpc/invoicesrpc" "github.com/lightningnetwork/lnd/lnrpc/invoicesrpc"
"github.com/lightningnetwork/lnd/lnrpc/routerrpc" "github.com/lightningnetwork/lnd/lnrpc/routerrpc"
"github.com/lightningnetwork/lnd/lnrpc/walletrpc"
"github.com/lightningnetwork/lnd/lntest" "github.com/lightningnetwork/lnd/lntest"
"github.com/lightningnetwork/lnd/lntest/wait"
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
// testCoopCloseWithHtlcs tests whether or not we can successfully issue a coop // testCoopCloseWithHtlcs tests whether we can successfully issue a coop close
// close request whilt there are still active htlcs on the link. Here we will // request while there are still active htlcs on the link. In all the tests, we
// set up an HODL invoice to suspend settlement. Then we will attempt to close // will set up an HODL invoice to suspend settlement. Then we will attempt to
// the channel which should appear as a noop for the time being. Then we will // close the channel which should appear as a noop for the time being. Then we
// have the receiver settle the invoice and observe that the channel gets torn // will have the receiver settle the invoice and observe that the channel gets
// down after settlement. // torn down after settlement.
func testCoopCloseWithHtlcs(ht *lntest.HarnessTest) { func testCoopCloseWithHtlcs(ht *lntest.HarnessTest) {
ht.Run("no restart", func(t *testing.T) {
tt := ht.Subtest(t)
coopCloseWithHTLCs(tt)
})
ht.Run("with restart", func(t *testing.T) {
tt := ht.Subtest(t)
coopCloseWithHTLCsWithRestart(tt)
})
}
// coopCloseWithHTLCs tests the basic coop close scenario which occurs when one
// channel party initiates a channel shutdown while an HTLC is still pending on
// the channel.
func coopCloseWithHTLCs(ht *lntest.HarnessTest) {
alice, bob := ht.Alice, ht.Bob alice, bob := ht.Alice, ht.Bob
ht.ConnectNodes(alice, bob)
// Here we set up a channel between Alice and Bob, beginning with a // Here we set up a channel between Alice and Bob, beginning with a
// balance on Bob's side. // balance on Bob's side.
@ -101,3 +121,123 @@ func testCoopCloseWithHtlcs(ht *lntest.HarnessTest) {
// Wait for it to get mined and finish tearing down. // Wait for it to get mined and finish tearing down.
ht.AssertStreamChannelCoopClosed(alice, chanPoint, false, closeClient) ht.AssertStreamChannelCoopClosed(alice, chanPoint, false, closeClient)
} }
// coopCloseWithHTLCsWithRestart also tests the coop close flow when an HTLC
// is still pending on the channel but this time it ensures that the shutdown
// process continues as expected even if a channel re-establish happens after
// one party has already initiated the shutdown.
func coopCloseWithHTLCsWithRestart(ht *lntest.HarnessTest) {
alice, bob := ht.Alice, ht.Bob
ht.ConnectNodes(alice, bob)
// Open a channel between Alice and Bob with the balance split equally.
// We do this to ensure that the close transaction will have 2 outputs
// so that we can assert that the correct delivery address gets used by
// the channel close initiator.
chanPoint := ht.OpenChannel(bob, alice, lntest.OpenChannelParams{
Amt: btcutil.Amount(1000000),
PushAmt: btcutil.Amount(1000000 / 2),
})
// Wait for Bob to understand that the channel is ready to use.
ht.AssertTopologyChannelOpen(bob, chanPoint)
// Set up a HODL invoice so that we can be sure that an HTLC is pending
// on the channel at the time that shutdown is requested.
var preimage lntypes.Preimage
copy(preimage[:], ht.Random32Bytes())
payHash := preimage.Hash()
invoiceReq := &invoicesrpc.AddHoldInvoiceRequest{
Memo: "testing close",
Value: 400,
Hash: payHash[:],
}
resp := alice.RPC.AddHoldInvoice(invoiceReq)
invoiceStream := alice.RPC.SubscribeSingleInvoice(payHash[:])
// Wait for the invoice to be ready and payable.
ht.AssertInvoiceState(invoiceStream, lnrpc.Invoice_OPEN)
// Now that the invoice is ready to be paid, let's have Bob open an HTLC
// for it.
req := &routerrpc.SendPaymentRequest{
PaymentRequest: resp.PaymentRequest,
TimeoutSeconds: 60,
FeeLimitSat: 1000000,
}
ht.SendPaymentAndAssertStatus(bob, req, lnrpc.Payment_IN_FLIGHT)
ht.AssertNumActiveHtlcs(bob, 1)
// Assert at this point that the HTLC is open but not yet settled.
ht.AssertInvoiceState(invoiceStream, lnrpc.Invoice_ACCEPTED)
// We will now let Alice initiate the closure of the channel. We will
// also let her specify a specific delivery address to be used since we
// want to test that this same address is used in the Shutdown message
// on reconnection.
newAddr := alice.RPC.NewAddress(&lnrpc.NewAddressRequest{
Type: AddrTypeWitnessPubkeyHash,
})
_ = alice.RPC.CloseChannel(&lnrpc.CloseChannelRequest{
ChannelPoint: chanPoint,
NoWait: true,
DeliveryAddress: newAddr.Address,
})
// Assert that both nodes see the channel as waiting for close.
ht.AssertChannelInactive(bob, chanPoint)
ht.AssertChannelInactive(alice, chanPoint)
// Now restart Alice and Bob.
ht.RestartNode(alice)
ht.RestartNode(bob)
ht.AssertConnected(alice, bob)
// Show that both nodes still see the channel as waiting for close after
// the restart.
ht.AssertChannelInactive(bob, chanPoint)
ht.AssertChannelInactive(alice, chanPoint)
// Settle the invoice.
alice.RPC.SettleInvoice(preimage[:])
// Wait for the channel to appear in the waiting closed list.
err := wait.Predicate(func() bool {
pendingChansResp := alice.RPC.PendingChannels()
waitingClosed := pendingChansResp.WaitingCloseChannels
return len(waitingClosed) == 1
}, defaultTimeout)
require.NoError(ht, err)
// Wait for the close tx to be in the Mempool and then mine 6 blocks
// to confirm the close.
closingTx := ht.AssertClosingTxInMempool(
chanPoint, lnrpc.CommitmentType_LEGACY,
)
ht.MineBlocksAndAssertNumTxes(6, 1)
// Finally, we inspect the closing transaction here to show that the
// delivery address that Alice specified in her original close request
// is the one that ended up being used in the final closing transaction.
tx := alice.RPC.GetTransaction(&walletrpc.GetTransactionRequest{
Txid: closingTx.TxHash().String(),
})
require.Len(ht, tx.OutputDetails, 2)
// Find Alice's output in the coop-close transaction.
var outputDetail *lnrpc.OutputDetail
for _, output := range tx.OutputDetails {
if output.IsOurAddress {
outputDetail = output
break
}
}
require.NotNil(ht, outputDetail)
// Show that the address used is the one she requested.
require.Equal(ht, outputDetail.Address, newAddr.Address)
}

View File

@ -21,13 +21,13 @@ import (
) )
var ( var (
// ErrChanAlreadyClosing is returned when a channel shutdown is attempted // ErrChanAlreadyClosing is returned when a channel shutdown is
// more than once. // attempted more than once.
ErrChanAlreadyClosing = fmt.Errorf("channel shutdown already initiated") ErrChanAlreadyClosing = fmt.Errorf("channel shutdown already initiated")
// ErrChanCloseNotFinished is returned when a caller attempts to access // ErrChanCloseNotFinished is returned when a caller attempts to access
// a field or function that is contingent on the channel closure negotiation // a field or function that is contingent on the channel closure
// already being completed. // negotiation already being completed.
ErrChanCloseNotFinished = fmt.Errorf("close negotiation not finished") ErrChanCloseNotFinished = fmt.Errorf("close negotiation not finished")
// ErrInvalidState is returned when the closing state machine receives a // ErrInvalidState is returned when the closing state machine receives a
@ -79,16 +79,16 @@ const (
// closeFeeNegotiation is the third, and most persistent state. Both // closeFeeNegotiation is the third, and most persistent state. Both
// parties enter this state after they've sent and received a shutdown // parties enter this state after they've sent and received a shutdown
// message. During this phase, both sides will send monotonically // message. During this phase, both sides will send monotonically
// increasing fee requests until one side accepts the last fee rate offered // increasing fee requests until one side accepts the last fee rate
// by the other party. In this case, the party will broadcast the closing // offered by the other party. In this case, the party will broadcast
// transaction, and send the accepted fee to the remote party. This then // the closing transaction, and send the accepted fee to the remote
// causes a shift into the closeFinished state. // party. This then causes a shift into the closeFinished state.
closeFeeNegotiation closeFeeNegotiation
// closeFinished is the final state of the state machine. In this state, a // closeFinished is the final state of the state machine. In this state,
// side has accepted a fee offer and has broadcast the valid closing // a side has accepted a fee offer and has broadcast the valid closing
// transaction to the network. During this phase, the closing transaction // transaction to the network. During this phase, the closing
// becomes available for examination. // transaction becomes available for examination.
closeFinished closeFinished
) )
@ -156,8 +156,9 @@ type ChanCloser struct {
// negotiationHeight is the height that the fee negotiation begun at. // negotiationHeight is the height that the fee negotiation begun at.
negotiationHeight uint32 negotiationHeight uint32
// closingTx is the final, fully signed closing transaction. This will only // closingTx is the final, fully signed closing transaction. This will
// be populated once the state machine shifts to the closeFinished state. // only be populated once the state machine shifts to the closeFinished
// state.
closingTx *wire.MsgTx closingTx *wire.MsgTx
// idealFeeSat is the ideal fee that the state machine should initially // idealFeeSat is the ideal fee that the state machine should initially
@ -173,22 +174,22 @@ type ChanCloser struct {
idealFeeRate chainfee.SatPerKWeight idealFeeRate chainfee.SatPerKWeight
// lastFeeProposal is the last fee that we proposed to the remote party. // lastFeeProposal is the last fee that we proposed to the remote party.
// We'll use this as a pivot point to ratchet our next offer up, down, or // We'll use this as a pivot point to ratchet our next offer up, down,
// simply accept the remote party's prior offer. // or simply accept the remote party's prior offer.
lastFeeProposal btcutil.Amount lastFeeProposal btcutil.Amount
// priorFeeOffers is a map that keeps track of all the proposed fees that // priorFeeOffers is a map that keeps track of all the proposed fees
// we've offered during the fee negotiation. We use this map to cut the // that we've offered during the fee negotiation. We use this map to cut
// negotiation early if the remote party ever sends an offer that we've // the negotiation early if the remote party ever sends an offer that
// sent in the past. Once negotiation terminates, we can extract the prior // we've sent in the past. Once negotiation terminates, we can extract
// signature of our accepted offer from this map. // the prior signature of our accepted offer from this map.
// //
// TODO(roasbeef): need to ensure if they broadcast w/ any of our prior // TODO(roasbeef): need to ensure if they broadcast w/ any of our prior
// sigs, we are aware of // sigs, we are aware of
priorFeeOffers map[btcutil.Amount]*lnwire.ClosingSigned priorFeeOffers map[btcutil.Amount]*lnwire.ClosingSigned
// closeReq is the initial closing request. This will only be populated if // closeReq is the initial closing request. This will only be populated
// we're the initiator of this closing negotiation. // if we're the initiator of this closing negotiation.
// //
// TODO(roasbeef): abstract away // TODO(roasbeef): abstract away
closeReq *htlcswitch.ChanClose closeReq *htlcswitch.ChanClose
@ -273,8 +274,10 @@ func NewChanCloser(cfg ChanCloseCfg, deliveryScript []byte,
negotiationHeight: negotiationHeight, negotiationHeight: negotiationHeight,
idealFeeRate: idealFeePerKw, idealFeeRate: idealFeePerKw,
localDeliveryScript: deliveryScript, localDeliveryScript: deliveryScript,
priorFeeOffers: make(map[btcutil.Amount]*lnwire.ClosingSigned), priorFeeOffers: make(
locallyInitiated: locallyInitiated, map[btcutil.Amount]*lnwire.ClosingSigned,
),
locallyInitiated: locallyInitiated,
} }
} }
@ -321,9 +324,9 @@ func (c *ChanCloser) initFeeBaseline() {
// initChanShutdown begins the shutdown process by un-registering the channel, // initChanShutdown begins the shutdown process by un-registering the channel,
// and creating a valid shutdown message to our target delivery address. // and creating a valid shutdown message to our target delivery address.
func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) {
// With both items constructed we'll now send the shutdown message for this // With both items constructed we'll now send the shutdown message for
// particular channel, advertising a shutdown request to our desired // this particular channel, advertising a shutdown request to our
// closing script. // desired closing script.
shutdown := lnwire.NewShutdown(c.cid, c.localDeliveryScript) shutdown := lnwire.NewShutdown(c.cid, c.localDeliveryScript)
// If this is a taproot channel, then we'll need to also generate a // If this is a taproot channel, then we'll need to also generate a
@ -353,6 +356,17 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) {
chancloserLog.Infof("ChannelPoint(%v): sending shutdown message", chancloserLog.Infof("ChannelPoint(%v): sending shutdown message",
c.chanPoint) c.chanPoint)
// At this point, we persist any relevant info regarding the Shutdown
// message we are about to send in order to ensure that if a
// re-establish occurs then we will re-send the same Shutdown message.
shutdownInfo := channeldb.NewShutdownInfo(
c.localDeliveryScript, c.locallyInitiated,
)
err := c.cfg.Channel.MarkShutdownSent(shutdownInfo)
if err != nil {
return nil, err
}
return shutdown, nil return shutdown, nil
} }
@ -375,12 +389,12 @@ func (c *ChanCloser) ShutdownChan() (*lnwire.Shutdown, error) {
} }
// With the opening steps complete, we'll transition into the // With the opening steps complete, we'll transition into the
// closeShutdownInitiated state. In this state, we'll wait until the other // closeShutdownInitiated state. In this state, we'll wait until the
// party sends their version of the shutdown message. // other party sends their version of the shutdown message.
c.state = closeShutdownInitiated c.state = closeShutdownInitiated
// Finally, we'll return the shutdown message to the caller so it can send // Finally, we'll return the shutdown message to the caller so it can
// it to the remote peer. // send it to the remote peer.
return shutdownMsg, nil return shutdownMsg, nil
} }
@ -476,9 +490,8 @@ func validateShutdownScript(disconnect func() error, upfrontScript,
// If appropriate, it will also generate a Shutdown message of its own to send // If appropriate, it will also generate a Shutdown message of its own to send
// out to the peer. It is possible for this method to return None when no error // out to the peer. It is possible for this method to return None when no error
// occurred. // occurred.
func (c *ChanCloser) ReceiveShutdown( func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) (
msg lnwire.Shutdown, fn.Option[lnwire.Shutdown], error) {
) (fn.Option[lnwire.Shutdown], error) {
noShutdown := fn.None[lnwire.Shutdown]() noShutdown := fn.None[lnwire.Shutdown]()
@ -610,9 +623,8 @@ func (c *ChanCloser) ReceiveShutdown(
// it will not. In either case it will transition the ChanCloser state machine // it will not. In either case it will transition the ChanCloser state machine
// to the negotiation phase wherein ClosingSigned messages are exchanged until // to the negotiation phase wherein ClosingSigned messages are exchanged until
// a mutually agreeable result is achieved. // a mutually agreeable result is achieved.
func (c *ChanCloser) BeginNegotiation() ( func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned],
fn.Option[lnwire.ClosingSigned], error, error) {
) {
noClosingSigned := fn.None[lnwire.ClosingSigned]() noClosingSigned := fn.None[lnwire.ClosingSigned]()
@ -673,11 +685,8 @@ func (c *ChanCloser) BeginNegotiation() (
// ReceiveClosingSigned is a method that should be called whenever we receive a // ReceiveClosingSigned is a method that should be called whenever we receive a
// ClosingSigned message from the wire. It may or may not return a ClosingSigned // ClosingSigned message from the wire. It may or may not return a ClosingSigned
// of our own to send back to the remote. // of our own to send back to the remote.
// func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) (
//nolint:funlen fn.Option[lnwire.ClosingSigned], error) {
func (c *ChanCloser) ReceiveClosingSigned(
msg lnwire.ClosingSigned,
) (fn.Option[lnwire.ClosingSigned], error) {
noClosing := fn.None[lnwire.ClosingSigned]() noClosing := fn.None[lnwire.ClosingSigned]()
@ -882,7 +891,9 @@ func (c *ChanCloser) ReceiveClosingSigned(
// proposeCloseSigned attempts to propose a new signature for the closing // proposeCloseSigned attempts to propose a new signature for the closing
// transaction for a channel based on the prior fee negotiations and our current // transaction for a channel based on the prior fee negotiations and our current
// compromise fee. // compromise fee.
func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (*lnwire.ClosingSigned, error) { func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (
*lnwire.ClosingSigned, error) {
var ( var (
closeOpts []lnwallet.ChanCloseOpt closeOpts []lnwallet.ChanCloseOpt
err error err error
@ -956,8 +967,8 @@ func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (*lnwire.ClosingSign
// compromise and to ensure that the fee negotiation has a stopping point. We // compromise and to ensure that the fee negotiation has a stopping point. We
// consider their fee acceptable if it's within 30% of our fee. // consider their fee acceptable if it's within 30% of our fee.
func feeInAcceptableRange(localFee, remoteFee btcutil.Amount) bool { func feeInAcceptableRange(localFee, remoteFee btcutil.Amount) bool {
// If our offer is lower than theirs, then we'll accept their offer if it's // If our offer is lower than theirs, then we'll accept their offer if
// no more than 30% *greater* than our current offer. // it's no more than 30% *greater* than our current offer.
if localFee < remoteFee { if localFee < remoteFee {
acceptableRange := localFee + ((localFee * 3) / 10) acceptableRange := localFee + ((localFee * 3) / 10)
return remoteFee <= acceptableRange return remoteFee <= acceptableRange
@ -991,51 +1002,59 @@ func calcCompromiseFee(chanPoint wire.OutPoint, ourIdealFee, lastSentFee,
// TODO(roasbeef): take in number of rounds as well? // TODO(roasbeef): take in number of rounds as well?
chancloserLog.Infof("ChannelPoint(%v): computing fee compromise, ideal="+ chancloserLog.Infof("ChannelPoint(%v): computing fee compromise, "+
"%v, last_sent=%v, remote_offer=%v", chanPoint, int64(ourIdealFee), "ideal=%v, last_sent=%v, remote_offer=%v", chanPoint,
int64(lastSentFee), int64(remoteFee)) int64(ourIdealFee), int64(lastSentFee), int64(remoteFee))
// Otherwise, we'll need to attempt to make a fee compromise if this is the // Otherwise, we'll need to attempt to make a fee compromise if this is
// second round, and neither side has agreed on fees. // the second round, and neither side has agreed on fees.
switch { switch {
// If their proposed fee is identical to our ideal fee, then we'll go with // If their proposed fee is identical to our ideal fee, then we'll go
// that as we can short circuit the fee negotiation. Similarly, if we // with that as we can short circuit the fee negotiation. Similarly, if
// haven't sent an offer yet, we'll default to our ideal fee. // we haven't sent an offer yet, we'll default to our ideal fee.
case ourIdealFee == remoteFee || lastSentFee == 0: case ourIdealFee == remoteFee || lastSentFee == 0:
return ourIdealFee return ourIdealFee
// If the last fee we sent, is equal to the fee the remote party is // If the last fee we sent, is equal to the fee the remote party is
// offering, then we can simply return this fee as the negotiation is over. // offering, then we can simply return this fee as the negotiation is
// over.
case remoteFee == lastSentFee: case remoteFee == lastSentFee:
return lastSentFee return lastSentFee
// If the fee the remote party is offering is less than the last one we // If the fee the remote party is offering is less than the last one we
// sent, then we'll need to ratchet down in order to move our offer closer // sent, then we'll need to ratchet down in order to move our offer
// to theirs. // closer to theirs.
case remoteFee < lastSentFee: case remoteFee < lastSentFee:
// If the fee is lower, but still acceptable, then we'll just return // If the fee is lower, but still acceptable, then we'll just
// this fee and end the negotiation. // return this fee and end the negotiation.
if feeInAcceptableRange(lastSentFee, remoteFee) { if feeInAcceptableRange(lastSentFee, remoteFee) {
chancloserLog.Infof("ChannelPoint(%v): proposed remote fee is "+ chancloserLog.Infof("ChannelPoint(%v): proposed "+
"close enough, capitulating", chanPoint) "remote fee is close enough, capitulating",
chanPoint)
return remoteFee return remoteFee
} }
// Otherwise, we'll ratchet the fee *down* using our current algorithm. // Otherwise, we'll ratchet the fee *down* using our current
// algorithm.
return ratchetFee(lastSentFee, false) return ratchetFee(lastSentFee, false)
// If the fee the remote party is offering is greater than the last one we // If the fee the remote party is offering is greater than the last one
// sent, then we'll ratchet up in order to ensure we terminate eventually. // we sent, then we'll ratchet up in order to ensure we terminate
// eventually.
case remoteFee > lastSentFee: case remoteFee > lastSentFee:
// If the fee is greater, but still acceptable, then we'll just return // If the fee is greater, but still acceptable, then we'll just
// this fee in order to put an end to the negotiation. // return this fee in order to put an end to the negotiation.
if feeInAcceptableRange(lastSentFee, remoteFee) { if feeInAcceptableRange(lastSentFee, remoteFee) {
chancloserLog.Infof("ChannelPoint(%v): proposed remote fee is "+ chancloserLog.Infof("ChannelPoint(%v): proposed "+
"close enough, capitulating", chanPoint) "remote fee is close enough, capitulating",
chanPoint)
return remoteFee return remoteFee
} }
// Otherwise, we'll ratchet the fee up using our current algorithm. // Otherwise, we'll ratchet the fee up using our current
// algorithm.
return ratchetFee(lastSentFee, true) return ratchetFee(lastSentFee, true)
default: default:

View File

@ -154,6 +154,10 @@ func (m *mockChannel) MarkCoopBroadcasted(*wire.MsgTx, bool) error {
return nil return nil
} }
func (m *mockChannel) MarkShutdownSent(*channeldb.ShutdownInfo) error {
return nil
}
func (m *mockChannel) IsInitiator() bool { func (m *mockChannel) IsInitiator() bool {
return m.initiator return m.initiator
} }

View File

@ -35,6 +35,11 @@ type Channel interface { //nolint:interfacebloat
// transaction has been broadcast. // transaction has been broadcast.
MarkCoopBroadcasted(*wire.MsgTx, bool) error MarkCoopBroadcasted(*wire.MsgTx, bool) error
// MarkShutdownSent persists the given ShutdownInfo. The existence of
// the ShutdownInfo represents the fact that the Shutdown message has
// been sent by us and so should be re-sent on re-establish.
MarkShutdownSent(info *channeldb.ShutdownInfo) error
// IsInitiator returns true we are the initiator of the channel. // IsInitiator returns true we are the initiator of the channel.
IsInitiator() bool IsInitiator() bool

View File

@ -8823,6 +8823,18 @@ func (lc *LightningChannel) MarkCoopBroadcasted(tx *wire.MsgTx,
return lc.channelState.MarkCoopBroadcasted(tx, localInitiated) return lc.channelState.MarkCoopBroadcasted(tx, localInitiated)
} }
// MarkShutdownSent persists the given ShutdownInfo. The existence of the
// ShutdownInfo represents the fact that the Shutdown message has been sent by
// us and so should be re-sent on re-establish.
func (lc *LightningChannel) MarkShutdownSent(
info *channeldb.ShutdownInfo) error {
lc.Lock()
defer lc.Unlock()
return lc.channelState.MarkShutdownSent(info)
}
// MarkDataLoss marks sets the channel status to LocalDataLoss and stores the // MarkDataLoss marks sets the channel status to LocalDataLoss and stores the
// passed commitPoint for use to retrieve funds in case the remote force closes // passed commitPoint for use to retrieve funds in case the remote force closes
// the channel. // the channel.

View File

@ -27,6 +27,7 @@ import (
"github.com/lightningnetwork/lnd/contractcourt" "github.com/lightningnetwork/lnd/contractcourt"
"github.com/lightningnetwork/lnd/discovery" "github.com/lightningnetwork/lnd/discovery"
"github.com/lightningnetwork/lnd/feature" "github.com/lightningnetwork/lnd/feature"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/funding" "github.com/lightningnetwork/lnd/funding"
"github.com/lightningnetwork/lnd/htlcswitch" "github.com/lightningnetwork/lnd/htlcswitch"
"github.com/lightningnetwork/lnd/htlcswitch/hodl" "github.com/lightningnetwork/lnd/htlcswitch/hodl"
@ -975,17 +976,70 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
spew.Sdump(forwardingPolicy)) spew.Sdump(forwardingPolicy))
// If the channel is pending, set the value to nil in the // If the channel is pending, set the value to nil in the
// activeChannels map. This is done to signify that the channel is // activeChannels map. This is done to signify that the channel
// pending. We don't add the link to the switch here - it's the funding // is pending. We don't add the link to the switch here - it's
// manager's responsibility to spin up pending channels. Adding them // the funding manager's responsibility to spin up pending
// here would just be extra work as we'll tear them down when creating // channels. Adding them here would just be extra work as we'll
// + adding the final link. // tear them down when creating + adding the final link.
if lnChan.IsPending() { if lnChan.IsPending() {
p.activeChannels.Store(chanID, nil) p.activeChannels.Store(chanID, nil)
continue continue
} }
shutdownInfo, err := lnChan.State().ShutdownInfo()
if err != nil && !errors.Is(err, channeldb.ErrNoShutdownInfo) {
return nil, err
}
var (
shutdownMsg fn.Option[lnwire.Shutdown]
shutdownInfoErr error
)
shutdownInfo.WhenSome(func(info channeldb.ShutdownInfo) {
// Compute an ideal fee.
feePerKw, err := p.cfg.FeeEstimator.EstimateFeePerKW(
p.cfg.CoopCloseTargetConfs,
)
if err != nil {
shutdownInfoErr = fmt.Errorf("unable to "+
"estimate fee: %w", err)
return
}
chanCloser, err := p.createChanCloser(
lnChan, info.DeliveryScript.Val, feePerKw, nil,
info.LocalInitiator.Val,
)
if err != nil {
shutdownInfoErr = fmt.Errorf("unable to "+
"create chan closer: %w", err)
return
}
chanID := lnwire.NewChanIDFromOutPoint(
&lnChan.State().FundingOutpoint,
)
p.activeChanCloses[chanID] = chanCloser
// Create the Shutdown message.
shutdown, err := chanCloser.ShutdownChan()
if err != nil {
delete(p.activeChanCloses, chanID)
shutdownInfoErr = err
return
}
shutdownMsg = fn.Some[lnwire.Shutdown](*shutdown)
})
if shutdownInfoErr != nil {
return nil, shutdownInfoErr
}
// Subscribe to the set of on-chain events for this channel. // Subscribe to the set of on-chain events for this channel.
chainEvents, err := p.cfg.ChainArb.SubscribeChannelEvents( chainEvents, err := p.cfg.ChainArb.SubscribeChannelEvents(
*chanPoint, *chanPoint,
@ -996,7 +1050,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) (
err = p.addLink( err = p.addLink(
chanPoint, lnChan, forwardingPolicy, chainEvents, chanPoint, lnChan, forwardingPolicy, chainEvents,
true, true, shutdownMsg,
) )
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to add link %v to "+ return nil, fmt.Errorf("unable to add link %v to "+
@ -1014,7 +1068,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint,
lnChan *lnwallet.LightningChannel, lnChan *lnwallet.LightningChannel,
forwardingPolicy *models.ForwardingPolicy, forwardingPolicy *models.ForwardingPolicy,
chainEvents *contractcourt.ChainEventSubscription, chainEvents *contractcourt.ChainEventSubscription,
syncStates bool) error { syncStates bool, shutdownMsg fn.Option[lnwire.Shutdown]) error {
// onChannelFailure will be called by the link in case the channel // onChannelFailure will be called by the link in case the channel
// fails for some reason. // fails for some reason.
@ -1083,6 +1137,7 @@ func (p *Brontide) addLink(chanPoint *wire.OutPoint,
NotifyInactiveLinkEvent: p.cfg.ChannelNotifier.NotifyInactiveLinkEvent, NotifyInactiveLinkEvent: p.cfg.ChannelNotifier.NotifyInactiveLinkEvent,
HtlcNotifier: p.cfg.HtlcNotifier, HtlcNotifier: p.cfg.HtlcNotifier,
GetAliases: p.cfg.GetAliases, GetAliases: p.cfg.GetAliases,
PreviouslySentShutdown: shutdownMsg,
} }
// Before adding our new link, purge the switch of any pending or live // Before adding our new link, purge the switch of any pending or live
@ -2802,15 +2857,32 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) (
return nil, nil return nil, nil
} }
// As mentioned above, we don't re-create the delivery script. var deliveryScript []byte
deliveryScript := c.LocalShutdownScript
if len(deliveryScript) == 0 { shutdownInfo, err := c.ShutdownInfo()
var err error switch {
deliveryScript, err = p.genDeliveryScript() // We have previously stored the delivery script that we need to use
if err != nil { // in the shutdown message. Re-use this script.
p.log.Errorf("unable to gen delivery script: %v", case err == nil:
err) shutdownInfo.WhenSome(func(info channeldb.ShutdownInfo) {
return nil, fmt.Errorf("close addr unavailable") deliveryScript = info.DeliveryScript.Val
})
// An error other than ErrNoShutdownInfo was returned
case err != nil && !errors.Is(err, channeldb.ErrNoShutdownInfo):
return nil, err
case errors.Is(err, channeldb.ErrNoShutdownInfo):
deliveryScript = c.LocalShutdownScript
if len(deliveryScript) == 0 {
var err error
deliveryScript, err = p.genDeliveryScript()
if err != nil {
p.log.Errorf("unable to gen delivery script: "+
"%v", err)
return nil, fmt.Errorf("close addr unavailable")
}
} }
} }
@ -2990,13 +3062,12 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) {
return return
} }
link.OnCommitOnce(htlcswitch.Outgoing, func() { if !link.DisableAdds(htlcswitch.Outgoing) {
err := link.DisableAdds(htlcswitch.Outgoing) p.log.Warnf("Outgoing link adds already "+
if err != nil { "disabled: %v", link.ChanID())
p.log.Warnf("outgoing link adds already "+ }
"disabled: %v", link.ChanID())
}
link.OnCommitOnce(htlcswitch.Outgoing, func() {
p.queueMsg(shutdownMsg, nil) p.queueMsg(shutdownMsg, nil)
}) })
@ -3619,12 +3690,9 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) {
switch typed := msg.msg.(type) { switch typed := msg.msg.(type) {
case *lnwire.Shutdown: case *lnwire.Shutdown:
// Disable incoming adds immediately. // Disable incoming adds immediately.
if link != nil { if link != nil && !link.DisableAdds(htlcswitch.Incoming) {
err := link.DisableAdds(htlcswitch.Incoming) p.log.Warnf("Incoming link adds already disabled: %v",
if err != nil { link.ChanID())
p.log.Warnf("incoming link adds already "+
"disabled: %v", link.ChanID())
}
} }
oShutdown, err := chanCloser.ReceiveShutdown(*typed) oShutdown, err := chanCloser.ReceiveShutdown(*typed)
@ -3634,7 +3702,7 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) {
} }
oShutdown.WhenSome(func(msg lnwire.Shutdown) { oShutdown.WhenSome(func(msg lnwire.Shutdown) {
// if the link is nil it means we can immediately queue // If the link is nil it means we can immediately queue
// the Shutdown message since we don't have to wait for // the Shutdown message since we don't have to wait for
// commitment transaction synchronization. // commitment transaction synchronization.
if link == nil { if link == nil {
@ -3642,14 +3710,17 @@ func (p *Brontide) handleCloseMsg(msg *closeMsg) {
return return
} }
// Immediately disallow any new HTLC's from being added
// in the outgoing direction.
if !link.DisableAdds(htlcswitch.Outgoing) {
p.log.Warnf("Outgoing link adds already "+
"disabled: %v", link.ChanID())
}
// When we have a Shutdown to send, we defer it till the // When we have a Shutdown to send, we defer it till the
// next time we send a CommitSig to remain spec // next time we send a CommitSig to remain spec
// compliant. // compliant.
link.OnCommitOnce(htlcswitch.Outgoing, func() { link.OnCommitOnce(htlcswitch.Outgoing, func() {
err := link.DisableAdds(htlcswitch.Outgoing)
if err != nil {
p.log.Warn(err.Error())
}
p.queueMsg(&msg, nil) p.queueMsg(&msg, nil)
}) })
}) })
@ -3906,7 +3977,7 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error {
// Create the link and add it to the switch. // Create the link and add it to the switch.
err = p.addLink( err = p.addLink(
chanPoint, lnChan, initialPolicy, chainEvents, chanPoint, lnChan, initialPolicy, chainEvents,
shouldReestablish, shouldReestablish, fn.None[lnwire.Shutdown](),
) )
if err != nil { if err != nil {
return fmt.Errorf("can't register new channel link(%v) with "+ return fmt.Errorf("can't register new channel link(%v) with "+

View File

@ -4,7 +4,6 @@ import (
"bytes" "bytes"
crand "crypto/rand" crand "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt"
"io" "io"
"math/rand" "math/rand"
"net" "net"
@ -510,34 +509,20 @@ type mockMessageConn struct {
readRaceDetectingCounter int readRaceDetectingCounter int
} }
func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) error { func (m *mockUpdateHandler) EnableAdds(dir htlcswitch.LinkDirection) bool {
switch dir { if dir == htlcswitch.Outgoing {
case htlcswitch.Outgoing: return m.isOutgoingAddBlocked.Swap(false)
if !m.isOutgoingAddBlocked.Swap(false) {
return fmt.Errorf("%v adds already enabled", dir)
}
case htlcswitch.Incoming:
if !m.isIncomingAddBlocked.Swap(false) {
return fmt.Errorf("%v adds already enabled", dir)
}
} }
return nil return m.isIncomingAddBlocked.Swap(false)
} }
func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) error { func (m *mockUpdateHandler) DisableAdds(dir htlcswitch.LinkDirection) bool {
switch dir { if dir == htlcswitch.Outgoing {
case htlcswitch.Outgoing: return !m.isOutgoingAddBlocked.Swap(true)
if m.isOutgoingAddBlocked.Swap(true) {
return fmt.Errorf("%v adds already disabled", dir)
}
case htlcswitch.Incoming:
if m.isIncomingAddBlocked.Swap(true) {
return fmt.Errorf("%v adds already disabled", dir)
}
} }
return nil return !m.isIncomingAddBlocked.Swap(true)
} }
func (m *mockUpdateHandler) IsFlushing(dir htlcswitch.LinkDirection) bool { func (m *mockUpdateHandler) IsFlushing(dir htlcswitch.LinkDirection) bool {