mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-03-29 03:01:52 +01:00
htlcswitch+lntest: modify Switch to persist resolution messages
Include unit tests for the Switch, and integration tests that exercise the persistence logic.
This commit is contained in:
parent
bfed7a088f
commit
f7b3da4bb2
@ -1092,8 +1092,6 @@ func (c *ChannelArbitrator) stateStep(
|
|||||||
if len(pktsToSend) != 0 {
|
if len(pktsToSend) != 0 {
|
||||||
err := c.cfg.DeliverResolutionMsg(pktsToSend...)
|
err := c.cfg.DeliverResolutionMsg(pktsToSend...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO(roasbeef): make sure packet sends are
|
|
||||||
// idempotent
|
|
||||||
log.Errorf("unable to send pkts: %v", err)
|
log.Errorf("unable to send pkts: %v", err)
|
||||||
return StateError, closeTx, err
|
return StateError, closeTx, err
|
||||||
}
|
}
|
||||||
|
@ -213,6 +213,10 @@ type CircuitMapConfig struct {
|
|||||||
// ExtractErrorEncrypter derives the shared secret used to encrypt
|
// ExtractErrorEncrypter derives the shared secret used to encrypt
|
||||||
// errors from the obfuscator's ephemeral public key.
|
// errors from the obfuscator's ephemeral public key.
|
||||||
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
|
ExtractErrorEncrypter hop.ErrorEncrypterExtracter
|
||||||
|
|
||||||
|
// CheckResolutionMsg checks whether a given resolution message exists
|
||||||
|
// for the passed CircuitKey.
|
||||||
|
CheckResolutionMsg func(outKey *CircuitKey) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewCircuitMap creates a new instance of the circuitMap.
|
// NewCircuitMap creates a new instance of the circuitMap.
|
||||||
@ -400,7 +404,19 @@ func (cm *circuitMap) cleanClosedChannels() error {
|
|||||||
// Check if the outgoing channel ID can be found in the
|
// Check if the outgoing channel ID can be found in the
|
||||||
// closed channel ID map. Notice that we need to store
|
// closed channel ID map. Notice that we need to store
|
||||||
// the outgoing key because it's used for db query.
|
// the outgoing key because it's used for db query.
|
||||||
|
//
|
||||||
|
// NOTE: We skip this if a resolution message can be
|
||||||
|
// found under the outKey. This means that there is an
|
||||||
|
// existing resolution message(s) that need to get to
|
||||||
|
// the incoming links.
|
||||||
if isClosedChannel(outKey.ChanID) {
|
if isClosedChannel(outKey.ChanID) {
|
||||||
|
// Check the resolution message store. A return
|
||||||
|
// value of nil means we need to skip deleting
|
||||||
|
// these circuits.
|
||||||
|
if cm.cfg.CheckResolutionMsg(&outKey) == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
keystoneKeySet[outKey] = struct{}{}
|
keystoneKeySet[outKey] = struct{}{}
|
||||||
|
|
||||||
// Also update circuitKeySet to mark the
|
// Also update circuitKeySet to mark the
|
||||||
|
@ -66,6 +66,11 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
|||||||
chanParams []closeChannelParams
|
chanParams []closeChannelParams
|
||||||
deleted []htlcswitch.Keystone
|
deleted []htlcswitch.Keystone
|
||||||
untouched []htlcswitch.Keystone
|
untouched []htlcswitch.Keystone
|
||||||
|
|
||||||
|
// If resMsg is true, then closed channels will not delete
|
||||||
|
// circuits if the channel was the keystone / outgoing key in
|
||||||
|
// the open circuit.
|
||||||
|
resMsg bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "no deletion if there are no closed channels",
|
name: "no deletion if there are no closed channels",
|
||||||
@ -120,7 +125,7 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
|||||||
{InKey: inKey20, OutKey: outKey20},
|
{InKey: inKey20, OutKey: outKey20},
|
||||||
},
|
},
|
||||||
deleted: []htlcswitch.Keystone{
|
deleted: []htlcswitch.Keystone{
|
||||||
{InKey: inKey00}, {InKey: inKey11},
|
{InKey: inKey10}, {InKey: inKey11},
|
||||||
},
|
},
|
||||||
untouched: []htlcswitch.Keystone{
|
untouched: []htlcswitch.Keystone{
|
||||||
{InKey: inKey20, OutKey: outKey20},
|
{InKey: inKey20, OutKey: outKey20},
|
||||||
@ -214,13 +219,33 @@ func TestCircuitMapCleanClosedChannels(t *testing.T) {
|
|||||||
{InKey: inKey22, OutKey: outKey20},
|
{InKey: inKey22, OutKey: outKey20},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "don't delete circuits for outgoing",
|
||||||
|
chanParams: []closeChannelParams{
|
||||||
|
// Creates a close channel with chanID1.
|
||||||
|
{chanID: chanID1, isPending: false},
|
||||||
|
},
|
||||||
|
keystones: []htlcswitch.Keystone{
|
||||||
|
// Creates a circuit and a keystone
|
||||||
|
{InKey: inKey10, OutKey: outKey10},
|
||||||
|
// Creates a circuit and a keystone
|
||||||
|
{InKey: inKey11, OutKey: outKey20},
|
||||||
|
// Creates a circuit and a keystone
|
||||||
|
{InKey: inKey00, OutKey: outKey11},
|
||||||
|
},
|
||||||
|
deleted: []htlcswitch.Keystone{
|
||||||
|
{InKey: inKey10, OutKey: outKey10},
|
||||||
|
{InKey: inKey11, OutKey: outKey20},
|
||||||
|
},
|
||||||
|
resMsg: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tt := range testParams {
|
for _, tt := range testParams {
|
||||||
test := tt
|
test := tt
|
||||||
|
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, test.resMsg)
|
||||||
|
|
||||||
// create test circuits
|
// create test circuits
|
||||||
for _, ks := range test.keystones {
|
for _, ks := range test.keystones {
|
||||||
|
@ -2,6 +2,7 @@ package htlcswitch_test
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
@ -97,8 +98,9 @@ func newOnionProcessor(t *testing.T) *hop.OnionProcessor {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// newCircuitMap creates a new htlcswitch.CircuitMap using a temp db and a
|
// newCircuitMap creates a new htlcswitch.CircuitMap using a temp db and a
|
||||||
// fresh sphinx router.
|
// fresh sphinx router. When resMsg is set to true, CheckResolutionMsg will
|
||||||
func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
// always return nil. Otherwise it will always return an error.
|
||||||
|
func newCircuitMap(t *testing.T, resMsg bool) (*htlcswitch.CircuitMapConfig,
|
||||||
htlcswitch.CircuitMap) {
|
htlcswitch.CircuitMap) {
|
||||||
|
|
||||||
onionProcessor := newOnionProcessor(t)
|
onionProcessor := newOnionProcessor(t)
|
||||||
@ -111,6 +113,18 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
|||||||
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
|
ExtractErrorEncrypter: onionProcessor.ExtractErrorEncrypter,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if resMsg {
|
||||||
|
checkRes := func(out *htlcswitch.CircuitKey) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
circuitMapCfg.CheckResolutionMsg = checkRes
|
||||||
|
} else {
|
||||||
|
checkRes := func(out *htlcswitch.CircuitKey) error {
|
||||||
|
return fmt.Errorf("not found")
|
||||||
|
}
|
||||||
|
circuitMapCfg.CheckResolutionMsg = checkRes
|
||||||
|
}
|
||||||
|
|
||||||
circuitMap, err := htlcswitch.NewCircuitMap(circuitMapCfg)
|
circuitMap, err := htlcswitch.NewCircuitMap(circuitMapCfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("unable to create persistent circuit map: %v", err)
|
t.Fatalf("unable to create persistent circuit map: %v", err)
|
||||||
@ -124,7 +138,7 @@ func newCircuitMap(t *testing.T) (*htlcswitch.CircuitMapConfig,
|
|||||||
func TestCircuitMapInit(t *testing.T) {
|
func TestCircuitMapInit(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
cfg, _ := newCircuitMap(t)
|
cfg, _ := newCircuitMap(t, false)
|
||||||
restartCircuitMap(t, cfg)
|
restartCircuitMap(t, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -231,7 +245,7 @@ func TestCircuitMapPersistence(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{
|
circuit := circuitMap.LookupCircuit(htlcswitch.CircuitKey{
|
||||||
ChanID: chan1,
|
ChanID: chan1,
|
||||||
@ -649,6 +663,7 @@ func restartCircuitMap(t *testing.T, cfg *htlcswitch.CircuitMapConfig) (
|
|||||||
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
|
FetchAllOpenChannels: db.ChannelStateDB().FetchAllOpenChannels,
|
||||||
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
|
FetchClosedChannels: db.ChannelStateDB().FetchClosedChannels,
|
||||||
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
||||||
|
CheckResolutionMsg: cfg.CheckResolutionMsg,
|
||||||
}
|
}
|
||||||
cm2, err := htlcswitch.NewCircuitMap(cfg2)
|
cm2, err := htlcswitch.NewCircuitMap(cfg2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -671,7 +686,7 @@ func TestCircuitMapCommitCircuits(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
@ -767,7 +782,7 @@ func TestCircuitMapOpenCircuits(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
@ -973,7 +988,7 @@ func TestCircuitMapTrimOpenCircuits(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
const nCircuits = 10
|
const nCircuits = 10
|
||||||
const firstTrimIndex = 7
|
const firstTrimIndex = 7
|
||||||
@ -1122,7 +1137,7 @@ func TestCircuitMapCloseOpenCircuits(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
@ -1215,7 +1230,7 @@ func TestCircuitMapCloseUnopenedCircuit(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
@ -1272,7 +1287,7 @@ func TestCircuitMapDeleteUnopenedCircuit(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
@ -1331,7 +1346,7 @@ func TestCircuitMapDeleteOpenCircuit(t *testing.T) {
|
|||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg, circuitMap := newCircuitMap(t)
|
cfg, circuitMap := newCircuitMap(t, false)
|
||||||
|
|
||||||
circuit := &htlcswitch.PaymentCircuit{
|
circuit := &htlcswitch.PaymentCircuit{
|
||||||
Incoming: htlcswitch.CircuitKey{
|
Incoming: htlcswitch.CircuitKey{
|
||||||
|
@ -297,15 +297,23 @@ type Switch struct {
|
|||||||
// ack in the forwarding package of the outgoing link. This was added to
|
// ack in the forwarding package of the outgoing link. This was added to
|
||||||
// make pipelining settles more efficient.
|
// make pipelining settles more efficient.
|
||||||
pendingSettleFails []channeldb.SettleFailRef
|
pendingSettleFails []channeldb.SettleFailRef
|
||||||
|
|
||||||
|
// resMsgStore is used to store the set of ResolutionMsg that come from
|
||||||
|
// contractcourt. This is used so the Switch can properly forward them,
|
||||||
|
// even on restarts.
|
||||||
|
resMsgStore *resolutionStore
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates the new instance of htlc switch.
|
// New creates the new instance of htlc switch.
|
||||||
func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
||||||
|
resStore := newResolutionStore(cfg.DB)
|
||||||
|
|
||||||
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
|
circuitMap, err := NewCircuitMap(&CircuitMapConfig{
|
||||||
DB: cfg.DB,
|
DB: cfg.DB,
|
||||||
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
|
FetchAllOpenChannels: cfg.FetchAllOpenChannels,
|
||||||
FetchClosedChannels: cfg.FetchClosedChannels,
|
FetchClosedChannels: cfg.FetchClosedChannels,
|
||||||
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
ExtractErrorEncrypter: cfg.ExtractErrorEncrypter,
|
||||||
|
CheckResolutionMsg: resStore.checkResolutionMsg,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -323,6 +331,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
|||||||
htlcPlex: make(chan *plexPacket),
|
htlcPlex: make(chan *plexPacket),
|
||||||
chanCloseRequests: make(chan *ChanClose),
|
chanCloseRequests: make(chan *ChanClose),
|
||||||
resolutionMsgs: make(chan *resolutionMsg),
|
resolutionMsgs: make(chan *resolutionMsg),
|
||||||
|
resMsgStore: resStore,
|
||||||
quit: make(chan struct{}),
|
quit: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -342,7 +351,7 @@ func New(cfg Config, currentHeight uint32) (*Switch, error) {
|
|||||||
type resolutionMsg struct {
|
type resolutionMsg struct {
|
||||||
contractcourt.ResolutionMsg
|
contractcourt.ResolutionMsg
|
||||||
|
|
||||||
doneChan chan struct{}
|
errChan chan error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessContractResolution is called by active contract resolvers once a
|
// ProcessContractResolution is called by active contract resolvers once a
|
||||||
@ -351,25 +360,23 @@ type resolutionMsg struct {
|
|||||||
// didn't need to go to the chain in order to fulfill a contract. We'll process
|
// didn't need to go to the chain in order to fulfill a contract. We'll process
|
||||||
// this message just as if it came from an active outgoing channel.
|
// this message just as if it came from an active outgoing channel.
|
||||||
func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) error {
|
func (s *Switch) ProcessContractResolution(msg contractcourt.ResolutionMsg) error {
|
||||||
|
errChan := make(chan error, 1)
|
||||||
done := make(chan struct{})
|
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case s.resolutionMsgs <- &resolutionMsg{
|
case s.resolutionMsgs <- &resolutionMsg{
|
||||||
ResolutionMsg: msg,
|
ResolutionMsg: msg,
|
||||||
doneChan: done,
|
errChan: errChan,
|
||||||
}:
|
}:
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
return ErrSwitchExiting
|
return ErrSwitchExiting
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-done:
|
case err := <-errChan:
|
||||||
|
return err
|
||||||
case <-s.quit:
|
case <-s.quit:
|
||||||
return ErrSwitchExiting
|
return ErrSwitchExiting
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPaymentResult returns the the result of the payment attempt with the
|
// GetPaymentResult returns the the result of the payment attempt with the
|
||||||
@ -1678,6 +1685,28 @@ out:
|
|||||||
go s.cfg.LocalChannelClose(peerPub[:], req)
|
go s.cfg.LocalChannelClose(peerPub[:], req)
|
||||||
|
|
||||||
case resolutionMsg := <-s.resolutionMsgs:
|
case resolutionMsg := <-s.resolutionMsgs:
|
||||||
|
// We'll persist the resolution message to the Switch's
|
||||||
|
// resolution store.
|
||||||
|
resMsg := resolutionMsg.ResolutionMsg
|
||||||
|
err := s.resMsgStore.addResolutionMsg(&resMsg)
|
||||||
|
if err != nil {
|
||||||
|
// This will only fail if there is a database
|
||||||
|
// error or a serialization error. Sending the
|
||||||
|
// error prevents the contractcourt from being
|
||||||
|
// in a state where it believes the send was
|
||||||
|
// successful, when it wasn't.
|
||||||
|
log.Errorf("unable to add resolution msg: %v",
|
||||||
|
err)
|
||||||
|
resolutionMsg.errChan <- err
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// At this point, the resolution message has been
|
||||||
|
// persisted. It is safe to signal success by sending
|
||||||
|
// a nil error since the Switch will re-deliver the
|
||||||
|
// resolution message on restart.
|
||||||
|
resolutionMsg.errChan <- nil
|
||||||
|
|
||||||
pkt := &htlcPacket{
|
pkt := &htlcPacket{
|
||||||
outgoingChanID: resolutionMsg.SourceChan,
|
outgoingChanID: resolutionMsg.SourceChan,
|
||||||
outgoingHTLCID: resolutionMsg.HtlcIndex,
|
outgoingHTLCID: resolutionMsg.HtlcIndex,
|
||||||
@ -1703,14 +1732,11 @@ out:
|
|||||||
// encounter is due to the circuit already being
|
// encounter is due to the circuit already being
|
||||||
// closed. This is fine, as processing this message is
|
// closed. This is fine, as processing this message is
|
||||||
// meant to be idempotent.
|
// meant to be idempotent.
|
||||||
err := s.handlePacketForward(pkt)
|
err = s.handlePacketForward(pkt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Unable to forward resolution msg: %v", err)
|
log.Errorf("Unable to forward resolution msg: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// With the message processed, we'll now close out
|
|
||||||
close(resolutionMsg.doneChan)
|
|
||||||
|
|
||||||
// A new packet has arrived for forwarding, we'll interpret the
|
// A new packet has arrived for forwarding, we'll interpret the
|
||||||
// packet concretely, then either forward it along, or
|
// packet concretely, then either forward it along, or
|
||||||
// interpret a return packet to a locally initialized one.
|
// interpret a return packet to a locally initialized one.
|
||||||
@ -1863,6 +1889,72 @@ func (s *Switch) Start() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.reforwardResolutions(); err != nil {
|
||||||
|
// We are already stopping so we can ignore the error.
|
||||||
|
_ = s.Stop()
|
||||||
|
log.Errorf("unable to reforward resolutions: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// reforwardResolutions fetches the set of resolution messages stored on-disk
|
||||||
|
// and reforwards them if their circuits are still open. If the circuits have
|
||||||
|
// been deleted, then we will delete the resolution message from the database.
|
||||||
|
func (s *Switch) reforwardResolutions() error {
|
||||||
|
// Fetch all stored resolution messages, deleting the ones that are
|
||||||
|
// resolved.
|
||||||
|
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switchPackets := make([]*htlcPacket, 0, len(resMsgs))
|
||||||
|
for _, resMsg := range resMsgs {
|
||||||
|
// If the open circuit no longer exists, then we can remove the
|
||||||
|
// message from the store.
|
||||||
|
outKey := CircuitKey{
|
||||||
|
ChanID: resMsg.SourceChan,
|
||||||
|
HtlcID: resMsg.HtlcIndex,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.circuits.LookupOpenCircuit(outKey) == nil {
|
||||||
|
// The open circuit doesn't exist.
|
||||||
|
err := s.resMsgStore.deleteResolutionMsg(&outKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// The circuit is still open, so we can assume that the link or
|
||||||
|
// switch (if we are the source) hasn't cleaned it up yet.
|
||||||
|
resPkt := &htlcPacket{
|
||||||
|
outgoingChanID: resMsg.SourceChan,
|
||||||
|
outgoingHTLCID: resMsg.HtlcIndex,
|
||||||
|
isResolution: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
if resMsg.Failure != nil {
|
||||||
|
resPkt.htlc = &lnwire.UpdateFailHTLC{}
|
||||||
|
} else {
|
||||||
|
resPkt.htlc = &lnwire.UpdateFulfillHTLC{
|
||||||
|
PaymentPreimage: *resMsg.PreImage,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switchPackets = append(switchPackets, resPkt)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We'll now dispatch the set of resolution messages to the proper
|
||||||
|
// destination. An error is only encountered here if the switch is
|
||||||
|
// shutting down.
|
||||||
|
if err := s.ForwardPackets(nil, switchPackets...); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
"github.com/lightningnetwork/lnd/channeldb"
|
"github.com/lightningnetwork/lnd/channeldb"
|
||||||
|
"github.com/lightningnetwork/lnd/contractcourt"
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
|
"github.com/lightningnetwork/lnd/htlcswitch/hodl"
|
||||||
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
"github.com/lightningnetwork/lnd/htlcswitch/hop"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
@ -3923,3 +3924,144 @@ func TestSwitchMailboxDust(t *testing.T) {
|
|||||||
t.Fatal("no timely reply from switch")
|
t.Fatal("no timely reply from switch")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestSwitchResolution checks the ability of the switch to persist and handle
|
||||||
|
// resolution messages.
|
||||||
|
func TestSwitchResolution(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
alicePeer, err := newMockServer(
|
||||||
|
t, "alice", testStartingHeight, nil, testDefaultDelta,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
bobPeer, err := newMockServer(
|
||||||
|
t, "bob", testStartingHeight, nil, testDefaultDelta,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s, err := initSwitchWithDB(testStartingHeight, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
|
||||||
|
|
||||||
|
aliceChannelLink := newMockChannelLink(
|
||||||
|
s, chanID1, aliceChanID, alicePeer, true,
|
||||||
|
)
|
||||||
|
bobChannelLink := newMockChannelLink(
|
||||||
|
s, chanID2, bobChanID, bobPeer, true,
|
||||||
|
)
|
||||||
|
err = s.AddLink(aliceChannelLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = s.AddLink(bobChannelLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create an add htlcPacket that Alice will send to Bob.
|
||||||
|
preimage, err := genPreimage()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rhash := sha256.Sum256(preimage[:])
|
||||||
|
packet := &htlcPacket{
|
||||||
|
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||||
|
incomingHTLCID: 0,
|
||||||
|
outgoingChanID: bobChannelLink.ShortChanID(),
|
||||||
|
obfuscator: NewMockObfuscator(),
|
||||||
|
htlc: &lnwire.UpdateAddHTLC{
|
||||||
|
PaymentHash: rhash,
|
||||||
|
Amount: 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = s.ForwardPackets(nil, packet)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Bob will receive the packet and open the circuit.
|
||||||
|
select {
|
||||||
|
case <-bobChannelLink.packets:
|
||||||
|
err = bobChannelLink.completeCircuit(packet)
|
||||||
|
require.NoError(t, err)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("request was not propagated to destination")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that only one circuit is open.
|
||||||
|
require.Equal(t, 1, s.circuits.NumOpen())
|
||||||
|
|
||||||
|
// We'll send a settle resolution to Switch that should go to Alice.
|
||||||
|
settleResMsg := contractcourt.ResolutionMsg{
|
||||||
|
SourceChan: bobChanID,
|
||||||
|
HtlcIndex: 0,
|
||||||
|
PreImage: &preimage,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Before the resolution is sent, remove alice's link so we can assert
|
||||||
|
// that the resolution is actually stored. Otherwise, it would be
|
||||||
|
// deleted shortly after being sent.
|
||||||
|
s.RemoveLink(chanID1)
|
||||||
|
|
||||||
|
// Send the resolution message.
|
||||||
|
err = s.ProcessContractResolution(settleResMsg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Assert that the resolution store contains the settle reoslution.
|
||||||
|
resMsgs, err := s.resMsgStore.fetchAllResolutionMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, 1, len(resMsgs))
|
||||||
|
require.Equal(t, settleResMsg.SourceChan, resMsgs[0].SourceChan)
|
||||||
|
require.Equal(t, settleResMsg.HtlcIndex, resMsgs[0].HtlcIndex)
|
||||||
|
require.Nil(t, resMsgs[0].Failure)
|
||||||
|
require.Equal(t, preimage, *resMsgs[0].PreImage)
|
||||||
|
|
||||||
|
// Now we'll restart Alice's link and delete the circuit.
|
||||||
|
err = s.AddLink(aliceChannelLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Alice will receive the packet and open the circuit.
|
||||||
|
select {
|
||||||
|
case alicePkt := <-aliceChannelLink.packets:
|
||||||
|
err = aliceChannelLink.completeCircuit(alicePkt)
|
||||||
|
require.NoError(t, err)
|
||||||
|
case <-time.After(time.Second):
|
||||||
|
t.Fatal("request was not propagated to destination")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assert that there are no more circuits.
|
||||||
|
require.Equal(t, 0, s.circuits.NumOpen())
|
||||||
|
|
||||||
|
// We'll restart the Switch and assert that Alice does not receive
|
||||||
|
// another packet.
|
||||||
|
switchDB := s.cfg.DB.(*channeldb.DB)
|
||||||
|
err = s.Stop()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
s, err = initSwitchWithDB(testStartingHeight, switchDB)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = s.Start()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
_ = s.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
err = s.AddLink(aliceChannelLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
err = s.AddLink(bobChannelLink)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Alice should not receive a packet since the Switch should have
|
||||||
|
// deleted the resolution message since the circuit was closed.
|
||||||
|
select {
|
||||||
|
case alicePkt := <-aliceChannelLink.packets:
|
||||||
|
t.Fatalf("received erroneous packet: %v", alicePkt)
|
||||||
|
case <-time.After(time.Second * 5):
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that the resolution message no longer exists in the store.
|
||||||
|
resMsgs, err = s.resMsgStore.fetchAllResolutionMsg()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.Equal(t, 0, len(resMsgs))
|
||||||
|
}
|
||||||
|
200
lntest/itest/lnd_res_handoff_test.go
Normal file
200
lntest/itest/lnd_res_handoff_test.go
Normal file
@ -0,0 +1,200 @@
|
|||||||
|
package itest
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/btcutil"
|
||||||
|
"github.com/lightningnetwork/lnd/lnrpc"
|
||||||
|
"github.com/lightningnetwork/lnd/lntest"
|
||||||
|
"github.com/lightningnetwork/lnd/lntest/wait"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/net/context"
|
||||||
|
)
|
||||||
|
|
||||||
|
// testResHandoff tests that the contractcourt is able to properly hand-off
|
||||||
|
// resolution messages to the switch.
|
||||||
|
func testResHandoff(net *lntest.NetworkHarness, t *harnessTest) {
|
||||||
|
const (
|
||||||
|
chanAmt = btcutil.Amount(1000000)
|
||||||
|
paymentAmt = 50000
|
||||||
|
)
|
||||||
|
|
||||||
|
ctxb := context.Background()
|
||||||
|
|
||||||
|
// First we'll create a channel between Alice and Bob.
|
||||||
|
net.EnsureConnected(t.t, net.Alice, net.Bob)
|
||||||
|
|
||||||
|
chanPointAlice := openChannelAndAssert(
|
||||||
|
t, net, net.Alice, net.Bob,
|
||||||
|
lntest.OpenChannelParams{
|
||||||
|
Amt: chanAmt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
defer closeChannelAndAssert(t, net, net.Alice, chanPointAlice, false)
|
||||||
|
|
||||||
|
// Wait for Alice and Bob to receive the channel edge from the funding
|
||||||
|
// manager.
|
||||||
|
err := net.Alice.WaitForNetworkChannelOpen(chanPointAlice)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
err = net.Bob.WaitForNetworkChannelOpen(chanPointAlice)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Create a new node Carol that will be in hodl mode. This is used to
|
||||||
|
// trigger the behavior of checkRemoteDanglingActions in the
|
||||||
|
// contractcourt. This will cause Bob to fail the HTLC back to Alice.
|
||||||
|
carol := net.NewNode(t.t, "Carol", []string{"--hodl.commit"})
|
||||||
|
defer shutdownAndAssert(net, t, carol)
|
||||||
|
|
||||||
|
// Connect Bob to Carol.
|
||||||
|
net.ConnectNodes(t.t, net.Bob, carol)
|
||||||
|
|
||||||
|
// Open a channel between Bob and Carol.
|
||||||
|
chanPointCarol := openChannelAndAssert(
|
||||||
|
t, net, net.Bob, carol,
|
||||||
|
lntest.OpenChannelParams{
|
||||||
|
Amt: chanAmt,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
// Wait for Bob and Carol to receive the channel edge from the funding
|
||||||
|
// manager.
|
||||||
|
err = net.Bob.WaitForNetworkChannelOpen(chanPointCarol)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
err = carol.WaitForNetworkChannelOpen(chanPointCarol)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Wait for Alice to see the channel edge in the graph.
|
||||||
|
err = net.Alice.WaitForNetworkChannelOpen(chanPointCarol)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// We'll create an invoice for Carol that Alice will attempt to pay.
|
||||||
|
// Since Carol is in hodl.commit mode, she won't send back any commit
|
||||||
|
// sigs.
|
||||||
|
carolPayReqs, _, _, err := createPayReqs(
|
||||||
|
carol, paymentAmt, 1,
|
||||||
|
)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Alice will now attempt to fulfill the invoice.
|
||||||
|
err = completePaymentRequests(
|
||||||
|
net.Alice, net.Alice.RouterClient, carolPayReqs, false,
|
||||||
|
)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Wait until Carol has received the Add, CommitSig from Bob, and has
|
||||||
|
// responded with a RevokeAndAck. We expect NumUpdates to be 1 meaning
|
||||||
|
// Carol's CommitHeight is 1.
|
||||||
|
err = wait.Predicate(func() bool {
|
||||||
|
carolInfo, err := getChanInfo(carol)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return carolInfo.NumUpdates == 1
|
||||||
|
}, defaultTimeout)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Before we shutdown Alice, we'll assert that she only has 1 update.
|
||||||
|
err = wait.Predicate(func() bool {
|
||||||
|
aliceInfo, err := getChanInfo(net.Alice)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return aliceInfo.NumUpdates == 1
|
||||||
|
}, defaultTimeout)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// We'll shutdown Alice so that Bob can't connect to her.
|
||||||
|
restartAlice, err := net.SuspendNode(net.Alice)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Bob will now force close his channel with Carol such that resolution
|
||||||
|
// messages are created and forwarded backwards to Alice.
|
||||||
|
_, _, err = net.CloseChannel(net.Bob, chanPointCarol, true)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// The channel should be listed in the PendingChannels result.
|
||||||
|
ctxt, cancel := context.WithTimeout(ctxb, defaultTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
pendingChansRequest := &lnrpc.PendingChannelsRequest{}
|
||||||
|
pendingChanResp, err := net.Bob.PendingChannels(
|
||||||
|
ctxt, pendingChansRequest,
|
||||||
|
)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
require.NoError(t.t, checkNumWaitingCloseChannels(pendingChanResp, 1))
|
||||||
|
|
||||||
|
// We'll mine a block to confirm the force close transaction and to
|
||||||
|
// advance Bob's contract state with Carol to StateContractClosed.
|
||||||
|
mineBlocks(t, net, 1, 1)
|
||||||
|
|
||||||
|
// We sleep here so we can be sure that the hand-off has occurred from
|
||||||
|
// Bob's contractcourt to Bob's htlcswitch. This sleep could be removed
|
||||||
|
// if there was some feedback (i.e. API in switch) that allowed for
|
||||||
|
// querying the state of resolution messages.
|
||||||
|
time.Sleep(10 * time.Second)
|
||||||
|
|
||||||
|
// Mine blocks until Bob has no waiting close channels. This tests
|
||||||
|
// that the circuit-deletion logic is skipped if a resolution message
|
||||||
|
// exists.
|
||||||
|
for {
|
||||||
|
_, err = net.Miner.Client.Generate(1)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
pendingChanResp, err = net.Bob.PendingChannels(
|
||||||
|
ctxt, pendingChansRequest,
|
||||||
|
)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
isErr := checkNumForceClosedChannels(pendingChanResp, 0)
|
||||||
|
if isErr == nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
time.Sleep(150 * time.Millisecond)
|
||||||
|
}
|
||||||
|
|
||||||
|
// We will now restart Bob so that we can test whether the resolution
|
||||||
|
// messages are re-forwarded on start-up.
|
||||||
|
restartBob, err := net.SuspendNode(net.Bob)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
err = restartBob()
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// We'll now also restart Alice and connect her with Bob.
|
||||||
|
err = restartAlice()
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
net.EnsureConnected(t.t, net.Alice, net.Bob)
|
||||||
|
|
||||||
|
// We'll assert that Alice has received the failure resolution
|
||||||
|
// message.
|
||||||
|
err = wait.Predicate(func() bool {
|
||||||
|
aliceInfo, err := getChanInfo(net.Alice)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return aliceInfo.NumUpdates == 2
|
||||||
|
}, defaultTimeout)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
|
||||||
|
// Assert that Alice's payment failed.
|
||||||
|
ctxt, _ = context.WithTimeout(ctxb, defaultTimeout)
|
||||||
|
paymentsResp, err := net.Alice.ListPayments(
|
||||||
|
ctxt, &lnrpc.ListPaymentsRequest{
|
||||||
|
IncludeIncomplete: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NoError(t.t, err)
|
||||||
|
require.Equal(t.t, 1, len(paymentsResp.Payments))
|
||||||
|
|
||||||
|
htlcs := paymentsResp.Payments[0].Htlcs
|
||||||
|
|
||||||
|
require.Equal(t.t, 1, len(htlcs))
|
||||||
|
require.Equal(t.t, lnrpc.HTLCAttempt_FAILED, htlcs[0].Status)
|
||||||
|
}
|
@ -399,4 +399,8 @@ var allTestCases = []*testCase{
|
|||||||
name: "addpeer config",
|
name: "addpeer config",
|
||||||
test: testAddPeerConfig,
|
test: testAddPeerConfig,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "resolution handoff",
|
||||||
|
test: testResHandoff,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user