lnwire: verify failure message length

Adds extra checks to make sure the failure message
is well-formed.
This commit is contained in:
Joost Jager
2022-10-28 08:45:27 +02:00
parent e9440a24a2
commit 9730bc1ca0

View File

@@ -7,6 +7,7 @@ import (
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"github.com/davecgh/go-spew/spew"
"github.com/go-errors/errors"
@@ -1222,18 +1223,41 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
// is a 2 byte length followed by the payload itself.
var failureLength uint16
if err := ReadElement(r, &failureLength); err != nil {
return nil, fmt.Errorf("unable to read error len: %v", err)
}
if failureLength > FailureMessageLength {
return nil, fmt.Errorf("failure message is too "+
"long: %v", failureLength)
return nil, fmt.Errorf("unable to read failure len: %w", err)
}
failureData := make([]byte, failureLength)
if _, err := io.ReadFull(r, failureData); err != nil {
return nil, fmt.Errorf("unable to full read payload of "+
"%v: %v", failureLength, err)
"%v: %w", failureLength, err)
}
// Read the padding.
var padLength uint16
if err := ReadElement(r, &padLength); err != nil {
return nil, fmt.Errorf("unable to read pad len: %w", err)
}
if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
return nil, fmt.Errorf("unable to read padding %w", err)
}
// Verify that we are at the end of the stream now.
scratch := make([]byte, 1)
_, err := r.Read(scratch)
if err != io.EOF {
return nil, fmt.Errorf("unexpected failure bytes")
}
// Check the total length. Convert to 32 bits to prevent overflow.
totalLength := uint32(padLength) + uint32(failureLength)
if totalLength != FailureMessageLength {
return nil, fmt.Errorf("failure message length is "+
"incorrect: msg=%v, pad=%v, total=%v",
failureLength, padLength, totalLength)
}
// Decode the failure message.
dataReader := bytes.NewReader(failureData)
return DecodeFailureMessage(dataReader, pver)