diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index 5cfe39284..7615fb077 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -579,15 +579,21 @@ func assertRevocationLogEntryEqual(t *testing.T, c *ChannelCommitment, require.Equal(t, len(r.HTLCEntries), len(c.Htlcs), "HTLCs len mismatch") for i, rHtlc := range r.HTLCEntries { cHtlc := c.Htlcs[i] - require.Equal(t, rHtlc.RHash, cHtlc.RHash, "RHash mismatch") - require.Equal(t, rHtlc.Amt, cHtlc.Amt.ToSatoshis(), - "Amt mismatch") - require.Equal(t, rHtlc.RefundTimeout, cHtlc.RefundTimeout, - "RefundTimeout mismatch") - require.EqualValues(t, rHtlc.OutputIndex, cHtlc.OutputIndex, - "OutputIndex mismatch") - require.Equal(t, rHtlc.Incoming, cHtlc.Incoming, - "Incoming mismatch") + require.Equal(t, rHtlc.RHash.Val[:], cHtlc.RHash[:], "RHash") + require.Equal( + t, rHtlc.Amt.Val.Int(), cHtlc.Amt.ToSatoshis(), "Amt", + ) + require.Equal( + t, rHtlc.RefundTimeout.Val, cHtlc.RefundTimeout, + "RefundTimeout", + ) + require.EqualValues( + t, rHtlc.OutputIndex.Val, cHtlc.OutputIndex, + "OutputIndex", + ) + require.Equal( + t, rHtlc.Incoming.Val, cHtlc.Incoming, "Incoming", + ) } } diff --git a/channeldb/revocation_log.go b/channeldb/revocation_log.go index f062ac086..cfe65095f 100644 --- a/channeldb/revocation_log.go +++ b/channeldb/revocation_log.go @@ -54,6 +54,82 @@ var ( ErrOutputIndexTooBig = errors.New("output index is over uint16") ) +// SparsePayHash is a type alias for a 32 byte array, which when serialized is +// able to save some space by not including an empty payment hash on disk. +type SparsePayHash [32]byte + +// NewSparsePayHash creates a new SparsePayHash from a 32 byte array. +func NewSparsePayHash(rHash [32]byte) SparsePayHash { + return SparsePayHash(rHash) +} + +// Record returns a tlv record for the SparsePayHash. +func (s *SparsePayHash) Record() tlv.Record { + // We use a zero for the type here, as this'll be used along with the + // RecordT type. + return tlv.MakeDynamicRecord( + 0, s, s.hashLen, + sparseHashEncoder, sparseHashDecoder, + ) +} + +// hashLen is used by MakeDynamicRecord to return the size of the RHash. +// +// NOTE: for zero hash, we return a length 0. +func (s *SparsePayHash) hashLen() uint64 { + if bytes.Equal(s[:], lntypes.ZeroHash[:]) { + return 0 + } + + return 32 +} + +// sparseHashEncoder is the customized encoder which skips encoding the empty +// hash. +func sparseHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the value is an empty hash, we will skip encoding it. + if bytes.Equal(v[:], lntypes.ZeroHash[:]) { + return nil + } + + vArray := (*[32]byte)(v) + + return tlv.EBytes32(w, vArray, buf) +} + +// sparseHashDecoder is the customized decoder which skips decoding the empty +// hash. +func sparseHashDecoder(r io.Reader, val interface{}, buf *[8]byte, + l uint64) error { + + v, ok := val.(*SparsePayHash) + if !ok { + return tlv.NewTypeForEncodingErr(val, "SparsePayHash") + } + + // If the length is zero, we will skip encoding the empty hash. + if l == 0 { + return nil + } + + vArray := (*[32]byte)(v) + + if err := tlv.DBytes32(r, vArray, buf, 32); err != nil { + return err + } + + vHash := SparsePayHash(*vArray) + + v = &vHash + + return nil +} + // HTLCEntry specifies the minimal info needed to be stored on disk for ALL the // historical HTLCs, which is useful for constructing RevocationLog when a // breach is detected. @@ -72,118 +148,62 @@ var ( // made into tlv records without further conversion. type HTLCEntry struct { // RHash is the payment hash of the HTLC. - RHash [32]byte + RHash tlv.RecordT[tlv.TlvType0, SparsePayHash] // RefundTimeout is the absolute timeout on the HTLC that the sender // must wait before reclaiming the funds in limbo. - RefundTimeout uint32 + RefundTimeout tlv.RecordT[tlv.TlvType1, uint32] // OutputIndex is the output index for this particular HTLC output // within the commitment transaction. // // NOTE: we use uint16 instead of int32 here to save us 2 bytes, which // gives us a max number of HTLCs of 65K. - OutputIndex uint16 + OutputIndex tlv.RecordT[tlv.TlvType2, uint16] // Incoming denotes whether we're the receiver or the sender of this // HTLC. // // NOTE: this field is the memory representation of the field // incomingUint. - Incoming bool + Incoming tlv.RecordT[tlv.TlvType3, bool] // Amt is the amount of satoshis this HTLC escrows. // // NOTE: this field is the memory representation of the field amtUint. - Amt btcutil.Amount - - // amtTlv is the uint64 format of Amt. This field is created so we can - // easily make it into a tlv record and save it to disk. - // - // NOTE: we keep this field for accounting purpose only. If the disk - // space becomes an issue, we could delete this field to save us extra - // 8 bytes. - amtTlv uint64 - - // incomingTlv is the uint8 format of Incoming. This field is created - // so we can easily make it into a tlv record and save it to disk. - incomingTlv uint8 -} - -// RHashLen is used by MakeDynamicRecord to return the size of the RHash. -// -// NOTE: for zero hash, we return a length 0. -func (h *HTLCEntry) RHashLen() uint64 { - if h.RHash == lntypes.ZeroHash { - return 0 - } - return 32 -} - -// RHashEncoder is the customized encoder which skips encoding the empty hash. -func RHashEncoder(w io.Writer, val interface{}, buf *[8]byte) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the value is an empty hash, we will skip encoding it. - if *v == lntypes.ZeroHash { - return nil - } - - return tlv.EBytes32(w, v, buf) -} - -// RHashDecoder is the customized decoder which skips decoding the empty hash. -func RHashDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { - v, ok := val.(*[32]byte) - if !ok { - return tlv.NewTypeForEncodingErr(val, "RHash") - } - - // If the length is zero, we will skip encoding the empty hash. - if l == 0 { - return nil - } - - return tlv.DBytes32(r, v, buf, 32) + Amt tlv.RecordT[tlv.TlvType4, tlv.BigSizeT[btcutil.Amount]] } // toTlvStream converts an HTLCEntry record into a tlv representation. func (h *HTLCEntry) toTlvStream() (*tlv.Stream, error) { - const ( - // A set of tlv type definitions used to serialize htlc entries - // to the database. We define it here instead of the head of - // the file to avoid naming conflicts. - // - // NOTE: A migration should be added whenever this list - // changes. - rHashType tlv.Type = 0 - refundTimeoutType tlv.Type = 1 - outputIndexType tlv.Type = 2 - incomingType tlv.Type = 3 - amtType tlv.Type = 4 - ) - return tlv.NewStream( - tlv.MakeDynamicRecord( - rHashType, &h.RHash, h.RHashLen, - RHashEncoder, RHashDecoder, - ), - tlv.MakePrimitiveRecord( - refundTimeoutType, &h.RefundTimeout, - ), - tlv.MakePrimitiveRecord( - outputIndexType, &h.OutputIndex, - ), - tlv.MakePrimitiveRecord(incomingType, &h.incomingTlv), - // We will save 3 bytes if the amount is less or equal to - // 4,294,967,295 msat, or roughly 0.043 bitcoin. - tlv.MakeBigSizeRecord(amtType, &h.amtTlv), + h.RHash.Record(), + h.RefundTimeout.Record(), + h.OutputIndex.Record(), + h.Incoming.Record(), + h.Amt.Record(), ) } +// NewHTLCEntryFromHTLC creates a new HTLCEntry from an HTLC. +func NewHTLCEntryFromHTLC(htlc HTLC) *HTLCEntry { + return &HTLCEntry{ + RHash: tlv.NewRecordT[tlv.TlvType0]( + NewSparsePayHash(htlc.RHash), + ), + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1]( + htlc.RefundTimeout, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlc.OutputIndex), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](htlc.Incoming), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(htlc.Amt.ToSatoshis()), + ), + } +} + // RevocationLog stores the info needed to construct a breach retribution. Its // fields can be viewed as a subset of a ChannelCommitment's. In the database, // all historical versions of the RevocationLog are saved using the @@ -265,13 +285,7 @@ func putRevocationLog(bucket kvdb.RwBucket, commit *ChannelCommitment, return ErrOutputIndexTooBig } - entry := &HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - Incoming: htlc.Incoming, - OutputIndex: uint16(htlc.OutputIndex), - Amt: htlc.Amt.ToSatoshis(), - } + entry := NewHTLCEntryFromHTLC(htlc) rl.HTLCEntries = append(rl.HTLCEntries, entry) } @@ -351,14 +365,6 @@ func serializeRevocationLog(w io.Writer, rl *RevocationLog) error { // format. func serializeHTLCEntries(w io.Writer, htlcs []*HTLCEntry) error { for _, htlc := range htlcs { - // Patch the incomingTlv field. - if htlc.Incoming { - htlc.incomingTlv = 1 - } - - // Patch the amtTlv field. - htlc.amtTlv = uint64(htlc.Amt) - // Create the tlv stream. tlvStream, err := htlc.toTlvStream() if err != nil { @@ -447,14 +453,6 @@ func deserializeHTLCEntries(r io.Reader) ([]*HTLCEntry, error) { return nil, err } - // Patch the Incoming field. - if htlc.incomingTlv == 1 { - htlc.Incoming = true - } - - // Patch the Amt field. - htlc.Amt = btcutil.Amount(htlc.amtTlv) - // Append the entry. htlcs = append(htlcs, &htlc) } @@ -469,6 +467,7 @@ func writeTlvStream(w io.Writer, s *tlv.Stream) error { if err := s.Encode(&b); err != nil { return err } + // Write the stream's length as a varint. err := tlv.WriteVarInt(w, uint64(b.Len()), &[8]byte{}) if err != nil { diff --git a/channeldb/revocation_log_test.go b/channeldb/revocation_log_test.go index fc5303a48..b1a96e7c3 100644 --- a/channeldb/revocation_log_test.go +++ b/channeldb/revocation_log_test.go @@ -34,12 +34,16 @@ var ( } testHTLCEntry = HTLCEntry{ - RefundTimeout: 740_000, - OutputIndex: 10, - Incoming: true, - Amt: 1000_000, - amtTlv: 1000_000, - incomingTlv: 1, + RefundTimeout: tlv.NewPrimitiveRecord[tlv.TlvType1, uint32]( + 740_000, + ), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16]( + 10, + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(1_000_000)), + ), } testHTLCEntryBytes = []byte{ // Body length 23. @@ -68,11 +72,11 @@ var ( CommitTx: channels.TestFundingTx, CommitSig: bytes.Repeat([]byte{1}, 71), Htlcs: []HTLC{{ - RefundTimeout: testHTLCEntry.RefundTimeout, - OutputIndex: int32(testHTLCEntry.OutputIndex), - Incoming: testHTLCEntry.Incoming, + RefundTimeout: testHTLCEntry.RefundTimeout.Val, + OutputIndex: int32(testHTLCEntry.OutputIndex.Val), + Incoming: testHTLCEntry.Incoming.Val, Amt: lnwire.NewMSatFromSatoshis( - testHTLCEntry.Amt, + testHTLCEntry.Amt.Val.Int(), ), }}, } @@ -193,11 +197,6 @@ func TestSerializeHTLCEntriesEmptyRHash(t *testing.T) { // Copy the testHTLCEntry. entry := testHTLCEntry - // Set the internal fields to empty values so we can test the bytes are - // padded. - entry.incomingTlv = 0 - entry.amtTlv = 0 - // Write the tlv stream. buf := bytes.NewBuffer([]byte{}) err := serializeHTLCEntries(buf, []*HTLCEntry{&entry}) @@ -215,7 +214,7 @@ func TestSerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // @@ -330,7 +329,7 @@ func TestDerializeHTLCEntries(t *testing.T) { // Create a fake rHash. rHashBytes := bytes.Repeat([]byte{10}, 32) - copy(entry.RHash[:], rHashBytes) + copy(entry.RHash.Val[:], rHashBytes) // Construct the serialized bytes. // diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 804c8fd5f..c2953b5bc 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2673,8 +2673,8 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, // then from the PoV of the remote commitment state, they're the // receiver of this HTLC. scriptInfo, err := genHtlcScript( - chanState.ChanType, htlc.Incoming, false, - htlc.RefundTimeout, htlc.RHash, keyRing, + chanState.ChanType, htlc.Incoming.Val, false, + htlc.RefundTimeout.Val, htlc.RHash.Val, keyRing, ) if err != nil { return emptyRetribution, err @@ -2687,7 +2687,7 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, WitnessScript: scriptInfo.WitnessScriptToSign(), Output: &wire.TxOut{ PkScript: scriptInfo.PkScript(), - Value: int64(htlc.Amt), + Value: int64(htlc.Amt.Val.Int()), }, HashType: sweepSigHash(chanState.ChanType), } @@ -2720,10 +2720,10 @@ func createHtlcRetribution(chanState *channeldb.OpenChannel, SignDesc: signDesc, OutPoint: wire.OutPoint{ Hash: commitHash, - Index: uint32(htlc.OutputIndex), + Index: uint32(htlc.OutputIndex.Val), }, SecondLevelWitnessScript: secondLevelWitnessScript, - IsIncoming: htlc.Incoming, + IsIncoming: htlc.Incoming.Val, SecondLevelTapTweak: secondLevelTapTweak, }, nil } @@ -2885,13 +2885,7 @@ func createBreachRetributionLegacy(revokedLog *channeldb.ChannelCommitment, continue } - entry := &channeldb.HTLCEntry{ - RHash: htlc.RHash, - RefundTimeout: htlc.RefundTimeout, - OutputIndex: uint16(htlc.OutputIndex), - Incoming: htlc.Incoming, - Amt: htlc.Amt.ToSatoshis(), - } + entry := channeldb.NewHTLCEntryFromHTLC(htlc) hr, err := createHtlcRetribution( chanState, keyRing, commitHash, commitmentSecret, leaseExpiry, entry, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index d08eeea29..84c694c61 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -9950,9 +9950,11 @@ func TestCreateHtlcRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: testAmt, - Incoming: true, - OutputIndex: 1, + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(testAmt), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2, uint16](1), } // Create the htlc retribution. @@ -9966,8 +9968,8 @@ func TestCreateHtlcRetribution(t *testing.T) { // Check the fields have expected values. require.EqualValues(t, testAmt, hr.SignDesc.Output.Value) require.Equal(t, commitHash, hr.OutPoint.Hash) - require.EqualValues(t, htlc.OutputIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.EqualValues(t, htlc.OutputIndex.Val, hr.OutPoint.Index) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } // TestCreateBreachRetribution checks that `createBreachRetribution` behaves as @@ -10007,9 +10009,13 @@ func TestCreateBreachRetribution(t *testing.T) { aliceChannel.channelState, ) htlc := &channeldb.HTLCEntry{ - Amt: btcutil.Amount(testAmt), - Incoming: true, - OutputIndex: uint16(htlcIndex), + Amt: tlv.NewRecordT[tlv.TlvType4]( + tlv.NewBigSizeT(btcutil.Amount(testAmt)), + ), + Incoming: tlv.NewPrimitiveRecord[tlv.TlvType3](true), + OutputIndex: tlv.NewPrimitiveRecord[tlv.TlvType2]( + uint16(htlcIndex), + ), } // Create a dummy revocation log. @@ -10136,11 +10142,12 @@ func TestCreateBreachRetribution(t *testing.T) { require.Equal(t, remote, br.RemoteOutpoint) for _, hr := range br.HtlcRetributions { - require.EqualValues(t, testAmt, - hr.SignDesc.Output.Value) + require.EqualValues( + t, testAmt, hr.SignDesc.Output.Value, + ) require.Equal(t, commitHash, hr.OutPoint.Hash) require.EqualValues(t, htlcIndex, hr.OutPoint.Index) - require.Equal(t, htlc.Incoming, hr.IsIncoming) + require.Equal(t, htlc.Incoming.Val, hr.IsIncoming) } }