From 9730bc1ca01a4eb06cb7427874fedb9e25b3e32f Mon Sep 17 00:00:00 2001 From: Joost Jager Date: Fri, 28 Oct 2022 08:45:27 +0200 Subject: [PATCH] lnwire: verify failure message length Adds extra checks to make sure the failure message is well-formed. --- lnwire/onion_error.go | 36 ++++++++++++++++++++++++++++++------ 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 593f58e3d..0f252c59d 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -7,6 +7,7 @@ import ( "encoding/binary" "fmt" "io" + "io/ioutil" "github.com/davecgh/go-spew/spew" "github.com/go-errors/errors" @@ -1222,18 +1223,41 @@ func DecodeFailure(r io.Reader, pver uint32) (FailureMessage, error) { // is a 2 byte length followed by the payload itself. var failureLength uint16 if err := ReadElement(r, &failureLength); err != nil { - return nil, fmt.Errorf("unable to read error len: %v", err) - } - if failureLength > FailureMessageLength { - return nil, fmt.Errorf("failure message is too "+ - "long: %v", failureLength) + return nil, fmt.Errorf("unable to read failure len: %w", err) } + failureData := make([]byte, failureLength) if _, err := io.ReadFull(r, failureData); err != nil { return nil, fmt.Errorf("unable to full read payload of "+ - "%v: %v", failureLength, err) + "%v: %w", failureLength, err) } + // Read the padding. + var padLength uint16 + if err := ReadElement(r, &padLength); err != nil { + return nil, fmt.Errorf("unable to read pad len: %w", err) + } + + if _, err := io.CopyN(ioutil.Discard, r, int64(padLength)); err != nil { + return nil, fmt.Errorf("unable to read padding %w", err) + } + + // Verify that we are at the end of the stream now. + scratch := make([]byte, 1) + _, err := r.Read(scratch) + if err != io.EOF { + return nil, fmt.Errorf("unexpected failure bytes") + } + + // Check the total length. Convert to 32 bits to prevent overflow. + totalLength := uint32(padLength) + uint32(failureLength) + if totalLength != FailureMessageLength { + return nil, fmt.Errorf("failure message length is "+ + "incorrect: msg=%v, pad=%v, total=%v", + failureLength, padLength, totalLength) + } + + // Decode the failure message. dataReader := bytes.NewReader(failureData) return DecodeFailureMessage(dataReader, pver)