From 40981dfdab262b1b55fc8c4bebba4483f305410d Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Tue, 7 Nov 2023 11:53:27 +0200 Subject: [PATCH] multi: use ChannelUpdate interface for failure messages --- htlcswitch/interceptable_switch.go | 2 +- htlcswitch/link.go | 8 +- htlcswitch/switch.go | 4 +- htlcswitch/switch_test.go | 13 ++- lnrpc/routerrpc/router_backend.go | 103 ++++++++++++++---- lnwire/onion_error.go | 145 +++++++++++++++++--------- lnwire/onion_error_test.go | 18 ++-- routing/missioncontrol_test.go | 6 +- routing/result_interpretation_test.go | 7 +- routing/router.go | 14 +-- routing/router_test.go | 20 ++-- 11 files changed, 221 insertions(+), 119 deletions(-) diff --git a/htlcswitch/interceptable_switch.go b/htlcswitch/interceptable_switch.go index 5517eb82d..07e22bec5 100644 --- a/htlcswitch/interceptable_switch.go +++ b/htlcswitch/interceptable_switch.go @@ -702,7 +702,7 @@ func (f *interceptedForward) FailWithCode(code lnwire.FailCode) error { return err } - failureMsg = lnwire.NewExpiryTooSoon(*update) + failureMsg = lnwire.NewExpiryTooSoon(update) default: return ErrUnsupportedFailureCode diff --git a/htlcswitch/link.go b/htlcswitch/link.go index 5934048eb..f5516d4eb 100644 --- a/htlcswitch/link.go +++ b/htlcswitch/link.go @@ -3055,7 +3055,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewFeeInsufficient(amtToForward, *upd) + return lnwire.NewFeeInsufficient(amtToForward, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -3084,7 +3084,7 @@ func (l *channelLink) CheckHtlcForward(payHash [32]byte, // date with our current policy. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { return lnwire.NewIncorrectCltvExpiry( - incomingTimeout, *upd, + incomingTimeout, upd, ) } failure := l.createFailureWithUpdate(false, originalScid, cb) @@ -3132,7 +3132,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, // As part of the returned error, we'll send our latest routing // policy so the sending node obtains the most up to date data. cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewAmountBelowMinimum(amt, *upd) + return lnwire.NewAmountBelowMinimum(amt, upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) @@ -3162,7 +3162,7 @@ func (l *channelLink) canSendHtlc(policy models.ForwardingPolicy, timeout, heightNow) cb := func(upd *lnwire.ChannelUpdate1) lnwire.FailureMessage { - return lnwire.NewExpiryTooSoon(*upd) + return lnwire.NewExpiryTooSoon(upd) } failure := l.createFailureWithUpdate(false, originalScid, cb) return NewLinkError(failure) diff --git a/htlcswitch/switch.go b/htlcswitch/switch.go index 8d9eaf0b5..02bda7999 100644 --- a/htlcswitch/switch.go +++ b/htlcswitch/switch.go @@ -2777,7 +2777,9 @@ func (s *Switch) handlePacketAdd(packet *htlcPacket, // sure that HTLC is not from the source node. if s.cfg.RejectHTLC { failure := NewDetailedLinkError( - &lnwire.FailChannelDisabled{}, + &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, OutgoingFailureForwardsDisabled, ) diff --git a/htlcswitch/switch_test.go b/htlcswitch/switch_test.go index 825ee6c65..d9e68f355 100644 --- a/htlcswitch/switch_test.go +++ b/htlcswitch/switch_test.go @@ -3382,7 +3382,10 @@ func TestHtlcNotifier(t *testing.T) { return getThreeHopEvents( channels, htlcID, ts, htlc, hops, &LinkError{ - msg: &lnwire.FailChannelDisabled{}, + //nolint:lll + msg: &lnwire.FailChannelDisabled{ + Update: &lnwire.ChannelUpdate1{}, + }, FailureDetail: OutgoingFailureForwardsDisabled, }, preimage, @@ -5045,7 +5048,7 @@ func testSwitchForwardFailAlias(t *testing.T, zeroConf bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, aliceAlias, failMsg.Update.ShortChannelID) + require.Equal(t, aliceAlias, failMsg.Update.SCID()) case <-s2.quit: t.Fatal("switch shutting down, failed to forward packet") } @@ -5228,7 +5231,7 @@ func testSwitchAliasFailAdd(t *testing.T, zeroConf, private, useAlias bool) { msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive fail packet") } @@ -5428,7 +5431,7 @@ func testSwitchHandlePacketForward(t *testing.T, zeroConf, private, msg := failPacket.linkFailure.msg failMsg, ok := msg.(*lnwire.FailAmountBelowMinimum) require.True(t, ok) - require.Equal(t, outgoingChanID, failMsg.Update.ShortChannelID) + require.Equal(t, outgoingChanID, failMsg.Update.SCID()) case <-s.quit: t.Fatal("switch shutting down, failed to receive failure") } @@ -5583,7 +5586,7 @@ func testSwitchAliasInterceptFail(t *testing.T, zeroConf bool) { failureMsg, ok := failure.(*lnwire.FailTemporaryChannelFailure) require.True(t, ok) - failScid := failureMsg.Update.ShortChannelID + failScid := failureMsg.Update.SCID() isAlias := failScid == aliceAlias || failScid == aliceAlias2 require.True(t, isAlias) diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index b37df7218..49c53d3c5 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -1523,8 +1523,14 @@ func marshallWireError(msg lnwire.FailureMessage, response.Code = lnrpc.Failure_INVALID_REALM case *lnwire.FailExpiryTooSoon: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_EXPIRY_TOO_SOON - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 case *lnwire.FailExpiryTooFar: response.Code = lnrpc.Failure_EXPIRY_TOO_FAR @@ -1542,28 +1548,58 @@ func marshallWireError(msg lnwire.FailureMessage, response.OnionSha_256 = onionErr.OnionSHA256[:] case *lnwire.FailAmountBelowMinimum: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_AMOUNT_BELOW_MINIMUM - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailFeeInsufficient: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_FEE_INSUFFICIENT - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.HtlcMsat = uint64(onionErr.HtlcMsat) case *lnwire.FailIncorrectCltvExpiry: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_INCORRECT_CLTV_EXPIRY - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.CltvExpiry = onionErr.CltvExpiry case *lnwire.FailChannelDisabled: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_CHANNEL_DISABLED - response.ChannelUpdate = marshallChannelUpdate(&onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 response.Flags = uint32(onionErr.Flags) case *lnwire.FailTemporaryChannelFailure: + update1, update2, err := marshallChannelUpdate(onionErr.Update) + if err != nil { + return err + } + response.Code = lnrpc.Failure_TEMPORARY_CHANNEL_FAILURE - response.ChannelUpdate = marshallChannelUpdate(onionErr.Update) + response.ChannelUpdate = update1 + response.ChannelUpdate_2 = update2 case *lnwire.FailRequiredNodeFeatureMissing: response.Code = lnrpc.Failure_REQUIRED_NODE_FEATURE_MISSING @@ -1605,24 +1641,49 @@ func marshallWireError(msg lnwire.FailureMessage, // marshallChannelUpdate marshalls a channel update as received over the wire to // the router rpc format. -func marshallChannelUpdate(update *lnwire.ChannelUpdate1) *lnrpc.ChannelUpdate { +func marshallChannelUpdate(update lnwire.ChannelUpdate) (*lnrpc.ChannelUpdate, + *lnrpc.ChannelUpdate2, error) { + if update == nil { - return nil + return nil, nil, nil } - return &lnrpc.ChannelUpdate{ - Signature: update.Signature.RawBytes(), - ChainHash: update.ChainHash[:], - ChanId: update.ShortChannelID.ToUint64(), - Timestamp: update.Timestamp, - MessageFlags: uint32(update.MessageFlags), - ChannelFlags: uint32(update.ChannelFlags), - TimeLockDelta: uint32(update.TimeLockDelta), - HtlcMinimumMsat: uint64(update.HtlcMinimumMsat), - BaseFee: update.BaseFee, - FeeRate: update.FeeRate, - HtlcMaximumMsat: uint64(update.HtlcMaximumMsat), - ExtraOpaqueData: update.ExtraOpaqueData, + switch upd := update.(type) { + case *lnwire.ChannelUpdate1: + return &lnrpc.ChannelUpdate{ + Signature: upd.Signature.RawBytes(), + ChainHash: upd.ChainHash[:], + ChanId: upd.ShortChannelID.ToUint64(), + Timestamp: upd.Timestamp, + MessageFlags: uint32(upd.MessageFlags), + ChannelFlags: uint32(upd.ChannelFlags), + TimeLockDelta: uint32(upd.TimeLockDelta), + HtlcMinimumMsat: uint64(upd.HtlcMinimumMsat), + BaseFee: upd.BaseFee, + FeeRate: upd.FeeRate, + HtlcMaximumMsat: uint64(upd.HtlcMaximumMsat), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil, nil + + case *lnwire.ChannelUpdate2: + return nil, &lnrpc.ChannelUpdate2{ + Signature: upd.Signature.RawBytes(), + ChainHash: upd.ChainHash.Val[:], + ChanId: upd.ShortChannelID.Val.ToUint64(), + BlockHeight: upd.BlockHeight.Val, + DisabledFlags: uint32(upd.DisabledFlags.Val), + Direction: upd.SecondPeer.IsSome(), + TimeLockDelta: uint32(upd.CLTVExpiryDelta.Val), + BaseFee: upd.FeeBaseMsat.Val, + FeeRate: upd.FeeProportionalMillionths.Val, + HtlcMinimumMsat: uint64(upd.HTLCMinimumMsat.Val), + HtlcMaximumMsat: uint64(upd.HTLCMaximumMsat.Val), + ExtraOpaqueData: upd.ExtraOpaqueData, + }, nil + + default: + return nil, nil, fmt.Errorf("unhandled implementation of "+ + "lnwire.ChannelUpdate: %T", update) } } diff --git a/lnwire/onion_error.go b/lnwire/onion_error.go index 9a669271d..1def2d9f9 100644 --- a/lnwire/onion_error.go +++ b/lnwire/onion_error.go @@ -600,7 +600,7 @@ func (f *FailInvalidOnionKey) Error() string { // unable to pull out a fully valid version, then we'll fall back to the // regular parsing mechanism which includes the length prefix an NO type byte. func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, - chanUpdate *ChannelUpdate1, pver uint32) error { + pver uint32) (ChannelUpdate, error) { // Instantiate a LimitReader because there may be additional data // present after the channel update. Without limiting the stream, the @@ -613,28 +613,50 @@ func parseChannelUpdateCompatibilityMode(reader io.Reader, length uint16, // buffer so we can decide how to parse the remainder of it. maybeTypeBytes, err := r.Peek(2) if err != nil { - return err + return nil, err } - // Some nodes well prefix an additional set of bytes in front of their - // channel updates. These bytes will _almost_ always be 258 or the type - // of the ChannelUpdate message. - typeInt := binary.BigEndian.Uint16(maybeTypeBytes) - if typeInt == MsgChannelUpdate { + var ( + typeInt = binary.BigEndian.Uint16(maybeTypeBytes) + chanUpdate ChannelUpdate + hasTypeBytes bool + ) + switch typeInt { + case MsgChannelUpdate: + chanUpdate = &ChannelUpdate1{} + hasTypeBytes = true + + case MsgChannelUpdate2: + chanUpdate = &ChannelUpdate2{} + hasTypeBytes = true + + default: + // Some older nodes will not have the type prefix in front of + // their channel updates as there was initially some ambiguity + // in the spec. This should ony be the case for the + // ChannelUpdate2 message. + chanUpdate = &ChannelUpdate1{} + } + + if hasTypeBytes { // At this point it's likely the case that this is a channel // update message with its type prefixed, so we'll snip off the // first two bytes and parse it as normal. var throwAwayTypeBytes [2]byte _, err := r.Read(throwAwayTypeBytes[:]) if err != nil { - return err + return nil, err } } // At this pint, we've either decided to keep the entire thing, or snip // off the first two bytes. In either case, we can just read it as // normal. - return chanUpdate.Decode(r, pver) + if err = chanUpdate.Decode(r, pver); err != nil { + return nil, err + } + + return chanUpdate, nil } // FailTemporaryChannelFailure is if an otherwise unspecified transient error @@ -647,12 +669,12 @@ type FailTemporaryChannelFailure struct { // which caused the failure. // // NOTE: This field is optional. - Update *ChannelUpdate1 + Update ChannelUpdate } // NewTemporaryChannelFailure creates new instance of the FailTemporaryChannelFailure. func NewTemporaryChannelFailure( - update *ChannelUpdate1) *FailTemporaryChannelFailure { + update ChannelUpdate) *FailTemporaryChannelFailure { return &FailTemporaryChannelFailure{Update: update} } @@ -687,11 +709,14 @@ func (f *FailTemporaryChannelFailure) Decode(r io.Reader, pver uint32) error { } if length != 0 { - f.Update = &ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update } return nil @@ -722,12 +747,12 @@ type FailAmountBelowMinimum struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewAmountBelowMinimum creates new instance of the FailAmountBelowMinimum. func NewAmountBelowMinimum(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailAmountBelowMinimum { + update ChannelUpdate) *FailAmountBelowMinimum { return &FailAmountBelowMinimum{ HtlcMsat: htlcMsat, @@ -763,11 +788,16 @@ func (f *FailAmountBelowMinimum) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} - - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, + update, err := parseChannelUpdateCompatibilityMode( + r, length, pver, ) + if err != nil { + return err + } + + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -778,7 +808,7 @@ func (f *FailAmountBelowMinimum) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFeeInsufficient is returned if the HTLC does not pay sufficient fee, we @@ -792,12 +822,13 @@ type FailFeeInsufficient struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewFeeInsufficient creates new instance of the FailFeeInsufficient. func NewFeeInsufficient(htlcMsat MilliSatoshi, - update ChannelUpdate1) *FailFeeInsufficient { + update ChannelUpdate) *FailFeeInsufficient { + return &FailFeeInsufficient{ HtlcMsat: htlcMsat, Update: update, @@ -832,11 +863,14 @@ func (f *FailFeeInsufficient) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -847,7 +881,7 @@ func (f *FailFeeInsufficient) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailIncorrectCltvExpiry is returned if outgoing cltv value does not match @@ -863,12 +897,12 @@ type FailIncorrectCltvExpiry struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewIncorrectCltvExpiry creates new instance of the FailIncorrectCltvExpiry. func NewIncorrectCltvExpiry(cltvExpiry uint32, - update ChannelUpdate1) *FailIncorrectCltvExpiry { + update ChannelUpdate) *FailIncorrectCltvExpiry { return &FailIncorrectCltvExpiry{ CltvExpiry: cltvExpiry, @@ -901,11 +935,14 @@ func (f *FailIncorrectCltvExpiry) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -916,7 +953,7 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailExpiryTooSoon is returned if the ctlv-expiry is too near, we tell them @@ -926,11 +963,11 @@ func (f *FailIncorrectCltvExpiry) Encode(w *bytes.Buffer, pver uint32) error { type FailExpiryTooSoon struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewExpiryTooSoon creates new instance of the FailExpiryTooSoon. -func NewExpiryTooSoon(update ChannelUpdate1) *FailExpiryTooSoon { +func NewExpiryTooSoon(update ChannelUpdate) *FailExpiryTooSoon { return &FailExpiryTooSoon{ Update: update, } @@ -959,18 +996,21 @@ func (f *FailExpiryTooSoon) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. // // NOTE: Part of the Serializable interface. func (f *FailExpiryTooSoon) Encode(w *bytes.Buffer, pver uint32) error { - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailChannelDisabled is returned if the channel is disabled, we tell them the @@ -985,12 +1025,12 @@ type FailChannelDisabled struct { // Update is used to update information about state of the channel // which caused the failure. - Update ChannelUpdate1 + Update ChannelUpdate } // NewChannelDisabled creates new instance of the FailChannelDisabled. func NewChannelDisabled(flags uint16, - update ChannelUpdate1) *FailChannelDisabled { + update ChannelUpdate) *FailChannelDisabled { return &FailChannelDisabled{ Flags: flags, @@ -1026,11 +1066,14 @@ func (f *FailChannelDisabled) Decode(r io.Reader, pver uint32) error { return err } - f.Update = ChannelUpdate1{} + update, err := parseChannelUpdateCompatibilityMode(r, length, pver) + if err != nil { + return err + } - return parseChannelUpdateCompatibilityMode( - r, length, &f.Update, pver, - ) + f.Update = update + + return nil } // Encode writes the failure in bytes stream. @@ -1041,7 +1084,7 @@ func (f *FailChannelDisabled) Encode(w *bytes.Buffer, pver uint32) error { return err } - return writeOnionErrorChanUpdate(w, &f.Update, pver) + return writeOnionErrorChanUpdate(w, f.Update, pver) } // FailFinalIncorrectCltvExpiry is returned if the outgoing_cltv_value does not @@ -1514,7 +1557,7 @@ func makeEmptyOnionError(code FailCode) (FailureMessage, error) { // writeOnionErrorChanUpdate writes out a ChannelUpdate using the onion error // format. The format is that we first write out the true serialized length of // the channel update, followed by the serialized channel update itself. -func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate *ChannelUpdate1, +func writeOnionErrorChanUpdate(w *bytes.Buffer, chanUpdate ChannelUpdate, pver uint32) error { // First, we encode the channel update in a temporary buffer in order diff --git a/lnwire/onion_error_test.go b/lnwire/onion_error_test.go index 37cd94b8a..6d5405869 100644 --- a/lnwire/onion_error_test.go +++ b/lnwire/onion_error_test.go @@ -20,7 +20,7 @@ var ( testType = uint64(3) testOffset = uint16(24) sig, _ = NewSigFromSignature(testSig) - testChannelUpdate = ChannelUpdate1{ + testChannelUpdate = &ChannelUpdate1{ Signature: sig, ShortChannelID: NewShortChanIDFromInt(1), Timestamp: 1, @@ -46,7 +46,7 @@ var onionFailures = []FailureMessage{ NewInvalidOnionVersion(testOnionHash), NewInvalidOnionHmac(testOnionHash), NewInvalidOnionKey(testOnionHash), - NewTemporaryChannelFailure(&testChannelUpdate), + NewTemporaryChannelFailure(testChannelUpdate), NewTemporaryChannelFailure(nil), NewAmountBelowMinimum(testAmount, testChannelUpdate), NewFeeInsufficient(testAmount, testChannelUpdate), @@ -137,9 +137,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // Now that we have the set of bytes encoded, we'll ensure that we're // able to decode it using our compatibility method, as it's a regular // encoded channel update message. - var newChanUpdate ChannelUpdate1 - err := parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate, 0, + newChanUpdate, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -164,9 +163,8 @@ func TestChannelUpdateCompatibilityParsing(t *testing.T) { // We should be able to properly parse the encoded channel update // message even with the extra two bytes. - var newChanUpdate2 ChannelUpdate1 - err = parseChannelUpdateCompatibilityMode( - &b, uint16(b.Len()), &newChanUpdate2, 0, + newChanUpdate2, err := parseChannelUpdateCompatibilityMode( + &b, uint16(b.Len()), 0, ) require.NoError(t, err, "unable to parse channel update") @@ -185,7 +183,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // raw serialized length. var b bytes.Buffer update := testChannelUpdate - trueUpdateLength, err := WriteMessage(&b, &update, 0) + trueUpdateLength, err := WriteMessage(&b, update, 0) if err != nil { t.Fatalf("unable to write update: %v", err) } @@ -193,7 +191,7 @@ func TestWriteOnionErrorChanUpdate(t *testing.T) { // Next, we'll use the function to encode the update as we would in a // onion error message. var errorBuf bytes.Buffer - err = writeOnionErrorChanUpdate(&errorBuf, &update, 0) + err = writeOnionErrorChanUpdate(&errorBuf, update, 0) require.NoError(t, err, "unable to encode onion error") // Finally, read the length encoded and ensure that it matches the raw diff --git a/routing/missioncontrol_test.go b/routing/missioncontrol_test.go index 4a0f73871..f68e433d3 100644 --- a/routing/missioncontrol_test.go +++ b/routing/missioncontrol_test.go @@ -197,7 +197,7 @@ func TestMissionControl(t *testing.T) { // A node level failure should bring probability of all known channels // back to zero. - ctx.reportFailure(0, lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{})) + ctx.reportFailure(0, lnwire.NewExpiryTooSoon(&lnwire.ChannelUpdate1{})) ctx.expectP(1000, 0) // Check whether history snapshot looks sane. @@ -219,14 +219,14 @@ func TestMissionControlChannelUpdate(t *testing.T) { // Report a policy related failure. Because it is the first, we don't // expect a penalty. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, testAprioriHopProbability) // Report another failure for the same channel. We expect it to be // pruned. ctx.reportFailure( - 0, lnwire.NewFeeInsufficient(0, lnwire.ChannelUpdate1{}), + 0, lnwire.NewFeeInsufficient(0, &lnwire.ChannelUpdate1{}), ) ctx.expectP(100, 0) } diff --git a/routing/result_interpretation_test.go b/routing/result_interpretation_test.go index 68b527e5a..2a9c15383 100644 --- a/routing/result_interpretation_test.go +++ b/routing/result_interpretation_test.go @@ -164,8 +164,9 @@ var resultTestCases = []resultTestCase{ name: "fail expiry too soon", route: &routeFourHop, failureSrcIdx: 3, - failure: lnwire.NewExpiryTooSoon(lnwire.ChannelUpdate1{}), - + failure: lnwire.NewExpiryTooSoon( + &lnwire.ChannelUpdate1{}, + ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ getTestPair(0, 1): failPairResult(0), @@ -267,7 +268,7 @@ var resultTestCases = []resultTestCase{ route: &routeFourHop, failureSrcIdx: 2, failure: lnwire.NewFeeInsufficient( - 0, lnwire.ChannelUpdate1{}, + 0, &lnwire.ChannelUpdate1{}, ), expectedResult: &interpretedResult{ pairResults: map[DirectedNodePair]pairResult{ diff --git a/routing/router.go b/routing/router.go index 954df93e3..01d2ebb13 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1270,20 +1270,20 @@ func (r *ChannelRouter) sendPayment(ctx context.Context, // extractChannelUpdate examines the error and extracts the channel update. func (r *ChannelRouter) extractChannelUpdate( - failure lnwire.FailureMessage) *lnwire.ChannelUpdate1 { + failure lnwire.FailureMessage) lnwire.ChannelUpdate { - var update *lnwire.ChannelUpdate1 + var update lnwire.ChannelUpdate switch onionErr := failure.(type) { case *lnwire.FailExpiryTooSoon: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailAmountBelowMinimum: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailFeeInsufficient: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailIncorrectCltvExpiry: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailChannelDisabled: - update = &onionErr.Update + update = onionErr.Update case *lnwire.FailTemporaryChannelFailure: update = onionErr.Update } diff --git a/routing/router_test.go b/routing/router_test.go index 8a0b88c7f..cf2c64794 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -550,7 +550,7 @@ func TestChannelUpdateValidation(t *testing.T) { func(firstHop lnwire.ShortChannelID) ([32]byte, error) { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) @@ -657,9 +657,6 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { signErrChanUpdate(t, ctx.privKeys["songoku"], errChanUpdate) - chanUpd, ok := errChanUpdate.(*lnwire.ChannelUpdate1) - require.True(t, ok) - // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to Son goku. This will be a fee related error, so // it should only cause the edge to be pruned after the second attempt. @@ -680,7 +677,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: *chanUpd, + Update: errChanUpdate, }, 1, ) } @@ -790,7 +787,7 @@ func TestSendPaymentErrorFeeInsufficientPrivateEdge(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -916,7 +913,7 @@ func TestSendPaymentPrivateEdgeUpdateFeeExceedsLimit(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: &errChanUpdate, }, 1, ) }, @@ -1000,9 +997,6 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { ) require.NoError(t, err) - chanUpd, ok := errChanUpdate.(*lnwire.ChannelUpdate1) - require.True(t, ok) - // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to son goku. Since this is a time lock related // error, we should fail the payment flow all together, as Goku is the @@ -1012,7 +1006,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ - Update: *chanUpd, + Update: errChanUpdate, }, 1, ) } @@ -1060,7 +1054,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ - Update: *chanUpd, + Update: errChanUpdate, }, 1, ) } @@ -1454,7 +1448,7 @@ func TestSendToRouteStructuredError(t *testing.T) { testCases := map[int]lnwire.FailureMessage{ finalHopIndex: lnwire.NewFailIncorrectDetails(payAmt, 100), 1: &lnwire.FailFeeInsufficient{ - Update: lnwire.ChannelUpdate1{}, + Update: &lnwire.ChannelUpdate1{}, }, }