mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-21 14:10:35 +02:00
htlcswitch: accept failure reason for intercepted htlcs
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
@@ -15,6 +16,10 @@ var (
|
||||
// ErrFwdNotExists is an error returned when the caller tries to resolve
|
||||
// a forward that doesn't exist anymore.
|
||||
ErrFwdNotExists = errors.New("forward does not exist")
|
||||
|
||||
// ErrUnsupportedFailureCode when processing of an unsupported failure
|
||||
// code is attempted.
|
||||
ErrUnsupportedFailureCode = errors.New("unsupported failure code")
|
||||
)
|
||||
|
||||
// InterceptableSwitch is an implementation of ForwardingSwitch interface.
|
||||
@@ -137,21 +142,63 @@ func (f *interceptedForward) Resume() error {
|
||||
return f.htlcSwitch.ForwardPackets(f.linkQuit, f.packet)
|
||||
}
|
||||
|
||||
// Fail forward a failed packet to the switch.
|
||||
func (f *interceptedForward) Fail() error {
|
||||
update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate(
|
||||
f.packet.incomingChanID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
// Fail notifies the intention to fail an existing hold forward with an
|
||||
// encrypted failure reason.
|
||||
func (f *interceptedForward) Fail(reason []byte) error {
|
||||
obfuscatedReason := f.packet.obfuscator.IntermediateEncrypt(reason)
|
||||
|
||||
return f.resolve(&lnwire.UpdateFailHTLC{
|
||||
Reason: obfuscatedReason,
|
||||
})
|
||||
}
|
||||
|
||||
// FailWithCode notifies the intention to fail an existing hold forward with the
|
||||
// specified failure code.
|
||||
func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error {
|
||||
shaOnionBlob := func() [32]byte {
|
||||
return sha256.Sum256(f.htlc.OnionBlob[:])
|
||||
}
|
||||
|
||||
reason, err := f.packet.obfuscator.EncryptFirstHop(
|
||||
lnwire.NewTemporaryChannelFailure(update),
|
||||
)
|
||||
// Create a local failure.
|
||||
var failureMsg lnwire.FailureMessage
|
||||
|
||||
switch code {
|
||||
case lnwire.CodeInvalidOnionVersion:
|
||||
failureMsg = &lnwire.FailInvalidOnionVersion{
|
||||
OnionSHA256: shaOnionBlob(),
|
||||
}
|
||||
|
||||
case lnwire.CodeInvalidOnionHmac:
|
||||
failureMsg = &lnwire.FailInvalidOnionHmac{
|
||||
OnionSHA256: shaOnionBlob(),
|
||||
}
|
||||
|
||||
case lnwire.CodeInvalidOnionKey:
|
||||
failureMsg = &lnwire.FailInvalidOnionKey{
|
||||
OnionSHA256: shaOnionBlob(),
|
||||
}
|
||||
|
||||
case lnwire.CodeTemporaryChannelFailure:
|
||||
update, err := f.htlcSwitch.cfg.FetchLastChannelUpdate(
|
||||
f.packet.incomingChanID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
failureMsg = lnwire.NewTemporaryChannelFailure(update)
|
||||
|
||||
default:
|
||||
return ErrUnsupportedFailureCode
|
||||
}
|
||||
|
||||
// Encrypt the failure for the first hop. This node will be the origin
|
||||
// of the failure.
|
||||
reason, err := f.packet.obfuscator.EncryptFirstHop(failureMsg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to encrypt failure reason %v", err)
|
||||
}
|
||||
|
||||
return f.resolve(&lnwire.UpdateFailHTLC{
|
||||
Reason: reason,
|
||||
})
|
||||
|
@@ -297,8 +297,13 @@ type InterceptedForward interface {
|
||||
// forward with a given preimage.
|
||||
Settle(lntypes.Preimage) error
|
||||
|
||||
// Fails notifies the intention to fail an existing hold forward
|
||||
Fail() error
|
||||
// Fail notifies the intention to fail an existing hold forward with an
|
||||
// encrypted failure reason.
|
||||
Fail(reason []byte) error
|
||||
|
||||
// FailWithCode notifies the intention to fail an existing hold forward
|
||||
// with the specified failure code.
|
||||
FailWithCode(code lnwire.FailCode) error
|
||||
}
|
||||
|
||||
// htlcNotifier is an interface which represents the input side of the
|
||||
|
@@ -1,6 +1,7 @@
|
||||
package htlcswitch
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"fmt"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/ticker"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -3142,7 +3144,8 @@ type mockForwardInterceptor struct {
|
||||
intercepted InterceptedForward
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) InterceptForwardHtlc(intercepted InterceptedForward) bool {
|
||||
func (m *mockForwardInterceptor) InterceptForwardHtlc(
|
||||
intercepted InterceptedForward) bool {
|
||||
|
||||
m.intercepted = intercepted
|
||||
return true
|
||||
@@ -3152,8 +3155,14 @@ func (m *mockForwardInterceptor) settle(preimage lntypes.Preimage) error {
|
||||
return m.intercepted.Settle(preimage)
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) fail() error {
|
||||
return m.intercepted.Fail()
|
||||
func (m *mockForwardInterceptor) fail(reason []byte) error {
|
||||
return m.intercepted.Fail(reason)
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) failWithCode(
|
||||
code lnwire.FailCode) error {
|
||||
|
||||
return m.intercepted.FailWithCode(code)
|
||||
}
|
||||
|
||||
func (m *mockForwardInterceptor) resume() error {
|
||||
@@ -3170,7 +3179,7 @@ func assertNumCircuits(t *testing.T, s *Switch, pending, opened int) {
|
||||
}
|
||||
|
||||
func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
|
||||
expectReceive bool) {
|
||||
expectReceive bool) *htlcPacket {
|
||||
|
||||
// Pull packet from targetLink link.
|
||||
select {
|
||||
@@ -3181,11 +3190,15 @@ func assertOutgoingLinkReceive(t *testing.T, targetLink *mockChannelLink,
|
||||
t.Fatalf("unable to complete payment circuit: %v", err)
|
||||
}
|
||||
|
||||
return packet
|
||||
|
||||
case <-time.After(time.Second):
|
||||
if expectReceive {
|
||||
t.Fatal("request was not propagated to destination")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSwitchHoldForward(t *testing.T) {
|
||||
@@ -3247,6 +3260,7 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
// bob channel link.
|
||||
preimage := [sha256.Size]byte{1}
|
||||
rhash := sha256.Sum256(preimage[:])
|
||||
onionBlob := [1366]byte{4, 5, 6}
|
||||
ogPacket := &htlcPacket{
|
||||
incomingChanID: aliceChannelLink.ShortChanID(),
|
||||
incomingHTLCID: 0,
|
||||
@@ -3255,6 +3269,7 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
htlc: &lnwire.UpdateAddHTLC{
|
||||
PaymentHash: rhash,
|
||||
Amount: 1,
|
||||
OnionBlob: onionBlob,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -3299,13 +3314,55 @@ func TestSwitchHoldForward(t *testing.T) {
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
if err := forwardInterceptor.fail(); err != nil {
|
||||
if err := forwardInterceptor.fail(nil); err != nil {
|
||||
t.Fatalf("failed to cancel forward %v", err)
|
||||
}
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test failing a hold forward with a failure message.
|
||||
require.NoError(t,
|
||||
switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket),
|
||||
)
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
reason := lnwire.OpaqueReason([]byte{1, 2, 3})
|
||||
require.NoError(t, forwardInterceptor.fail(reason))
|
||||
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
packet := assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
|
||||
require.Equal(t, reason, packet.htlc.(*lnwire.UpdateFailHTLC).Reason)
|
||||
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test failing a hold forward with a malformed htlc failure.
|
||||
err = switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket)
|
||||
require.NoError(t, err)
|
||||
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
|
||||
code := lnwire.CodeInvalidOnionKey
|
||||
require.NoError(t, forwardInterceptor.failWithCode(code))
|
||||
|
||||
assertOutgoingLinkReceive(t, bobChannelLink, false)
|
||||
packet = assertOutgoingLinkReceive(t, aliceChannelLink, true)
|
||||
failPacket := packet.htlc.(*lnwire.UpdateFailHTLC)
|
||||
|
||||
shaOnionBlob := sha256.Sum256(onionBlob[:])
|
||||
expectedFailure := &lnwire.FailInvalidOnionKey{
|
||||
OnionSHA256: shaOnionBlob,
|
||||
}
|
||||
var b bytes.Buffer
|
||||
require.NoError(t, lnwire.EncodeFailure(&b, expectedFailure, 0))
|
||||
|
||||
assert.Equal(t, lnwire.OpaqueReason(b.Bytes()), failPacket.Reason)
|
||||
|
||||
assertNumCircuits(t, s, 0, 0)
|
||||
|
||||
// Test settling a hold forward
|
||||
if err := switchForwardInterceptor.ForwardPackets(linkQuit, ogPacket); err != nil {
|
||||
t.Fatalf("can't forward htlc packet: %v", err)
|
||||
|
Reference in New Issue
Block a user