From 43f8bf288f9f48cd775dbe8c3520ba44970e2d48 Mon Sep 17 00:00:00 2001 From: Abdullahi Yunus Date: Thu, 15 May 2025 12:15:00 +0100 Subject: [PATCH] htlcswitch+channeldb: add htlcidx to fwding log In this commit we add htlcindex field to the forwardingevent struct, which is persisted alongside the other event fields. --- channeldb/forwarding_log.go | 62 ++++++++++++++++- channeldb/forwarding_log_test.go | 115 +++++++++++++++++++++++++++++++ htlcswitch/switch.go | 6 ++ 3 files changed, 180 insertions(+), 3 deletions(-) diff --git a/channeldb/forwarding_log.go b/channeldb/forwarding_log.go index f527b78dd..18cc6cf32 100644 --- a/channeldb/forwarding_log.go +++ b/channeldb/forwarding_log.go @@ -2,11 +2,13 @@ package channeldb import ( "bytes" + "errors" "io" "sort" "time" "github.com/btcsuite/btcwallet/walletdb" + "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" ) @@ -25,11 +27,12 @@ const ( // is as follows: // // * 8 byte incoming chan ID || 8 byte outgoing chan ID || 8 byte value in - // || 8 byte value out + // || 8 byte value out || 8 byte incoming htlc id || 8 byte + // outgoing htlc id // // From the value in and value out, callers can easily compute the // total fee extract from a forwarding event. - forwardingEventSize = 32 + forwardingEventSize = 48 // MaxResponseEvents is the max number of forwarding events that will // be returned by a single query response. This size was selected to @@ -78,14 +81,44 @@ type ForwardingEvent struct { // AmtOut is the amount of the outgoing HTLC. Subtracting the incoming // amount from this gives the total fees for this payment circuit. AmtOut lnwire.MilliSatoshi + + // IncomingHtlcID is the ID of the incoming HTLC in the payment circuit. + // If this is not set, the value will be nil. This field is added in + // v0.20 and is made optional to make it backward compatible with + // existing forwarding events created before it's introduction. + IncomingHtlcID fn.Option[uint64] + + // OutgoingHtlcID is the ID of the outgoing HTLC in the payment circuit. + // If this is not set, the value will be nil. This field is added in + // v0.20 and is made optional to make it backward compatible with + // existing forwarding events created before it's introduction. + OutgoingHtlcID fn.Option[uint64] } // encodeForwardingEvent writes out the target forwarding event to the passed // io.Writer, using the expected DB format. Note that the timestamp isn't // serialized as this will be the key value within the bucket. func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { + // We check for the HTLC IDs if they are set. If they are not, + // from v0.20 upward, we return an error to make it clear they are + // required. + incomingID, err := f.IncomingHtlcID.UnwrapOrErr( + errors.New("incoming HTLC ID must be set"), + ) + if err != nil { + return err + } + + outgoingID, err := f.OutgoingHtlcID.UnwrapOrErr( + errors.New("outgoing HTLC ID must be set"), + ) + if err != nil { + return err + } + return WriteElements( w, f.IncomingChanID, f.OutgoingChanID, f.AmtIn, f.AmtOut, + incomingID, outgoingID, ) } @@ -94,9 +127,32 @@ func encodeForwardingEvent(w io.Writer, f *ForwardingEvent) error { // won't be decoded, as the caller is expected to set this due to the bucket // structure of the forwarding log. func decodeForwardingEvent(r io.Reader, f *ForwardingEvent) error { - return ReadElements( + // Decode the original fields of the forwarding event. + err := ReadElements( r, &f.IncomingChanID, &f.OutgoingChanID, &f.AmtIn, &f.AmtOut, ) + if err != nil { + return err + } + + // Decode the incoming and outgoing htlc IDs. For backward compatibility + // with older records that don't have these fields, we handle EOF by + // setting the ID to nil. Any other error is treated as a read failure. + var incomingHtlcID, outgoingHtlcID uint64 + err = ReadElements(r, &incomingHtlcID, &outgoingHtlcID) + switch { + case err == nil: + f.IncomingHtlcID = fn.Some(incomingHtlcID) + f.OutgoingHtlcID = fn.Some(outgoingHtlcID) + + return nil + + case errors.Is(err, io.EOF): + return nil + + default: + return err + } } // AddForwardingEvents adds a series of forwarding events to the database. diff --git a/channeldb/forwarding_log_test.go b/channeldb/forwarding_log_test.go index 7ac2dfbc5..c86de52db 100644 --- a/channeldb/forwarding_log_test.go +++ b/channeldb/forwarding_log_test.go @@ -1,12 +1,15 @@ package channeldb import ( + "bytes" "math/rand" "reflect" "testing" "time" "github.com/davecgh/go-spew/spew" + "github.com/lightningnetwork/lnd/fn/v2" + "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -41,6 +44,8 @@ func TestForwardingLogBasicStorageAndQuery(t *testing.T) { OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), AmtIn: lnwire.MilliSatoshi(rand.Int63()), AmtOut: lnwire.MilliSatoshi(rand.Int63()), + IncomingHtlcID: fn.Some(uint64(i)), + OutgoingHtlcID: fn.Some(uint64(i)), } timestamp = timestamp.Add(time.Minute * 10) @@ -109,6 +114,8 @@ func TestForwardingLogQueryOptions(t *testing.T) { OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), AmtIn: lnwire.MilliSatoshi(rand.Int63()), AmtOut: lnwire.MilliSatoshi(rand.Int63()), + IncomingHtlcID: fn.Some(uint64(i)), + OutgoingHtlcID: fn.Some(uint64(i)), } endTime = endTime.Add(time.Minute * 10) @@ -208,6 +215,8 @@ func TestForwardingLogQueryLimit(t *testing.T) { OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), AmtIn: lnwire.MilliSatoshi(rand.Int63()), AmtOut: lnwire.MilliSatoshi(rand.Int63()), + IncomingHtlcID: fn.Some(uint64(i)), + OutgoingHtlcID: fn.Some(uint64(i)), } endTime = endTime.Add(time.Minute * 10) @@ -317,6 +326,8 @@ func TestForwardingLogStoreEvent(t *testing.T) { OutgoingChanID: lnwire.NewShortChanIDFromInt(uint64(rand.Int63())), AmtIn: lnwire.MilliSatoshi(rand.Int63()), AmtOut: lnwire.MilliSatoshi(rand.Int63()), + IncomingHtlcID: fn.Some(uint64(i)), + OutgoingHtlcID: fn.Some(uint64(i)), } } @@ -360,3 +371,107 @@ func TestForwardingLogStoreEvent(t *testing.T) { } } } + +// TestForwardingLogDecodeForwardingEvent tests that we're able to decode +// forwarding events that don't have the new incoming and outgoing htlc +// indices. +func TestForwardingLogDecodeForwardingEvent(t *testing.T) { + t.Parallel() + + // First, we'll set up a test database, and use that to instantiate the + // forwarding event log that we'll be using for the duration of the + // test. + db, err := MakeTestDB(t) + require.NoError(t, err) + + log := ForwardingLog{ + db: db, + } + + initialTime := time.Unix(1234, 0) + endTime := time.Unix(1234, 0) + + // We'll create forwarding events that don't have the incoming and + // outgoing htlc indices. + numEvents := 10 + events := make([]ForwardingEvent, numEvents) + for i := range numEvents { + events[i] = ForwardingEvent{ + Timestamp: endTime, + IncomingChanID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + OutgoingChanID: lnwire.NewShortChanIDFromInt( + uint64(rand.Int63()), + ), + AmtIn: lnwire.MilliSatoshi(rand.Int63()), + AmtOut: lnwire.MilliSatoshi(rand.Int63()), + } + + endTime = endTime.Add(time.Minute * 10) + } + + // Now that all of our events are constructed, we'll add them to the + // database. + err = writeOldFormatEvents(db, events) + require.NoError(t, err) + + // With all of our events added, we'll now query for them and ensure + // that the incoming and outgoing htlc indices are set to 0 (default + // value) for all events. + eventQuery := ForwardingEventQuery{ + StartTime: initialTime, + EndTime: endTime, + IndexOffset: 0, + NumMaxEvents: uint32(numEvents * 3), + } + timeSlice, err := log.Query(eventQuery) + require.NoError(t, err) + require.Equal(t, numEvents, len(timeSlice.ForwardingEvents)) + + for _, event := range timeSlice.ForwardingEvents { + require.Equal(t, fn.None[uint64](), event.IncomingHtlcID) + require.Equal(t, fn.None[uint64](), event.OutgoingHtlcID) + } +} + +// writeOldFormatEvents writes forwarding events to the database in the old +// format (without incoming and outgoing htlc indices). This is used to test +// backward compatibility. +func writeOldFormatEvents(db *DB, events []ForwardingEvent) error { + return kvdb.Batch(db.Backend, func(tx kvdb.RwTx) error { + bucket, err := tx.CreateTopLevelBucket(forwardingLogBucket) + if err != nil { + return err + } + + for _, event := range events { + var timestamp [8]byte + byteOrder.PutUint64(timestamp[:], uint64( + event.Timestamp.UnixNano(), + )) + + // Use the old event size (32 bytes) for writing old + // format events. + var eventBytes [32]byte + eventBuf := bytes.NewBuffer(eventBytes[0:0:32]) + + // Write only the original fields without incoming and + // outgoing htlc indices. + if err := WriteElements( + eventBuf, event.IncomingChanID, + event.OutgoingChanID, event.AmtIn, event.AmtOut, + ); err != nil { + return err + } + + if err := bucket.Put( + timestamp[:], eventBuf.Bytes(), + ); err != nil { + return err + } + } + + return nil + }) +} diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index d034e732e..1009ba302 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -3074,6 +3074,12 @@ func (s *Switch) handlePacketSettle(packet *htlcPacket) error { OutgoingChanID: circuit.Outgoing.ChanID, AmtIn: circuit.IncomingAmount, AmtOut: circuit.OutgoingAmount, + IncomingHtlcID: fn.Some( + circuit.Incoming.HtlcID, + ), + OutgoingHtlcID: fn.Some( + circuit.Outgoing.HtlcID, + ), }, ) s.fwdEventMtx.Unlock()