lnwire: add timestamps to ReplyChannelRange msg

This commit is contained in:
Elle Mouton
2023-09-19 20:53:58 +02:00
parent 4872010779
commit 49a0370dcd
3 changed files with 368 additions and 20 deletions

View File

@@ -2,11 +2,13 @@ package lnwire
import (
"bytes"
"fmt"
"io"
"math"
"sort"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/lightningnetwork/lnd/tlv"
)
// ReplyChannelRange is the response to the QueryChannelRange message. It
@@ -39,6 +41,12 @@ type ReplyChannelRange struct {
// ShortChanIDs is a slice of decoded short channel ID's.
ShortChanIDs []ShortChannelID
// Timestamps is an optional set of timestamps corresponding to the
// latest timestamps for the channel update messages corresponding to
// those referenced in the ShortChanIDs list. If this field is used,
// then the length must match the length of ShortChanIDs.
Timestamps Timestamps
// ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can
// be used to specify optional data such as custom TLV fields.
@@ -53,7 +61,9 @@ type ReplyChannelRange struct {
// NewReplyChannelRange creates a new empty ReplyChannelRange message.
func NewReplyChannelRange() *ReplyChannelRange {
return &ReplyChannelRange{}
return &ReplyChannelRange{
ExtraData: make([]byte, 0),
}
}
// A compile time check to ensure ReplyChannelRange implements the
@@ -80,7 +90,27 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
return err
}
return c.ExtraData.Decode(r)
var tlvRecords ExtraOpaqueData
if err := ReadElements(r, &tlvRecords); err != nil {
return err
}
var timeStamps Timestamps
typeMap, err := tlvRecords.ExtractRecords(&timeStamps)
if err != nil {
return err
}
// Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[TimestampsRecordType]; ok && val == nil {
c.Timestamps = timeStamps
}
if len(tlvRecords) != 0 {
c.ExtraData = tlvRecords
}
return nil
}
// Encode serializes the target ReplyChannelRange into the passed io.Writer
@@ -108,10 +138,48 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
// sorted in place, so we'll do that now. The sorting is applied unless
// we were specifically requested not to for testing purposes.
if !c.noSort {
var scidPreSortIndex map[uint64]int
if len(c.Timestamps) != 0 {
// Sanity check that a timestamp was provided for each
// SCID.
if len(c.Timestamps) != len(c.ShortChanIDs) {
return fmt.Errorf("must provide a timestamp " +
"pair for each of the given SCIDs")
}
// Create a map from SCID value to the original index of
// the SCID in the unsorted list.
scidPreSortIndex = make(
map[uint64]int, len(c.ShortChanIDs),
)
for i, scid := range c.ShortChanIDs {
scidPreSortIndex[scid.ToUint64()] = i
}
// Sanity check that there were no duplicates in the
// SCID list.
if len(scidPreSortIndex) != len(c.ShortChanIDs) {
return fmt.Errorf("scid list should not " +
"contain duplicates")
}
}
// Now sort the SCIDs.
sort.Slice(c.ShortChanIDs, func(i, j int) bool {
return c.ShortChanIDs[i].ToUint64() <
c.ShortChanIDs[j].ToUint64()
})
if len(c.Timestamps) != 0 {
timestamps := make(Timestamps, len(c.Timestamps))
for i, scid := range c.ShortChanIDs {
timestamps[i] = []ChanUpdateTimestamps(
c.Timestamps,
)[scidPreSortIndex[scid.ToUint64()]]
}
c.Timestamps = timestamps
}
}
err := encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs)
@@ -119,6 +187,15 @@ func (c *ReplyChannelRange) Encode(w *bytes.Buffer, pver uint32) error {
return err
}
recordProducers := make([]tlv.RecordProducer, 0, 1)
if len(c.Timestamps) != 0 {
recordProducers = append(recordProducers, &c.Timestamps)
}
err = EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil {
return err
}
return WriteBytes(w, c.ExtraData)
}