diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index b7b73a35a..3abc73f81 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -155,19 +155,17 @@ type HTLCEntry struct { // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] // CustomBlob is an optional blob that can be used to store information // specific to revocation handling for a custom channel type. CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] + + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.OptionalRecordT[tlv.TlvType6, uint16] } // toTlvStream converts an HTLCEntry record into a tlv representation. @@ -184,6 +182,12 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { records = append(records, r.Record()) }) + h.HtlcIndex.WhenSome(func(r tlv.RecordT[tlv.TlvType6, uint16]) { + records = append(records, r.Record()) + }) + + tlv.SortRecords(records) + return tlv.NewStream(records...) } @@ -203,6 +207,9 @@ func NewHTLCEntryFromHTLC(htlc HTLC) (*HTLCEntry, error) { Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), + HtlcIndex: tlv.SomeRecordT(tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + )), } if len(htlc.CustomRecords) != 0 { @@ -509,6 +516,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { var htlc HTLCEntry customBlob := htlc.CustomBlob.Zero() + htlcIndex := htlc.HtlcIndex.Zero() // Create the tlv stream. records := []tlv.Record{ @@ -518,6 +526,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.Incoming.Record(), htlc.Amt.Record(), customBlob.Record(), + htlcIndex.Record(), } tlvStream, err := tlv.NewStream(records...) @@ -539,6 +548,10 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.CustomBlob = tlv.SomeRecordT(customBlob) } + if t, ok := parsedTypes[htlcIndex.TlvType()]; ok && t == nil { + htlc.HtlcIndex = tlv.SomeRecordT(htlcIndex) + } + // Append the entry. htlcs = append(htlcs, &htlc) } diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 139a02d52..4290552ee 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -59,10 +59,13 @@ var ( CustomBlob: tlv.SomeRecordT( tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), ), + HtlcIndex: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](0x33), + ), } testHTLCEntryBytes = []byte{ - // Body length 41. - 0x29, + // Body length 45. + 0x2d, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -76,6 +79,8 @@ var ( // Custom blob tlv. 0x5, 0x11, 0xfe, 0x00, 0x01, 0x00, 0x01, 0x0b, 0x63, 0x75, 0x73, 0x74, 0x6f, 0x6d, 0x20, 0x64, 0x61, 0x74, 0x61, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x33, } testHTLCEntryHash = HTLCEntry{ @@ -126,7 +131,11 @@ var ( Htlcs: []HTLC{{ RefundTimeout: testHTLCEntry.RefundTimeout.Val, OutputIndex: int32(testHTLCEntry.OutputIndex.Val), - Incoming: testHTLCEntry.Incoming.Val, + HtlcIndex: uint64( + testHTLCEntry.HtlcIndex.ValOpt(). + UnsafeFromSome(), + ), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), ), @@ -294,7 +303,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x49, 0x0, 0x20} + expectedBytes := []byte{0x4d, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest.