mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-26 01:33:02 +01:00
htlcswitch+lntest: modify Switch to persist resolution messages
Include unit tests for the Switch, and integration tests that exercise the persistence logic.
This commit is contained in:
parent
bfed7a088f
commit
f7b3da4bb2
@ -1092,8 +1092,6 @@ func (c *ChannelArbitrator) stateStep(
|
||||
if len(pktsToSend) != 0 {
|
||||
err := c.cfg.DeliverResolutionMsg(pktsToSend...)
|
||||
if err != nil {
|
||||
// TODO(roasbeef): make sure packet sends are
|
||||
// idempotent
|
||||
log.Errorf("unable to send pkts: %v", err)
|
||||
return StateError, closeTx, err
|
||||
}
|
||||
|
@ -213,6 +213,10 @@ type CircuitMapConfig struct {
|
||||
// ExtractErrorEncrypter derives the shared secret used to encrypt
|
||||
// errors from the obfuscator's ephemeral public key.
|
||||
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
|
||||
|
||||
// CheckResolutionMsg checks whether a given resolution message exists
|
||||
// for the passed CircuitKey.
|
||||
CheckResolutionMsg func(outKey *CircuitKey) error
|
||||
}
|
||||
|
||||
// NewCircuitMap creates a new instance of the circuitMap.
|
||||
@ -400,7 +404,19 @@ func (cm *circuitMap) cleanClosedChannels() error {
|
||||
// Check if the outgoing channel ID can be found in the
|
||||
// closed channel ID map. Notice that we need to store
|
||||
// the outgoing key because it's used for db query.
|
||||
//
|
||||
// NOTE: We skip this if a resolution message can be
|
||||
// found under the outKey. This means that there is an
|
||||
// existing resolution message(s) that need to get to
|
||||
// the incoming links.
|
||||
if isClosedChannel(outKey.ChanID) {
|
||||
// Check the resolution message store. A return
|
||||
// value of nil means we need to skip deleting
|
||||
// these circuits.
|
||||
if cm.cfg.CheckResolutionMsg(&outKey) == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
keystoneKeySet[outKey] = struct{}{}
|
||||
|
||||
// Also update circuitKeySet to mark the
|
||||
|
@ -66,6 +66,11 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
||||
chanParams []closeChannelParams
|
||||
deleted []htlcswitch.Keystone
|
||||
untouched []htlcswitch.Keystone
|
||||
|
||||
// If resMsg is true, then closed channels will not delete
|
||||
// circuits if the channel was the keystone / outgoing key in
|
||||
// the open circuit.
|
||||
resMsg bool
|
||||
}{
|
||||
{
|
||||
name: "no deletion if there are no closed channels",
|
||||
@ -120,7 +125,7 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
||||
{InKey: inKey20, OutKey: outKey20},
|
||||
},
|
||||
deleted: []htlcswitch.Keystone{
|
||||
{InKey: inKey00}, {InKey: inKey11},
|
||||
{InKey: inKey10}, {InKey: inKey11},
|
||||
},
|
||||
untouched: []htlcswitch.Keystone{
|
||||
{InKey: inKey20, OutKey: outKey20},
|
||||
@ -214,13 +219,33 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
||||
{InKey: inKey22, OutKey: outKey20},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "don't delete circuits for outgoing",
|
||||
chanParams: []closeChannelParams{
|
||||
// Creates a close channel with chanID1.
|
||||
{chanID: chanID1, isPending: false},
|
||||
},
|
||||
keystones: []htlcswitch.Keystone{
|
||||
// Creates a circuit and a keystone
|
||||
{InKey: inKey10, OutKey: outKey10},
|
||||
// Creates a circuit and a keystone
|
||||
{InKey: inKey11, OutKey: outKey20},
|
||||
// Creates a circuit and a keystone
|
||||
{InKey: inKey00, OutKey: outKey11},
|
||||
},
|
||||
deleted: []htlcswitch.Keystone{
|
||||
{InKey: inKey10, OutKey: outKey10},
|
||||
{InKey: inKey11, OutKey: outKey20},
|
||||
},
|
||||
resMsg: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range testParams {
|
||||
test := tt
|
||||
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, test.resMsg)
|
||||
|
||||
// create test circuits
|
||||
for _, ks := range test.keystones {
|
||||
|
@ -2,6 +2,7 @@ package htlcswitch_test
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
@ -97,8 +98,9 @@ func newOnionProcessor(t *testing.T) *hop.OnionProcessor {
|
||||
}
|
||||
|
||||
// newCircuitMap creates a new htlcswitch.CircuitMap using a temp db and a
|
||||
// fresh sphinx router.
|
||||
func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
||||
// fresh sphinx router. When resMsg is set to true, CheckResolutionMsg will
|
||||
// always return nil. Otherwise it will always return an error.
|
||||
func newCircuitMap(t *testing.T, resMsg bool) (*htlcswitch.CircuitMapConfig,
|
||||
htlcswitch.CircuitMap) {
|
||||
|
||||
onionProcessor := newOnionProcessor(t)
|
||||
@ -111,6 +113,18 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
||||
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
|
||||
}
|
||||
|
||||
if resMsg {
|
||||
checkRes := func(out *htlcswitch.CircuitKey) error {
|
||||
return nil
|
||||
}
|
||||
circuitMapCfg.CheckResolutionMsg = checkRes
|
||||
} else {
|
||||
checkRes := func(out *htlcswitch.CircuitKey) error {
|
||||
return fmt.Errorf("not found")
|
||||
}
|
||||
circuitMapCfg.CheckResolutionMsg = checkRes
|
||||
}
|
||||
|
||||
circuitMap, err := htlcswitch.NewCircuitMap(circuitMapCfg)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to create persistent circuit map: %v", err)
|
||||
@ -124,7 +138,7 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
||||
func TestCircuitMapInit(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cfg, _ := newCircuitMap(t)
|
||||
cfg, _ := newCircuitMap(t, false)
|
||||
restartCircuitMap(t, cfg)
|
||||
}
|
||||
|
||||
@ -231,7 +245,7 @@ func TestCircuitMapPersistence(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{
|
||||
ChanID: chan1,
|
||||
@ -649,6 +663,7 @@ func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
|
||||
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
|
||||
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
|
||||
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
||||
CheckResolutionMsg: cfg.CheckResolutionMsg,
|
||||
}
|
||||
cm2, err := htlcswitch.NewCircuitMap(cfg2)
|
||||
if err != nil {
|
||||
@ -671,7 +686,7 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
@ -767,7 +782,7 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
@ -973,7 +988,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
const nCircuits = 10
|
||||
const firstTrimIndex = 7
|
||||
@ -1122,7 +1137,7 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
@ -1215,7 +1230,7 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
@ -1272,7 +1287,7 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
@ -1331,7 +1346,7 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
|
||||
err error
|
||||
)
|
||||
|
||||
cfg, circuitMap := newCircuitMap(t)
|
||||
cfg, circuitMap := newCircuitMap(t, false)
|
||||
|
||||
circuit := &htlcswitch.PaymentCircuit{
|
||||
Incoming: htlcswitch.CircuitKey{
|
||||
|
@ -297,15 +297,23 @@ type Switch struct {
|
||||
// ack in the forwarding package of the outgoing link. This was added to
|
||||
// make pipelining settles more efficient.
|
||||
pendingSettleFails []channeldb.SettleFailRef
|
||||
|
||||
// resMsgStore is used to store the set of ResolutionMsg that come from
|
||||
// contractcourt. This is used so the Switch can properly forward them,
|
||||
// even on restarts.
|
||||
resMsgStore *resolutionStore
|
||||
}
|
||||
|
||||
// New creates the new instance of htlc switch.
|
||||
func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
||||
resStore := newResolutionStore(cfg.DB)
|
||||
|
||||
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
|
||||
DB: cfg.DB,
|
||||
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
|
||||
FetchClosedChannels: cfg.FetchClosedChannels,
|
||||
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
||||
CheckResolutionMsg: resStore.checkResolutionMsg,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -323,6 +331,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
||||
htlcPlex: make(chan *plexPacket),
|
||||
chanCloseRequests: make(chan *ChanClose),
|
||||
resolutionMsgs: make(chan *resolutionMsg),
|
||||
resMsgStore: resStore,
|
||||
quit: make(chan struct{}),
|
||||
}
|
||||
|
||||
@ -342,7 +351,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
||||
type resolutionMsg struct {
|
||||
contractcourt.ResolutionMsg
|
||||
|
||||
doneChan chan struct{}
|
||||
errChan chan error
|
||||
}
|
||||
|
||||
// ProcessContractResolution is called by active contract resolvers once a
|
||||
@ -351,25 +360,23 @@ type resolutionMsg struct {
|
||||
// didn't need to go to the chain in order to fulfill a contract. We'll process
|
||||
// this message just as if it came from an active outgoing channel.
|
||||
func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) error {
|
||||
|
||||
done := make(chan struct{})
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
select {
|
||||
case s.resolutionMsgs <- &resolutionMsg{
|
||||
ResolutionMsg: msg,
|
||||
doneChan: done,
|
||||
errChan: errChan,
|
||||
}:
|
||||
case <-s.quit:
|
||||
return ErrSwitchExiting
|
||||
}
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
case err := <-errChan:
|
||||
return err
|
||||
case <-s.quit:
|
||||
return ErrSwitchExiting
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetPaymentResult returns the the result of the payment attempt with the
|
||||
@ -1678,6 +1685,28 @@ out:
|
||||
go s.cfg.LocalChannelClose(peerPub[:], req)
|
||||
|
||||
case resolutionMsg := <-s.resolutionMsgs:
|
||||
// We'll persist the resolution message to the Switch's
|
||||
// resolution store.
|
||||
resMsg := resolutionMsg.ResolutionMsg
|
||||
err := s.resMsgStore.addResolutionMsg(&resMsg)
|
||||
if err != nil {
|
||||
// This will only fail if there is a database
|
||||
// error or a serialization error. Sending the
|
||||
// error prevents the contractcourt from being
|
||||
// in a state where it believes the send was
|
||||
// successful, when it wasn't.
|
||||
log.Errorf("unable to add resolution msg: %v",
|
||||
err)
|
||||
resolutionMsg.errChan <- err
|
||||
continue
|
||||
}
|
||||
|
||||
// At this point, the resolution message has been
|
||||
// persisted. It is safe to signal success by sending
|
||||
// a nil error since the Switch will re-deliver the
|
||||
// resolution message on restart.
|
||||
resolutionMsg.errChan <- nil
|
||||
|
||||
pkt := &htlcPacket{
|
||||
outgoingChanID: resolutionMsg.SourceChan,
|
||||
outgoingHTLCID: resolutionMsg.HtlcIndex,
|
||||
@ -1703,14 +1732,11 @@ out:
|
||||
// encounter is due to the circuit already being
|
||||
// closed. This is fine, as processing this message is
|
||||
// meant to be idempotent.
|
||||
err := s.handlePacketForward(pkt)
|
||||
err = s.handlePacketForward(pkt)
|
||||
if err != nil {
|
||||
log.Errorf("Unable to forward resolution msg: %v", err)
|
||||
}
|
||||
|
||||
// With the message processed, we'll now close out
|
||||
close(resolutionMsg.doneChan)
|
||||
|
||||
// A new packet has arrived for forwarding, we'll interpret the
|
||||
// packet concretely, then either forward it along, or
|
||||
// interpret a return packet to a locally initialized one.
|
||||
@ -1863,6 +1889,72 @@ func (s *Switch) Start() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := s.reforwardResolutions(); err != nil {
|
||||
// We are already stopping so we can ignore the error.
|
||||
_ = s.Stop()
|
||||
log.Errorf("unable to reforward resolutions: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// reforwardResolutions fetches the set of resolution messages stored on-disk
|
||||
// and reforwards them if their circuits are still open. If the circuits have
|
||||
// been deleted, then we will delete the resolution message from the database.
|
||||
func (s *Switch) reforwardResolutions() error {
|
||||
// Fetch all stored resolution messages, deleting the ones that are
|
||||
// resolved.
|
||||
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switchPackets := make([]*htlcPacket, 0, len(resMsgs))
|
||||
for _, resMsg := range resMsgs {
|
||||
// If the open circuit no longer exists, then we can remove the
|
||||
// message from the store.
|
||||
outKey := CircuitKey{
|
||||
ChanID: resMsg.SourceChan,
|
||||
HtlcID: resMsg.HtlcIndex,
|
||||
}
|
||||
|
||||
if s.circuits.LookupOpenCircuit(outKey) == nil {
|
||||
// The open circuit doesn't exist.
|
||||
err := s.resMsgStore.deleteResolutionMsg(&outKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// The circuit is still open, so we can assume that the link or
|
||||
// switch (if we are the source) hasn't cleaned it up yet.
|
||||
resPkt := &htlcPacket{
|
||||
outgoingChanID: resMsg.SourceChan,
|
||||
outgoingHTLCID: resMsg.HtlcIndex,
|
||||
isResolution: true,
|
||||
}
|
||||
|
||||
if resMsg.Failure != nil {
|
||||
resPkt.htlc = &lnwire.UpdateFailHTLC{}
|
||||
} else {
|
||||
resPkt.htlc = &lnwire.UpdateFulfillHTLC{
|
||||
PaymentPreimage: *resMsg.PreImage,
|
||||
}
|
||||
}
|
||||
|
||||
switchPackets = append(switchPackets, resPkt)
|
||||
}
|
||||
|
||||
// We'll now dispatch the set of resolution messages to the proper
|
||||
// destination. An error is only encountered here if the switch is
|
||||
// shutting down.
|
||||
if err := s.ForwardPackets(nil, switchPackets...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -15,6 +15,7 @@ import (
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
"github.com/lightningnetwork/lnd/channeldb"
|
||||
"github.com/lightningnetwork/lnd/contractcourt"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
|
||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
@ -3923,3 +3924,144 @@ func TestSwitchMailboxDust(t *testing.T) {
|
||||
t.Fatal("no timely reply from switch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSwitchResolution checks the ability of the switch to persist and handle
|
||||
// resolution messages.
|
||||
func TestSwitchResolution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
alicePeer, err := newMockServer(
|
||||
t, "alice", testStartingHeight, nil, testDefaultDelta,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
bobPeer, err := newMockServer(
|
||||
t, "bob", testStartingHeight, nil, testDefaultDelta,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err := initSwitchWithDB(testStartingHeight, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
|
||||
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
|
||||
|
||||
aliceChannelLink := newMockChannelLink(
|
||||
s, chanID1, aliceChanID, alicePeer, true,
|
||||
)
|
||||
bobChannelLink := newMockChannelLink(
|
||||
s, chanID2, bobChanID, bobPeer, true,
|
||||
)
|
||||
err = s.AddLink(aliceChannelLink)
|
||||
require.NoError(t, err)
|
||||
err = s.AddLink(bobChannelLink)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create an add htlcPacket that Alice will send to Bob.
|
||||
preimage, err := genPreimage()
|
||||
require.NoError(t, err)
|
||||
|
||||
rhash := sha256.Sum256(preimage[:])
|
||||
packet := &htlcPacket{
|
||||
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||
incomingHTLCID: 0,
|
||||
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||
obfuscator: NewMockObfuscator(),
|
||||
htlc: &lnwire.UpdateAddHTLC{
|
||||
PaymentHash: rhash,
|
||||
Amount: 1,
|
||||
},
|
||||
}
|
||||
|
||||
err = s.ForwardPackets(nil, packet)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Bob will receive the packet and open the circuit.
|
||||
select {
|
||||
case <-bobChannelLink.packets:
|
||||
err = bobChannelLink.completeCircuit(packet)
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request was not propagated to destination")
|
||||
}
|
||||
|
||||
// Check that only one circuit is open.
|
||||
require.Equal(t, 1, s.circuits.NumOpen())
|
||||
|
||||
// We'll send a settle resolution to Switch that should go to Alice.
|
||||
settleResMsg := contractcourt.ResolutionMsg{
|
||||
SourceChan: bobChanID,
|
||||
HtlcIndex: 0,
|
||||
PreImage: &preimage,
|
||||
}
|
||||
|
||||
// Before the resolution is sent, remove alice's link so we can assert
|
||||
// that the resolution is actually stored. Otherwise, it would be
|
||||
// deleted shortly after being sent.
|
||||
s.RemoveLink(chanID1)
|
||||
|
||||
// Send the resolution message.
|
||||
err = s.ProcessContractResolution(settleResMsg)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Assert that the resolution store contains the settle reoslution.
|
||||
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, 1, len(resMsgs))
|
||||
require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
|
||||
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
|
||||
require.Nil(t, resMsgs[0].Failure)
|
||||
require.Equal(t, preimage, *resMsgs[0].PreImage)
|
||||
|
||||
// Now we'll restart Alice's link and delete the circuit.
|
||||
err = s.AddLink(aliceChannelLink)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Alice will receive the packet and open the circuit.
|
||||
select {
|
||||
case alicePkt := <-aliceChannelLink.packets:
|
||||
err = aliceChannelLink.completeCircuit(alicePkt)
|
||||
require.NoError(t, err)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("request was not propagated to destination")
|
||||
}
|
||||
|
||||
// Assert that there are no more circuits.
|
||||
require.Equal(t, 0, s.circuits.NumOpen())
|
||||
|
||||
// We'll restart the Switch and assert that Alice does not receive
|
||||
// another packet.
|
||||
switchDB := s.cfg.DB.(*channeldb.DB)
|
||||
err = s.Stop()
|
||||
require.NoError(t, err)
|
||||
|
||||
s, err = initSwitchWithDB(testStartingHeight, switchDB)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = s.Start()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
_ = s.Stop()
|
||||
}()
|
||||
|
||||
err = s.AddLink(aliceChannelLink)
|
||||
require.NoError(t, err)
|
||||
err = s.AddLink(bobChannelLink)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Alice should not receive a packet since the Switch should have
|
||||
// deleted the resolution message since the circuit was closed.
|
||||
select {
|
||||
case alicePkt := <-aliceChannelLink.packets:
|
||||
t.Fatalf("received erroneous packet: %v", alicePkt)
|
||||
case <-time.After(time.Second * 5):
|
||||
}
|
||||
|
||||
// Check that the resolution message no longer exists in the store.
|
||||
resMsgs, err = s.resMsgStore.fetchAllResolutionMsg()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(resMsgs))
|
||||
}
|
||||
|
200
lntest/itest/lnd_res_handoff_test.go
Normal file
200
lntest/itest/lnd_res_handoff_test.go
Normal file
@ -0,0 +1,200 @@
|
||||
package itest
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/btcsuite/btcd/btcutil"
|
||||
"github.com/lightningnetwork/lnd/lnrpc"
|
||||
"github.com/lightningnetwork/lnd/lntest"
|
||||
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/net/context"
|
||||
)
|
||||
|
||||
// testResHandoff tests that the contractcourt is able to properly hand-off
|
||||
// resolution messages to the switch.
|
||||
func testResHandoff(net *lntest.NetworkHarness, t *harnessTest) {
|
||||
const (
|
||||
chanAmt = btcutil.Amount(1000000)
|
||||
paymentAmt = 50000
|
||||
)
|
||||
|
||||
ctxb := context.Background()
|
||||
|
||||
// First we'll create a channel between Alice and Bob.
|
||||
net.EnsureConnected(t.t, net.Alice, net.Bob)
|
||||
|
||||
chanPointAlice := openChannelAndAssert(
|
||||
t, net, net.Alice, net.Bob,
|
||||
lntest.OpenChannelParams{
|
||||
Amt: chanAmt,
|
||||
},
|
||||
)
|
||||
defer closeChannelAndAssert(t, net, net.Alice, chanPointAlice, false)
|
||||
|
||||
// Wait for Alice and Bob to receive the channel edge from the funding
|
||||
// manager.
|
||||
err := net.Alice.WaitForNetworkChannelOpen(chanPointAlice)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
err = net.Bob.WaitForNetworkChannelOpen(chanPointAlice)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Create a new node Carol that will be in hodl mode. This is used to
|
||||
// trigger the behavior of checkRemoteDanglingActions in the
|
||||
// contractcourt. This will cause Bob to fail the HTLC back to Alice.
|
||||
carol := net.NewNode(t.t, "Carol", []string{"--hodl.commit"})
|
||||
defer shutdownAndAssert(net, t, carol)
|
||||
|
||||
// Connect Bob to Carol.
|
||||
net.ConnectNodes(t.t, net.Bob, carol)
|
||||
|
||||
// Open a channel between Bob and Carol.
|
||||
chanPointCarol := openChannelAndAssert(
|
||||
t, net, net.Bob, carol,
|
||||
lntest.OpenChannelParams{
|
||||
Amt: chanAmt,
|
||||
},
|
||||
)
|
||||
|
||||
// Wait for Bob and Carol to receive the channel edge from the funding
|
||||
// manager.
|
||||
err = net.Bob.WaitForNetworkChannelOpen(chanPointCarol)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
err = carol.WaitForNetworkChannelOpen(chanPointCarol)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Wait for Alice to see the channel edge in the graph.
|
||||
err = net.Alice.WaitForNetworkChannelOpen(chanPointCarol)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// We'll create an invoice for Carol that Alice will attempt to pay.
|
||||
// Since Carol is in hodl.commit mode, she won't send back any commit
|
||||
// sigs.
|
||||
carolPayReqs, _, _, err := createPayReqs(
|
||||
carol, paymentAmt, 1,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Alice will now attempt to fulfill the invoice.
|
||||
err = completePaymentRequests(
|
||||
net.Alice, net.Alice.RouterClient, carolPayReqs, false,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Wait until Carol has received the Add, CommitSig from Bob, and has
|
||||
// responded with a RevokeAndAck. We expect NumUpdates to be 1 meaning
|
||||
// Carol's CommitHeight is 1.
|
||||
err = wait.Predicate(func() bool {
|
||||
carolInfo, err := getChanInfo(carol)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return carolInfo.NumUpdates == 1
|
||||
}, defaultTimeout)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Before we shutdown Alice, we'll assert that she only has 1 update.
|
||||
err = wait.Predicate(func() bool {
|
||||
aliceInfo, err := getChanInfo(net.Alice)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return aliceInfo.NumUpdates == 1
|
||||
}, defaultTimeout)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// We'll shutdown Alice so that Bob can't connect to her.
|
||||
restartAlice, err := net.SuspendNode(net.Alice)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Bob will now force close his channel with Carol such that resolution
|
||||
// messages are created and forwarded backwards to Alice.
|
||||
_, _, err = net.CloseChannel(net.Bob, chanPointCarol, true)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// The channel should be listed in the PendingChannels result.
|
||||
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
|
||||
defer cancel()
|
||||
|
||||
pendingChansRequest := &lnrpc.PendingChannelsRequest{}
|
||||
pendingChanResp, err := net.Bob.PendingChannels(
|
||||
ctxt, pendingChansRequest,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
require.NoError(t.t, checkNumWaitingCloseChannels(pendingChanResp, 1))
|
||||
|
||||
// We'll mine a block to confirm the force close transaction and to
|
||||
// advance Bob's contract state with Carol to StateContractClosed.
|
||||
mineBlocks(t, net, 1, 1)
|
||||
|
||||
// We sleep here so we can be sure that the hand-off has occurred from
|
||||
// Bob's contractcourt to Bob's htlcswitch. This sleep could be removed
|
||||
// if there was some feedback (i.e. API in switch) that allowed for
|
||||
// querying the state of resolution messages.
|
||||
time.Sleep(10 * time.Second)
|
||||
|
||||
// Mine blocks until Bob has no waiting close channels. This tests
|
||||
// that the circuit-deletion logic is skipped if a resolution message
|
||||
// exists.
|
||||
for {
|
||||
_, err = net.Miner.Client.Generate(1)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
pendingChanResp, err = net.Bob.PendingChannels(
|
||||
ctxt, pendingChansRequest,
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
isErr := checkNumForceClosedChannels(pendingChanResp, 0)
|
||||
if isErr == nil {
|
||||
break
|
||||
}
|
||||
|
||||
time.Sleep(150 * time.Millisecond)
|
||||
}
|
||||
|
||||
// We will now restart Bob so that we can test whether the resolution
|
||||
// messages are re-forwarded on start-up.
|
||||
restartBob, err := net.SuspendNode(net.Bob)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
err = restartBob()
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// We'll now also restart Alice and connect her with Bob.
|
||||
err = restartAlice()
|
||||
require.NoError(t.t, err)
|
||||
|
||||
net.EnsureConnected(t.t, net.Alice, net.Bob)
|
||||
|
||||
// We'll assert that Alice has received the failure resolution
|
||||
// message.
|
||||
err = wait.Predicate(func() bool {
|
||||
aliceInfo, err := getChanInfo(net.Alice)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return aliceInfo.NumUpdates == 2
|
||||
}, defaultTimeout)
|
||||
require.NoError(t.t, err)
|
||||
|
||||
// Assert that Alice's payment failed.
|
||||
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||
paymentsResp, err := net.Alice.ListPayments(
|
||||
ctxt, &lnrpc.ListPaymentsRequest{
|
||||
IncludeIncomplete: true,
|
||||
},
|
||||
)
|
||||
require.NoError(t.t, err)
|
||||
require.Equal(t.t, 1, len(paymentsResp.Payments))
|
||||
|
||||
htlcs := paymentsResp.Payments[0].Htlcs
|
||||
|
||||
require.Equal(t.t, 1, len(htlcs))
|
||||
require.Equal(t.t, lnrpc.HTLCAttempt_FAILED, htlcs[0].Status)
|
||||
}
|
@ -399,4 +399,8 @@ var allTestCases = []*testCase{
|
||||
name: "addpeer config",
|
||||
test: testAddPeerConfig,
|
||||
},
|
||||
{
|
||||
name: "resolution handoff",
|
||||
test: testResHandoff,
|
||||
},
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user