diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index f6719ea7d..ffefb5e4e 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -366,7 +366,20 @@ func (f FailUnknownPaymentHash) Error() string { // // NOTE: Part of the Serializable interface. func (f *FailUnknownPaymentHash) Decode(r io.Reader, pver uint32) error { - return ReadElement(r, &f.amount) + err := ReadElement(r, &f.amount) + switch { + // This is an optional tack on that was added later in the protocol. As + // a result, older nodes may not include this value. We'll account for + // this by checking for io.EOF here which means that no bytes were read + // at all. + case err == io.EOF: + return nil + + case err != nil: + return err + } + + return nil } // Encode writes the failure in bytes stream. diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 3cc7d49cc..e05bd2a18 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -33,10 +33,10 @@ var onionFailures = []FailureMessage{ &FailPermanentChannelFailure{}, &FailRequiredChannelFeatureMissing{}, &FailUnknownNextPeer{}, - &FailUnknownPaymentHash{}, &FailIncorrectPaymentAmount{}, &FailFinalExpiryTooSoon{}, + NewFailUnknownPaymentHash(99), NewInvalidOnionVersion(testOnionHash), NewInvalidOnionHmac(testOnionHash), NewInvalidOnionKey(testOnionHash), @@ -167,3 +167,30 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { trueUpdateLength, encodedLen) } } + +// TestFailUnknownPaymentHashOptionalAmount tests that we're able to decode an +// UnknownPaymentHash error that doesn't have the optional amount. This ensures +// we're able to decode FailUnknownPaymentHash messages from older nodes. +func TestFailUnknownPaymentHashOptionalAmount(t *testing.T) { + t.Parallel() + + // Creation an error that is a non-pointer will allow us to skip the + // type assertion for the Serializable interface. As a result, the + // amount body won't be written. + onionError := FailUnknownPaymentHash{} + + var b bytes.Buffer + if err := EncodeFailure(&b, onionError, 0); err != nil { + t.Fatalf("unable to encode failure: %v", err) + } + + onionError2, err := DecodeFailure(bytes.NewReader(b.Bytes()), 0) + if err != nil { + t.Fatalf("unable to decode error: %v", err) + } + + if !reflect.DeepEqual(onionError, onionError) { + t.Fatalf("expected %v, got %v", spew.Sdump(onionError), + spew.Sdump(onionError2)) + } +}