diff --git a/lnwire/reply_channel_range.go b/lnwire/reply_channel_range.go index 18432392b..430606025 100644 --- a/lnwire/reply_channel_range.go +++ b/lnwire/reply_channel_range.go @@ -21,6 +21,12 @@ type ReplyChannelRange struct { // ShortChanIDs is a slice of decoded short channel ID's. ShortChanIDs []ShortChannelID + + // noSort indicates whether or not to sort the short channel ids before + // writing them out. + // + // NOTE: This should only be used for testing. + noSort bool } // NewReplyChannelRange creates a new empty ReplyChannelRange message. @@ -64,7 +70,7 @@ func (c *ReplyChannelRange) Encode(w io.Writer, pver uint32) error { return err } - return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, false) + return encodeShortChanIDs(w, c.EncodingType, c.ShortChanIDs, c.noSort) } // MsgType returns the integer uniquely identifying this message type on the diff --git a/lnwire/reply_channel_range_test.go b/lnwire/reply_channel_range_test.go new file mode 100644 index 000000000..c9c1cfc45 --- /dev/null +++ b/lnwire/reply_channel_range_test.go @@ -0,0 +1,34 @@ +package lnwire + +import ( + "bytes" + "testing" +) + +// TestReplyChannelRangeUnsorted tests that decoding a ReplyChannelRange request +// that contains duplicate or unsorted ids returns an ErrUnsortedSIDs failure. +func TestReplyChannelRangeUnsorted(t *testing.T) { + for _, test := range unsortedSidTests { + test := test + t.Run(test.name, func(t *testing.T) { + req := &ReplyChannelRange{ + EncodingType: test.encType, + ShortChanIDs: test.sids, + noSort: true, + } + + var b bytes.Buffer + err := req.Encode(&b, 0) + if err != nil { + t.Fatalf("unable to encode req: %v", err) + } + + var req2 ReplyChannelRange + err = req2.Decode(bytes.NewReader(b.Bytes()), 0) + if _, ok := err.(ErrUnsortedSIDs); !ok { + t.Fatalf("expected ErrUnsortedSIDs, got: %T", + err) + } + }) + } +}