htlcswitch: thread through packet's inbound wire records

For calculating the available auxiliary bandwidth of a channel, we need
access to the inbound custom wire records of the HTLC packet, which
might contain auxiliary information about the worth of the HTLC packet
apart from the BTC value being transported.
This commit is contained in:
Oliver Gugger 2024-12-04 12:02:24 +01:00
parent 117c6bc781
commit a2e78c3984
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
5 changed files with 55 additions and 48 deletions

View File

@ -272,10 +272,10 @@ type ChannelLink interface {
// in order to signal to the source of the HTLC, the policy consistency
// issue.
CheckHtlcForward(payHash [32]byte, incomingAmt lnwire.MilliSatoshi,
amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32,
inboundFee models.InboundFee,
heightNow uint32, scid lnwire.ShortChannelID) *LinkError
amtToForward lnwire.MilliSatoshi, incomingTimeout,
outgoingTimeout uint32, inboundFee models.InboundFee,
heightNow uint32, scid lnwire.ShortChannelID,
customRecords lnwire.CustomRecords) *LinkError
// CheckHtlcTransit should return a nil error if the passed HTLC details
// satisfy the current channel policy. Otherwise, a LinkError with a
@ -283,7 +283,8 @@ type ChannelLink interface {
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
CheckHtlcTransit(payHash [32]byte, amt lnwire.MilliSatoshi,
timeout uint32, heightNow uint32) *LinkError
timeout uint32, heightNow uint32,
customRecords lnwire.CustomRecords) *LinkError
// Stats return the statistics of channel link. Number of updates,
// total sent/received milli-satoshis.

View File

