mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-30 15:40:59 +02:00
lnwire: verify failure message length
Adds extra checks to make sure the failure message is well-formed.
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/go-errors/errors"
|
"github.com/go-errors/errors"
|
||||||
@@ -1222,18 +1223,41 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
|
|||||||
// is a 2 byte length followed by the payload itself.
|
// is a 2 byte length followed by the payload itself.
|
||||||
var failureLength uint16
|
var failureLength uint16
|
||||||
if err := ReadElement(r, &failureLength); err != nil {
|
if err := ReadElement(r, &failureLength); err != nil {
|
||||||
return nil, fmt.Errorf("unable to read error len: %v", err)
|
return nil, fmt.Errorf("unable to read failure len: %w", err)
|
||||||
}
|
|
||||||
if failureLength > FailureMessageLength {
|
|
||||||
return nil, fmt.Errorf("failure message is too "+
|
|
||||||
"long: %v", failureLength)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
failureData := make([]byte, failureLength)
|
failureData := make([]byte, failureLength)
|
||||||
if _, err := io.ReadFull(r, failureData); err != nil {
|
if _, err := io.ReadFull(r, failureData); err != nil {
|
||||||
return nil, fmt.Errorf("unable to full read payload of "+
|
return nil, fmt.Errorf("unable to full read payload of "+
|
||||||
"%v: %v", failureLength, err)
|
"%v: %w", failureLength, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read the padding.
|
||||||
|
var padLength uint16
|
||||||
|
if err := ReadElement(r, &padLength); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read pad len: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
|
||||||
|
return nil, fmt.Errorf("unable to read padding %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify that we are at the end of the stream now.
|
||||||
|
scratch := make([]byte, 1)
|
||||||
|
_, err := r.Read(scratch)
|
||||||
|
if err != io.EOF {
|
||||||
|
return nil, fmt.Errorf("unexpected failure bytes")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check the total length. Convert to 32 bits to prevent overflow.
|
||||||
|
totalLength := uint32(padLength) + uint32(failureLength)
|
||||||
|
if totalLength != FailureMessageLength {
|
||||||
|
return nil, fmt.Errorf("failure message length is "+
|
||||||
|
"incorrect: msg=%v, pad=%v, total=%v",
|
||||||
|
failureLength, padLength, totalLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the failure message.
|
||||||
dataReader := bytes.NewReader(failureData)
|
dataReader := bytes.NewReader(failureData)
|
||||||
|
|
||||||
return DecodeFailureMessage(dataReader, pver)
|
return DecodeFailureMessage(dataReader, pver)
|
||||||
|
Reference in New Issue
Block a user