From 585f28c5f5152e7d0ee0f3cba3a04ce678b881ca Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Wed, 1 Nov 2023 10:46:48 -0400 Subject: [PATCH] multi: explicitly signal final hop in pack hop payload Previously, we'd use the value of nextChanID to infer whether a payload was for the final hop in a route. This commit updates our packing logic to explicitly signal to account for blinded routes, which allow zero value nextChanID in intermediate hops. This is a preparatory commit that allows us to more thoroughly validate payloads. --- htlcswitch/hop/fuzz_test.go | 2 +- routing/route/route.go | 21 +++++++++++++-------- routing/route/route_test.go | 12 ++++++------ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/htlcswitch/hop/fuzz_test.go b/htlcswitch/hop/fuzz_test.go index c70c380c0..ca647899c 100644 --- a/htlcswitch/hop/fuzz_test.go +++ b/htlcswitch/hop/fuzz_test.go @@ -121,7 +121,7 @@ func fuzzPayload(f *testing.F, finalPayload bool) { var b bytes.Buffer hop, nextChanID := hopFromPayload(payload1) - err = hop.PackHopPayload(&b, nextChanID) + err = hop.PackHopPayload(&b, nextChanID, finalPayload) if errors.Is(err, route.ErrAMPMissingMPP) { // PackHopPayload refuses to encode an AMP record // without an MPP record. However, NewPayloadFromReader diff --git a/routing/route/route.go b/routing/route/route.go index fbf1b5b84..47c305ba0 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -172,11 +172,15 @@ func (h *Hop) Copy() *Hop { // PackHopPayload writes to the passed io.Writer, the series of byes that can // be placed directly into the per-hop payload (EOB) for this hop. This will // include the required routing fields, as well as serializing any of the -// passed optional TLVRecords. nextChanID is the unique channel ID that -// references the _outgoing_ channel ID that follows this hop. This field -// follows the same semantics as the NextAddress field in the onion: it should -// be set to zero to indicate the terminal hop. -func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { +// passed optional TLVRecords. nextChanID is the unique channel ID that +// references the _outgoing_ channel ID that follows this hop. The lastHop bool +// is used to signal whether this hop is the final hop in a route. Previously, +// a zero nextChanID would be used for this purpose, but with the addition of +// blinded routes which allow zero nextChanID values for intermediate hops we +// add an explicit signal. +func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64, + finalHop bool) error { + // If this is a legacy payload, then we'll exit here as this method // shouldn't be called. if h.LegacyPayload == true { @@ -218,7 +222,7 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { // attach it to the final hop. Otherwise the route was constructed // incorrectly. if h.MPP != nil { - if nextChanID == 0 { + if finalHop { records = append(records, h.MPP.Record()) } else { return ErrIntermediateMPPHop @@ -559,7 +563,8 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) { // If we aren't on the last hop, then we set the "next address" // field to be the channel that directly follows it. - if i != len(r.Hops)-1 { + finalHop := i == len(r.Hops)-1 + if !finalHop { nextHop = r.Hops[i+1].ChannelID } @@ -591,7 +596,7 @@ func (r *Route) ToSphinxPath() (*sphinx.PaymentPath, error) { // channel should be forwarded to so we can construct a // valid payload. var b bytes.Buffer - err := hop.PackHopPayload(&b, nextHop) + err := hop.PackHopPayload(&b, nextHop, finalHop) if err != nil { return nil, err } diff --git a/routing/route/route_test.go b/routing/route/route_test.go index 252ffba84..d73233e00 100644 --- a/routing/route/route_test.go +++ b/routing/route/route_test.go @@ -105,7 +105,7 @@ func TestMPPHop(t *testing.T) { // Encoding an MPP record to an intermediate hop should result in a // failure. var b bytes.Buffer - err := hop.PackHopPayload(&b, 2) + err := hop.PackHopPayload(&b, 2, false) if err != ErrIntermediateMPPHop { t.Fatalf("expected err: %v, got: %v", ErrIntermediateMPPHop, err) @@ -113,7 +113,7 @@ func TestMPPHop(t *testing.T) { // Encoding an MPP record to a final hop should be successful. b.Reset() - err = hop.PackHopPayload(&b, 0) + err = hop.PackHopPayload(&b, 0, true) if err != nil { t.Fatalf("expected err: %v, got: %v", nil, err) } @@ -135,7 +135,7 @@ func TestAMPHop(t *testing.T) { // Encoding an AMP record to an intermediate hop w/o an MPP record // should result in a failure. var b bytes.Buffer - err := hop.PackHopPayload(&b, 2) + err := hop.PackHopPayload(&b, 2, false) if err != ErrAMPMissingMPP { t.Fatalf("expected err: %v, got: %v", ErrAMPMissingMPP, err) @@ -144,7 +144,7 @@ func TestAMPHop(t *testing.T) { // Encoding an AMP record to a final hop w/o an MPP record should result // in a failure. b.Reset() - err = hop.PackHopPayload(&b, 0) + err = hop.PackHopPayload(&b, 0, true) if err != ErrAMPMissingMPP { t.Fatalf("expected err: %v, got: %v", ErrAMPMissingMPP, err) @@ -154,7 +154,7 @@ func TestAMPHop(t *testing.T) { // successful. hop.MPP = record.NewMPP(testAmt, testAddr) b.Reset() - err = hop.PackHopPayload(&b, 0) + err = hop.PackHopPayload(&b, 0, true) if err != nil { t.Fatalf("expected err: %v, got: %v", nil, err) } @@ -170,7 +170,7 @@ func TestNoForwardingParams(t *testing.T) { } var b bytes.Buffer - err := hop.PackHopPayload(&b, 2) + err := hop.PackHopPayload(&b, 2, false) require.NoError(t, err) }