mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-28 22:50:58 +02:00
lnwire: verify failure message length
Adds extra checks to make sure the failure message is well-formed.
This commit is contained in:
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/go-errors/errors"
|
||||
@@ -1222,18 +1223,41 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) {
|
||||
// is a 2 byte length followed by the payload itself.
|
||||
var failureLength uint16
|
||||
if err := ReadElement(r, &failureLength); err != nil {
|
||||
return nil, fmt.Errorf("unable to read error len: %v", err)
|
||||
}
|
||||
if failureLength > FailureMessageLength {
|
||||
return nil, fmt.Errorf("failure message is too "+
|
||||
"long: %v", failureLength)
|
||||
return nil, fmt.Errorf("unable to read failure len: %w", err)
|
||||
}
|
||||
|
||||
failureData := make([]byte, failureLength)
|
||||
if _, err := io.ReadFull(r, failureData); err != nil {
|
||||
return nil, fmt.Errorf("unable to full read payload of "+
|
||||
"%v: %v", failureLength, err)
|
||||
"%v: %w", failureLength, err)
|
||||
}
|
||||
|
||||
// Read the padding.
|
||||
var padLength uint16
|
||||
if err := ReadElement(r, &padLength); err != nil {
|
||||
return nil, fmt.Errorf("unable to read pad len: %w", err)
|
||||
}
|
||||
|
||||
if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil {
|
||||
return nil, fmt.Errorf("unable to read padding %w", err)
|
||||
}
|
||||
|
||||
// Verify that we are at the end of the stream now.
|
||||
scratch := make([]byte, 1)
|
||||
_, err := r.Read(scratch)
|
||||
if err != io.EOF {
|
||||
return nil, fmt.Errorf("unexpected failure bytes")
|
||||
}
|
||||
|
||||
// Check the total length. Convert to 32 bits to prevent overflow.
|
||||
totalLength := uint32(padLength) + uint32(failureLength)
|
||||
if totalLength != FailureMessageLength {
|
||||
return nil, fmt.Errorf("failure message length is "+
|
||||
"incorrect: msg=%v, pad=%v, total=%v",
|
||||
failureLength, padLength, totalLength)
|
||||
}
|
||||
|
||||
// Decode the failure message.
|
||||
dataReader := bytes.NewReader(failureData)
|
||||
|
||||
return DecodeFailureMessage(dataReader, pver)
|
||||
|
Reference in New Issue
Block a user