@ -3233,11 +3233,11 @@ func (l *channelLink) UpdateForwardingPolicy(
// issue.
//
// NOTE: Part of the ChannelLink interface.
func (l *channelLink) CheckHtlcForward(payHash [32]byte,
incomingHtlcAmt, amtToForward lnwire.MilliSatoshi,
incomingTimeout, outgoingTimeout uint32,
inboundFee models.InboundFee,
heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
func (l *channelLink) CheckHtlcForward(payHash [32]byte, incomingHtlcAmt,
amtToForward lnwire.MilliSatoshi, incomingTimeout,
outgoingTimeout uint32, inboundFee models.InboundFee,
heightNow uint32, originalScid lnwire.ShortChannelID,
customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
@ -3286,7 +3286,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte,
// Check whether the outgoing htlc satisfies the channel policy.
err := l.canSendHtlc(
policy, payHash, amtToForward, outgoingTimeout, heightNow,
originalScid,
originalScid, customRecords,
)
if err != nil {
return err
@ -3322,8 +3322,8 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte,
// the violation. This call is intended to be used for locally initiated
// payments for which there is no corresponding incoming htlc.
func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32) *LinkError {
amt lnwire.MilliSatoshi, timeout uint32, heightNow uint32,
customRecords lnwire.CustomRecords) *LinkError {
l.RLock()
policy := l.cfg.FwrdingPolicy
@ -3334,6 +3334,7 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
// to occur.
return l.canSendHtlc(
policy, payHash, amt, timeout, heightNow, hop.Source,
customRecords,
)
}
@ -3341,7 +3342,8 @@ func (l *channelLink) CheckHtlcTransit(payHash [32]byte,
// the channel's amount and time lock constraints.
func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy,
payHash [32]byte, amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32, originalScid lnwire.ShortChannelID) *LinkError {
heightNow uint32, originalScid lnwire.ShortChannelID,
customRecords lnwire.CustomRecords) *LinkError {
// As our first sanity check, we'll ensure that the passed HTLC isn't
// too small for the next hop. If so, then we'll cancel the HTLC

View File

@ -6243,9 +6243,9 @@ func TestCheckHtlcForward(t *testing.T) {
var hash [32]byte
t.Run("satisfied", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1500, 1000, 200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if result != nil {
t.Fatalf("expected policy to be satisfied")
@ -6253,9 +6253,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("below minhtlc", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 100, 50,
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 100, 50, 200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailAmountBelowMinimum); !ok {
t.Fatalf("expected FailAmountBelowMinimum failure code")
@ -6263,9 +6263,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("above maxhtlc", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1200,
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1500, 1200, 200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailTemporaryChannelFailure); !ok {
t.Fatalf("expected FailTemporaryChannelFailure failure code")
@ -6273,9 +6273,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("insufficient fee", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1005, 1000,
200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1005, 1000, 200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailFeeInsufficient); !ok {
t.Fatalf("expected FailFeeInsufficient failure code")
@ -6288,17 +6288,17 @@ func TestCheckHtlcForward(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(
hash, 100005, 100000, 200,
150, models.InboundFee{}, 0, lnwire.ShortChannelID{},
hash, 100005, 100000, 200, 150, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
_, ok := result.WireMessage().(*lnwire.FailFeeInsufficient)
require.True(t, ok, "expected FailFeeInsufficient failure code")
})
t.Run("expiry too soon", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 150, models.InboundFee{}, 190,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1500, 1000, 200, 150, models.InboundFee{}, 190,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooSoon); !ok {
t.Fatalf("expected FailExpiryTooSoon failure code")
@ -6306,9 +6306,9 @@ func TestCheckHtlcForward(t *testing.T) {
})
t.Run("incorrect cltv expiry", func(t *testing.T) {
result := link.CheckHtlcForward(hash, 1500, 1000,
200, 190, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1500, 1000, 200, 190, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailIncorrectCltvExpiry); !ok {
t.Fatalf("expected FailIncorrectCltvExpiry failure code")
@ -6318,9 +6318,9 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("cltv expiry too far in the future", func(t *testing.T) {
// Check that expiry isn't too far in the future.
result := link.CheckHtlcForward(hash, 1500, 1000,
10200, 10100, models.InboundFee{}, 0,
lnwire.ShortChannelID{},
result := link.CheckHtlcForward(
hash, 1500, 1000, 10200, 10100, models.InboundFee{}, 0,
lnwire.ShortChannelID{}, nil,
)
if _, ok := result.WireMessage().(*lnwire.FailExpiryTooFar); !ok {
t.Fatalf("expected FailExpiryTooFar failure code")
@ -6330,9 +6330,11 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("inbound fee satisfied", func(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(hash, 1000+10-2-1, 1000,
200, 150, models.InboundFee{Base: -2, Rate: -1_000},
0, lnwire.ShortChannelID{})
result := link.CheckHtlcForward(
hash, 1000+10-2-1, 1000, 200, 150,
models.InboundFee{Base: -2, Rate: -1_000},
0, lnwire.ShortChannelID{}, nil,
)
if result != nil {
t.Fatalf("expected policy to be satisfied")
}
@ -6341,9 +6343,11 @@ func TestCheckHtlcForward(t *testing.T) {
t.Run("inbound fee insufficient", func(t *testing.T) {
t.Parallel()
result := link.CheckHtlcForward(hash, 1000+10-10-101-1, 1000,
result := link.CheckHtlcForward(
hash, 1000+10-10-101-1, 1000,
200, 150, models.InboundFee{Base: -10, Rate: -100_000},
0, lnwire.ShortChannelID{})
0, lnwire.ShortChannelID{}, nil,
)
msg := result.WireMessage()
if _, ok := msg.(*lnwire.FailFeeInsufficient); !ok {

View File

@ -846,14 +846,14 @@ func (f *mockChannelLink) UpdateForwardingPolicy(_ models.ForwardingPolicy) {
}
func (f *mockChannelLink) CheckHtlcForward([32]byte, lnwire.MilliSatoshi,
lnwire.MilliSatoshi, uint32, uint32, models.InboundFee, uint32,
lnwire.ShortChannelID) *LinkError {
lnwire.ShortChannelID, lnwire.CustomRecords) *LinkError {
return f.checkHtlcForwardResult
}
func (f *mockChannelLink) CheckHtlcTransit(payHash [32]byte,
amt lnwire.MilliSatoshi, timeout uint32,
heightNow uint32) *LinkError {
heightNow uint32, _ lnwire.CustomRecords) *LinkError {
return f.checkHtlcTransitResult
}

View File

@ -917,6 +917,7 @@ func (s *Switch) getLocalLink(pkt *htlcPacket, htlc *lnwire.UpdateAddHTLC) (
currentHeight := atomic.LoadUint32(&s.bestHeight)
htlcErr := link.CheckHtlcTransit(
htlc.PaymentHash, htlc.Amount, htlc.Expiry, currentHeight,
htlc.CustomRecords,
)
if htlcErr != nil {
log.Errorf("Link %v policy for local forward not "+
@ -2887,10 +2888,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket,
failure = link.CheckHtlcForward(
htlc.PaymentHash, packet.incomingAmount,
packet.amount, packet.incomingTimeout,
packet.outgoingTimeout,
packet.inboundFee,
currentHeight,
packet.originalOutgoingChanID,
packet.outgoingTimeout, packet.inboundFee,
currentHeight, packet.originalOutgoingChanID,
htlc.CustomRecords,
)
}