diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index a3fcf25db..db6cf9281 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -163,19 +163,17 @@ type HTLCEntry struct { // Incoming denotes whether we're the receiver or the sender of this // HTLC. - // - // NOTE: this field is the memory representation of the field - // incomingUint. Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. - // - // NOTE: this field is the memory representation of the field amtUint. Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] // CustomBlob is an optional blob that can be used to store information // specific to revocation handling for a custom channel type. CustomBlob tlv.OptionalRecordT[tlv.TlvType5, tlv.Blob] + + // HtlcIndex is the index of the HTLC in the channel. + HtlcIndex tlv.RecordT[tlv.TlvType6, uint16] } // toTlvStream converts an HTLCEntry record into a tlv representation. @@ -186,12 +184,15 @@ func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { h.OutputIndex.Record(), h.Incoming.Record(), h.Amt.Record(), + h.HtlcIndex.Record(), } h.CustomBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType5, tlv.Blob]) { records = append(records, r.Record()) }) + tlv.SortRecords(records) + return tlv.NewStream(records...) } @@ -211,6 +212,9 @@ func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { Amt: tlv.NewRecordT[tlv.TlvType4]( tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6]( + uint16(htlc.HtlcIndex), + ), } if len(htlc.ExtraData) != 0 { @@ -520,6 +524,7 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { htlc.Incoming.Record(), htlc.Amt.Record(), customBlob.Record(), + htlc.HtlcIndex.Record(), } tlvStream, err := tlv.NewStream(records...) diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index 33f20c45c..01088757e 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -52,10 +52,11 @@ var ( CustomBlob: tlv.SomeRecordT( tlv.NewPrimitiveRecord[tlv.TlvType5](blobBytes), ), + HtlcIndex: tlv.NewPrimitiveRecord[tlv.TlvType6, uint16](3), } testHTLCEntryBytes = []byte{ - // Body length 28. - 0x1c, + // Body length 32. + 0x20, // Rhash tlv. 0x0, 0x0, // RefundTimeout tlv. @@ -68,6 +69,8 @@ var ( 0x4, 0x5, 0xfe, 0x0, 0xf, 0x42, 0x40, // Custom blob tlv. 0x5, 0x4, 0x1, 0x2, 0x3, 0x4, + // HLTC index tlv. + 0x6, 0x2, 0x0, 0x03, } localBalance = lnwire.MilliSatoshi(9000) @@ -84,6 +87,7 @@ var ( Htlcs: []HTLC{{ RefundTimeout: testHTLCEntry.RefundTimeout.Val, OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + HtlcIndex: uint64(testHTLCEntry.HtlcIndex.Val), Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( testHTLCEntry.Amt.Val.Int(), @@ -235,7 +239,7 @@ func TestSerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - expectedBytes := []byte{0x3c, 0x0, 0x20} + expectedBytes := []byte{0x40, 0x0, 0x20} expectedBytes = append(expectedBytes, rHashBytes...) // Append the rest. @@ -350,7 +354,7 @@ func TestDerializeHTLCEntries(t *testing.T) { partialBytes := testHTLCEntryBytes[3:] // Write the total length and RHash tlv. - testBytes := append([]byte{0x3c, 0x0, 0x20}, rHashBytes...) + testBytes := append([]byte{0x40, 0x0, 0x20}, rHashBytes...) // Append the rest. testBytes = append(testBytes, partialBytes...)