htlcswitch+channeldb: add htlcidx to fwding log

In this commit we add htlcindex field to the forwardingevent
struct, which is persisted alongside the other event fields.
This commit is contained in:
Abdullahi Yunus
2025-05-15 12:15:00 +01:00
committed by Olaoluwa Osuntokun
parent b27f401ccc
commit 43f8bf288f
3 changed files with 180 additions and 3 deletions

View File

@@ -2,11 +2,13 @@ package channeldb
import (
"bytes"
"errors"
"io"
"sort"
"time"
"github.com/btcsuite/btcwallet/walletdb"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
)
@@ -25,11 +27,12 @@ const (
// is as follows:
//
// * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in
// || 8 byte value out
// || 8 byte value out || 8 byte incoming htlc id || 8 byte
// outgoing htlc id
//
// From the value in and value out, callers can easily compute the
// total fee extract from a forwarding event.
forwardingEventSize = 32
forwardingEventSize = 48
// MaxResponseEvents is the max number of forwarding events that will
// be returned by a single query response. This size was selected to
@@ -78,14 +81,44 @@ type ForwardingEvent struct {
// AmtOut is the amount of the outgoing HTLC. Subtracting the incoming
// amount from this gives the total fees for this payment circuit.
AmtOut lnwire.MilliSatoshi
// IncomingHtlcID is the ID of the incoming HTLC in the payment circuit.
// If this is not set, the value will be nil. This field is added in
// v0.20 and is made optional to make it backward compatible with
// existing forwarding events created before it's introduction.
IncomingHtlcID fn.Option[uint64]
// OutgoingHtlcID is the ID of the outgoing HTLC in the payment circuit.
// If this is not set, the value will be nil. This field is added in
// v0.20 and is made optional to make it backward compatible with
// existing forwarding events created before it's introduction.
OutgoingHtlcID fn.Option[uint64]
}
// encodeForwardingEvent writes out the target forwarding event to the passed
// io.Writer, using the expected DB format. Note that the timestamp isn't
// serialized as this will be the key value within the bucket.
func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
// We check for the HTLC IDs if they are set. If they are not,
// from v0.20 upward, we return an error to make it clear they are
// required.
incomingID, err := f.IncomingHtlcID.UnwrapOrErr(
errors.New("incoming HTLC ID must be set"),
)
if err != nil {
return err
}
outgoingID, err := f.OutgoingHtlcID.UnwrapOrErr(
errors.New("outgoing HTLC ID must be set"),
)
if err != nil {
return err
}
return WriteElements(
w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut,
incomingID, outgoingID,
)
}
@@ -94,9 +127,32 @@ func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error {
// won't be decoded, as the caller is expected to set this due to the bucket
// structure of the forwarding log.
func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error {
return ReadElements(
// Decode the original fields of the forwarding event.
err := ReadElements(
r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut,
)
if err != nil {
return err
}
// Decode the incoming and outgoing htlc IDs. For backward compatibility
// with older records that don't have these fields, we handle EOF by
// setting the ID to nil. Any other error is treated as a read failure.
var incomingHtlcID, outgoingHtlcID uint64
err = ReadElements(r, &incomingHtlcID, &outgoingHtlcID)
switch {
case err == nil:
f.IncomingHtlcID = fn.Some(incomingHtlcID)
f.OutgoingHtlcID = fn.Some(outgoingHtlcID)
return nil
case errors.Is(err, io.EOF):
return nil
default:
return err
}
}
// AddForwardingEvents adds a series of forwarding events to the database.

View File

@@ -1,12 +1,15 @@
package channeldb
import (
"bytes"
"math/rand"
"reflect"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -41,6 +44,8 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) {
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
IncomingHtlcID: fn.Some(uint64(i)),
OutgoingHtlcID: fn.Some(uint64(i)),
}
timestamp = timestamp.Add(time.Minute * 10)
@@ -109,6 +114,8 @@ func TestForwardingLogQueryOptions(t *testing.T) {
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
IncomingHtlcID: fn.Some(uint64(i)),
OutgoingHtlcID: fn.Some(uint64(i)),
}
endTime = endTime.Add(time.Minute * 10)
@@ -208,6 +215,8 @@ func TestForwardingLogQueryLimit(t *testing.T) {
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
IncomingHtlcID: fn.Some(uint64(i)),
OutgoingHtlcID: fn.Some(uint64(i)),
}
endTime = endTime.Add(time.Minute * 10)
@@ -317,6 +326,8 @@ func TestForwardingLogStoreEvent(t *testing.T) {
OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
IncomingHtlcID: fn.Some(uint64(i)),
OutgoingHtlcID: fn.Some(uint64(i)),
}
}
@@ -360,3 +371,107 @@ func TestForwardingLogStoreEvent(t *testing.T) {
}
}
}
// TestForwardingLogDecodeForwardingEvent tests that we're able to decode
// forwarding events that don't have the new incoming and outgoing htlc
// indices.
func TestForwardingLogDecodeForwardingEvent(t *testing.T) {
t.Parallel()
// First, we'll set up a test database, and use that to instantiate the
// forwarding event log that we'll be using for the duration of the
// test.
db, err := MakeTestDB(t)
require.NoError(t, err)
log := ForwardingLog{
db: db,
}
initialTime := time.Unix(1234, 0)
endTime := time.Unix(1234, 0)
// We'll create forwarding events that don't have the incoming and
// outgoing htlc indices.
numEvents := 10
events := make([]ForwardingEvent, numEvents)
for i := range numEvents {
events[i] = ForwardingEvent{
Timestamp: endTime,
IncomingChanID: lnwire.NewShortChanIDFromInt(
uint64(rand.Int63()),
),
OutgoingChanID: lnwire.NewShortChanIDFromInt(
uint64(rand.Int63()),
),
AmtIn: lnwire.MilliSatoshi(rand.Int63()),
AmtOut: lnwire.MilliSatoshi(rand.Int63()),
}
endTime = endTime.Add(time.Minute * 10)
}
// Now that all of our events are constructed, we'll add them to the
// database.
err = writeOldFormatEvents(db, events)
require.NoError(t, err)
// With all of our events added, we'll now query for them and ensure
// that the incoming and outgoing htlc indices are set to 0 (default
// value) for all events.
eventQuery := ForwardingEventQuery{
StartTime: initialTime,
EndTime: endTime,
IndexOffset: 0,
NumMaxEvents: uint32(numEvents * 3),
}
timeSlice, err := log.Query(eventQuery)
require.NoError(t, err)
require.Equal(t, numEvents, len(timeSlice.ForwardingEvents))
for _, event := range timeSlice.ForwardingEvents {
require.Equal(t, fn.None[uint64](), event.IncomingHtlcID)
require.Equal(t, fn.None[uint64](), event.OutgoingHtlcID)
}
}
// writeOldFormatEvents writes forwarding events to the database in the old
// format (without incoming and outgoing htlc indices). This is used to test
// backward compatibility.
func writeOldFormatEvents(db *DB, events []ForwardingEvent) error {
return kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error {
bucket, err := tx.CreateTopLevelBucket(forwardingLogBucket)
if err != nil {
return err
}
for _, event := range events {
var timestamp [8]byte
byteOrder.PutUint64(timestamp[:], uint64(
event.Timestamp.UnixNano(),
))
// Use the old event size (32 bytes) for writing old
// format events.
var eventBytes [32]byte
eventBuf := bytes.NewBuffer(eventBytes[0:0:32])
// Write only the original fields without incoming and
// outgoing htlc indices.
if err := WriteElements(
eventBuf, event.IncomingChanID,
event.OutgoingChanID, event.AmtIn, event.AmtOut,
); err != nil {
return err
}
if err := bucket.Put(
timestamp[:], eventBuf.Bytes(),
); err != nil {
return err
}
}
return nil
})
}

View File

@@ -3074,6 +3074,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error {
OutgoingChanID: circuit.Outgoing.ChanID,
AmtIn: circuit.IncomingAmount,
AmtOut: circuit.OutgoingAmount,
IncomingHtlcID: fn.Some(
circuit.Incoming.HtlcID,
),
OutgoingHtlcID: fn.Some(
circuit.Outgoing.HtlcID,
),
},
)
s.fwdEventMtx.Unlock()