discovery+lnwire: remove embedding within ReplyChannelRange

In order to prep for allowing TLV extensions for the `ReplyChannelRange`
and `QueryChannelRange` messages, we'll need to remove the struct
embedding as is. If we don't remove this, then we'll attempt to decode
TLV extensions from both the embedded and outer struct.

All relevant call sites have been updated to reflect this minor change.
This commit is contained in:
Olaoluwa Osuntokun
2020-01-27 17:30:54 -08:00
committed by Johan T. Halseth
parent 466c079bbe
commit dd6f0ba931
6 changed files with 108 additions and 74 deletions

View File

@@ -1,14 +1,29 @@
package lnwire
import "io"
import (
"io"
"math"
"github.com/btcsuite/btcd/chaincfg/chainhash"
)
// ReplyChannelRange is the response to the QueryChannelRange message. It
// includes the original query, and the next streaming chunk of encoded short
// channel ID's as the response. We'll also include a byte that indicates if
// this is the last query in the message.
type ReplyChannelRange struct {
// QueryChannelRange is the corresponding query to this response.
QueryChannelRange
// ChainHash denotes the target chain that we're trying to synchronize
// channel graph state for.
ChainHash chainhash.Hash
// FirstBlockHeight is the first block in the query range. The
// responder should send all new short channel IDs from this block
// until this block plus the specified number of blocks.
FirstBlockHeight uint32
// NumBlocks is the number of blocks beyond the first block that short
// channel ID's should be sent for.
NumBlocks uint32
// Complete denotes if this is the conclusion of the set of streaming
// responses to the original query.
@@ -43,17 +58,21 @@ var _ Message = (*ReplyChannelRange)(nil)
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
err := c.QueryChannelRange.Decode(r, pver)
err := ReadElements(r,
c.ChainHash[:],
&c.FirstBlockHeight,
&c.NumBlocks,
&c.Complete,
)
if err != nil {
return err
}
if err := ReadElements(r, &c.Complete); err != nil {
c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r)
if err != nil {
return err
}
c.EncodingType, c.ShortChanIDs, err = decodeShortChanIDs(r)
return err
}
@@ -62,15 +81,22 @@ func (c *ReplyChannelRange) Decode(r io.Reader, pver uint32) error {
//
// This is part of the lnwire.Message interface.
func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error {
if err := c.QueryChannelRange.Encode(w, pver); err != nil {
err := WriteElements(w,
c.ChainHash[:],
c.FirstBlockHeight,
c.NumBlocks,
c.Complete,
)
if err != nil {
return err
}
if err := WriteElements(w, c.Complete); err != nil {
err = encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
if err != nil {
return err
}
return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort)
return nil
}
// MsgType returns the integer uniquely identifying this message type on the
@@ -88,3 +114,14 @@ func (c *ReplyChannelRange) MsgType() MessageType {
func (c *ReplyChannelRange) MaxPayloadLength(uint32) uint32 {
return MaxMessagePayload
}
// LastBlockHeight returns the last block height covered by the range of a
// QueryChannelRange message.
func (c *ReplyChannelRange) LastBlockHeight() uint32 {
// Handle overflows by casting to uint64.
lastBlockHeight := uint64(c.FirstBlockHeight) + uint64(c.NumBlocks) - 1
if lastBlockHeight > math.MaxUint32 {
return math.MaxUint32
}
return uint32(lastBlockHeight)
}