From b5afd905d17520d2de8ac19d284053234acf4dd0 Mon Sep 17 00:00:00 2001 From: Carla Kirk-Cohen Date: Wed, 1 Nov 2023 10:39:33 -0400 Subject: [PATCH] htlcswitch/hop: explicitly signal final hop from sphinx packet Previously, we were using nextChanID to determine whether a hop payload is for the final recipient. This is no longer suitable in a route-blinding world where intermediate hops are allowed to have zero nextChanID TLVs (as this information is provided to forwarding nodes in their encrypted data). This commit updates payload reading to use the signal provided by sphinx that we are on the last packet, rather than implying it from the contents of a hop. --- htlcswitch/hop/fuzz_test.go | 20 ++++++-- htlcswitch/hop/iterator.go | 7 +-- htlcswitch/hop/payload.go | 21 ++++----- htlcswitch/hop/payload_test.go | 86 ++++++++++++++++++++++------------ 4 files changed, 87 insertions(+), 47 deletions(-) diff --git a/htlcswitch/hop/fuzz_test.go b/htlcswitch/hop/fuzz_test.go index 82a92eb22..c70c380c0 100644 --- a/htlcswitch/hop/fuzz_test.go +++ b/htlcswitch/hop/fuzz_test.go @@ -92,7 +92,21 @@ func hopFromPayload(p *Payload) (*route.Hop, uint64) { }, p.FwdInfo.NextHop.ToUint64() } -func FuzzPayload(f *testing.F) { +// FuzzPayloadFinal fuzzes final hop payloads, providing the additional context +// that the hop should be final (which is usually obtained by the structure +// of the sphinx packet). +func FuzzPayloadFinal(f *testing.F) { + fuzzPayload(f, true) +} + +// FuzzPayloadIntermediate fuzzes intermediate hop payloads, providing the +// additional context that a hop should be intermediate (which is usually +// obtained by the structure of the sphinx packet). +func FuzzPayloadIntermediate(f *testing.F) { + fuzzPayload(f, false) +} + +func fuzzPayload(f *testing.F, finalPayload bool) { f.Fuzz(func(t *testing.T, data []byte) { if len(data) > sphinx.MaxPayloadSize { return @@ -100,7 +114,7 @@ func FuzzPayload(f *testing.F) { r := bytes.NewReader(data) - payload1, err := NewPayloadFromReader(r) + payload1, err := NewPayloadFromReader(r, finalPayload) if err != nil { return } @@ -118,7 +132,7 @@ func FuzzPayload(f *testing.F) { } require.NoError(t, err) - payload2, err := NewPayloadFromReader(&b) + payload2, err := NewPayloadFromReader(&b, finalPayload) require.NoError(t, err) require.Equal(t, payload1, payload2) diff --git a/htlcswitch/hop/iterator.go b/htlcswitch/hop/iterator.go index b5c7d0c83..96a5c5f2b 100644 --- a/htlcswitch/hop/iterator.go +++ b/htlcswitch/hop/iterator.go @@ -90,9 +90,10 @@ func (r *sphinxHopIterator) HopPayload() (*Payload, error) { // Otherwise, if this is the TLV payload, then we'll make a new stream // to decode only what we need to make routing decisions. case sphinx.PayloadTLV: - return NewPayloadFromReader(bytes.NewReader( - r.processedPacket.Payload.Payload, - )) + return NewPayloadFromReader( + bytes.NewReader(r.processedPacket.Payload.Payload), + r.processedPacket.Action == sphinx.ExitNode, + ) default: return nil, fmt.Errorf("unknown sphinx payload type: %v", diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 03fdd94f3..7c2e607ae 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -127,8 +127,10 @@ func NewLegacyPayload(f *sphinx.HopData) *Payload { } // NewPayloadFromReader builds a new Hop from the passed io.Reader. The reader -// should correspond to the bytes encapsulated in a TLV onion payload. -func NewPayloadFromReader(r io.Reader) (*Payload, error) { +// should correspond to the bytes encapsulated in a TLV onion payload. The +// final hop bool signals that this payload was the final packet parsed by +// sphinx. +func NewPayloadFromReader(r io.Reader, finalHop bool) (*Payload, error) { var ( cid uint64 amt uint64 @@ -165,8 +167,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { // Validate whether the sender properly included or omitted tlv records // in accordance with BOLT 04. - nextHop := lnwire.NewShortChanIDFromInt(cid) - err = ValidateParsedPayloadTypes(parsedTypes, nextHop) + err = ValidateParsedPayloadTypes(parsedTypes, finalHop) if err != nil { return nil, err } @@ -177,7 +178,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, ErrInvalidPayload{ Type: *violatingType, Violation: RequiredViolation, - FinalHop: nextHop == Exit, + FinalHop: finalHop, } } @@ -210,7 +211,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return &Payload{ FwdInfo: ForwardingInfo{ - NextHop: nextHop, + NextHop: lnwire.NewShortChanIDFromInt(cid), AmountToForward: lnwire.MilliSatoshi(amt), OutgoingCTLV: cltv, }, @@ -248,9 +249,7 @@ func NewCustomRecords(parsedTypes tlv.TypeMap) record.CustomSet { // boolean should be true if the payload was parsed for an exit hop. The // requirements for this method are described in BOLT 04. func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, - nextHop lnwire.ShortChannelID) error { - - isFinalHop := nextHop == Exit + isFinalHop bool) error { _, hasAmt := parsedTypes[record.AmtOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType] @@ -276,8 +275,8 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeMap, FinalHop: isFinalHop, } - // The exit hop should omit the next hop id. If nextHop != Exit, the - // sender must have included a record, so we don't need to test for its + // The exit hop should omit the next hop id, otherwise the sender must + // have included a record, so we don't need to test for its // inclusion at intermediate hops directly. case isFinalHop && hasNextHop: return ErrInvalidPayload{ diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index 89cea6df8..594b98bc5 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -24,6 +24,7 @@ const testUnknownRequiredType = 0x80 type decodePayloadTest struct { name string payload []byte + isFinalHop bool expErr error expCustomRecords map[uint64][]byte shouldHaveMPP bool @@ -36,18 +37,21 @@ type decodePayloadTest struct { var decodePayloadTests = []decodePayloadTest{ { - name: "final hop valid", - payload: []byte{0x02, 0x00, 0x04, 0x00}, + name: "final hop valid", + isFinalHop: true, + payload: []byte{0x02, 0x00, 0x04, 0x00}, }, { - name: "intermediate hop valid", + name: "intermediate hop valid", + isFinalHop: false, payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, }, { - name: "final hop no amount", - payload: []byte{0x04, 0x00}, + name: "final hop no amount", + payload: []byte{0x04, 0x00}, + isFinalHop: true, expErr: hop.ErrInvalidPayload{ Type: record.AmtOnionType, Violation: hop.OmittedViolation, @@ -55,7 +59,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate hop no amount", + name: "intermediate hop no amount", + isFinalHop: false, payload: []byte{0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, @@ -66,8 +71,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "final hop no expiry", - payload: []byte{0x02, 0x00}, + name: "final hop no expiry", + isFinalHop: true, + payload: []byte{0x02, 0x00}, expErr: hop.ErrInvalidPayload{ Type: record.LockTimeOnionType, Violation: hop.OmittedViolation, @@ -75,7 +81,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate hop no expiry", + name: "intermediate hop no expiry", + isFinalHop: false, payload: []byte{0x02, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, @@ -86,7 +93,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "final hop next sid present", + name: "final hop next sid present", + isFinalHop: true, payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, @@ -97,7 +105,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type after omitted hop id", + name: "required type after omitted hop id", + isFinalHop: true, payload: []byte{ 0x02, 0x00, 0x04, 0x00, testUnknownRequiredType, 0x00, @@ -109,7 +118,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type after included hop id", + name: "required type after included hop id", + isFinalHop: false, payload: []byte{ 0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -122,8 +132,9 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type zero final hop", - payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00}, + name: "required type zero final hop", + isFinalHop: true, + payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00}, expErr: hop.ErrInvalidPayload{ Type: 0, Violation: hop.RequiredViolation, @@ -131,7 +142,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type zero final hop zero sid", + name: "required type zero final hop zero sid", + isFinalHop: true, payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, @@ -142,7 +154,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type zero intermediate hop", + name: "required type zero intermediate hop", + isFinalHop: false, payload: []byte{0x00, 0x00, 0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, @@ -153,7 +166,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "required type in custom range", + name: "required type in custom range", + isFinalHop: false, payload: []byte{0x02, 0x00, 0x04, 0x00, 0xfe, 0x00, 0x01, 0x00, 0x00, 0x02, 0x10, 0x11, }, @@ -162,19 +176,22 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "valid intermediate hop", + name: "valid intermediate hop", + isFinalHop: false, payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, expErr: nil, }, { - name: "valid final hop", - payload: []byte{0x02, 0x00, 0x04, 0x00}, - expErr: nil, + name: "valid final hop", + isFinalHop: true, + payload: []byte{0x02, 0x00, 0x04, 0x00}, + expErr: nil, }, { - name: "intermediate hop with mpp", + name: "intermediate hop with mpp", + isFinalHop: false, payload: []byte{ // amount 0x02, 0x00, @@ -198,7 +215,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate hop with amp", + name: "intermediate hop with amp", + isFinalHop: false, payload: []byte{ // amount 0x02, 0x00, @@ -229,7 +247,8 @@ var decodePayloadTests = []decodePayloadTest{ }, }, { - name: "intermediate hop with encrypted data", + name: "intermediate hop with encrypted data", + isFinalHop: false, payload: []byte{ // encrypted data 0x0a, 0x03, 0x03, 0x02, 0x01, @@ -237,7 +256,8 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveEncData: true, }, { - name: "intermediate hop with blinding point", + name: "intermediate hop with blinding point", + isFinalHop: false, payload: append([]byte{ // encrypted data 0x0a, 0x03, 0x03, 0x02, 0x01, @@ -251,7 +271,8 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveEncData: true, }, { - name: "final hop with mpp", + name: "final hop with mpp", + isFinalHop: true, payload: []byte{ // amount 0x02, 0x00, @@ -269,7 +290,8 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveMPP: true, }, { - name: "final hop with amp", + name: "final hop with amp", + isFinalHop: true, payload: []byte{ // amount 0x02, 0x00, @@ -293,7 +315,8 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveAMP: true, }, { - name: "final hop with metadata", + name: "final hop with metadata", + isFinalHop: true, payload: []byte{ // amount 0x02, 0x00, @@ -305,7 +328,8 @@ var decodePayloadTests = []decodePayloadTest{ shouldHaveMetadata: true, }, { - name: "final hop with total amount", + name: "final hop with total amount", + isFinalHop: true, payload: []byte{ // amount 0x02, 0x00, @@ -356,7 +380,9 @@ func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { testChildIndex = uint32(9) ) - p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) + p, err := hop.NewPayloadFromReader( + bytes.NewReader(test.payload), test.isFinalHop, + ) if !reflect.DeepEqual(test.expErr, err) { t.Fatalf("expected error mismatch, want: %v, got: %v", test.expErr, err)