htlcswitch: introducing interceptable switch.

In this commit we implement a wrapper arround the switch, called
InterceptableSwitch. This kind of wrapper behaves like a proxy which
intercepts forwarded packets and allows an external interceptor to
signal if it is interested to hold this forward and resolve it
manually later or let the switch execute its default behavior.
This infrastructure allows the RPC layer to expose interceptor
registration API to the user and by that enable the implementation
of custom routing behavior.
This commit is contained in:
Roei Erez
2020-05-19 12:56:58 +03:00
parent 1a6701122c
commit 0f50d8b2ed
5 changed files with 401 additions and 1 deletions

View File

@@ -11,6 +11,7 @@ import (
"time"
"github.com/btcsuite/btcutil"
"github.com/btcsuite/fastsha256"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/htlcswitch/hop"
@@ -1679,6 +1680,9 @@ func testSkipIneligibleLinksMultiHopForward(t *testing.T,
if err := s.ForwardPackets(nil, packet); err != nil {
t.Fatal(err)
}
// We select from all links and extract the error if exists.
// The packet must be selected but we don't always expect a link error.
var linkError *LinkError
select {
case p := <-aliceChannelLink.packets:
@@ -3111,3 +3115,186 @@ func getThreeHopEvents(channels *clusterChannels, htlcID uint64,
return aliceEvents, bobEvents, carolEvents
}
type mockForwardInterceptor struct {
intercepted InterceptedForward
}
func (m *mockForwardInterceptor) InterceptForwardHtlc(intercepted InterceptedForward) bool {
m.intercepted = intercepted
return true
}
func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error {
return m.intercepted.Settle(preimage)
}
func (m *mockForwardInterceptor) fail() error {
return m.intercepted.Fail()
}
func (m *mockForwardInterceptor) resume() error {
return m.intercepted.Resume()
}
func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
if s.circuits.NumPending() != pending {
t.Fatal("wrong amount of half circuits")
}
if s.circuits.NumOpen() != opened {
t.Fatal("wrong amount of circuits")
}
}
func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
expectReceive bool) {
// Pull packet from targetLink link.
select {
case packet := <-targetLink.packets:
if !expectReceive {
t.Fatal("forward was intercepted, shouldn't land at bob link")
} else if err := targetLink.completeCircuit(packet); err != nil {
t.Fatalf("unable to complete payment circuit: %v", err)
}
case <-time.After(time.Second):
if expectReceive {
t.Fatal("request was not propagated to destination")
}
}
}
func TestSwitchHoldForward(t *testing.T) {
t.Parallel()
chanID1, chanID2, aliceChanID, bobChanID := genIDs()
alicePeer, err := newMockServer(
t, "alice", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create alice server: %v", err)
}
bobPeer, err := newMockServer(
t, "bob", testStartingHeight, nil, testDefaultDelta,
)
if err != nil {
t.Fatalf("unable to create bob server: %v", err)
}
tempPath, err := ioutil.TempDir("", "circuitdb")
if err != nil {
t.Fatalf("unable to temporary path: %v", err)
}
cdb, err := channeldb.Open(tempPath)
if err != nil {
t.Fatalf("unable to open channeldb: %v", err)
}
s, err := initSwitchWithDB(testStartingHeight, cdb)
if err != nil {
t.Fatalf("unable to init switch: %v", err)
}
if err := s.Start(); err != nil {
t.Fatalf("unable to start switch: %v", err)
}
defer func() {
if err := s.Stop(); err != nil {
t.Fatalf(err.Error())
}
}()
aliceChannelLink := newMockChannelLink(
s, chanID1, aliceChanID, alicePeer, true,
)
bobChannelLink := newMockChannelLink(
s, chanID2, bobChanID, bobPeer, true,
)
if err := s.AddLink(aliceChannelLink); err != nil {
t.Fatalf("unable to add alice link: %v", err)
}
if err := s.AddLink(bobChannelLink); err != nil {
t.Fatalf("unable to add bob link: %v", err)
}
// Create request which should be forwarded from Alice channel link to
// bob channel link.
preimage := [sha256.Size]byte{1}
rhash := fastsha256.Sum256(preimage[:])
ogPacket := &htlcPacket{
incomingChanID: aliceChannelLink.ShortChanID(),
incomingHTLCID: 0,
outgoingChanID: bobChannelLink.ShortChanID(),
obfuscator: NewMockObfuscator(),
htlc: &lnwire.UpdateAddHTLC{
PaymentHash: rhash,
Amount: 1,
},
}
forwardInterceptor := &mockForwardInterceptor{}
switchForwardInterceptor := NewInterceptableSwitch(s)
switchForwardInterceptor.SetInterceptor(forwardInterceptor.InterceptForwardHtlc)
linkQuit := make(chan struct{})
// Test resume a hold forward
assertNumCircuits(t, s, 0, 0)
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.resume(); err != nil {
t.Fatalf("failed to resume forward")
}
assertOutgoingLinkReceive(t, bobChannelLink, true)
assertNumCircuits(t, s, 1, 1)
// settling the htlc to close the circuit.
settle := &htlcPacket{
outgoingChanID: bobChannelLink.ShortChanID(),
outgoingHTLCID: 0,
amount: 1,
htlc: &lnwire.UpdateFulfillHTLC{
PaymentPreimage: preimage,
},
}
if err := switchForwardInterceptor.ForwardPackets(linkQuit, settle); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Test failing a hold forward
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.fail(); err != nil {
t.Fatalf("failed to cancel forward %v", err)
}
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
// Test settling a hold forward
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
t.Fatalf("can't forward htlc packet: %v", err)
}
assertNumCircuits(t, s, 0, 0)
assertOutgoingLinkReceive(t, bobChannelLink, false)
if err := forwardInterceptor.settle(preimage); err != nil {
t.Fatal("failed to cancel forward")
}
assertOutgoingLinkReceive(t, bobChannelLink, false)
assertOutgoingLinkReceive(t, aliceChannelLink, true)
assertNumCircuits(t, s, 0, 0)
}