From 4a6f5d8d3d4485d33245ced96aee5aab86729655 Mon Sep 17 00:00:00 2001 From: Conner Fromknecht Date: Mon, 4 Nov 2019 15:10:00 -0800 Subject: [PATCH] htlcswitch/hop/payload: parse option_mpp --- htlcswitch/hop/payload.go | 22 +++++++ htlcswitch/hop/payload_test.go | 102 ++++++++++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 9 deletions(-) diff --git a/htlcswitch/hop/payload.go b/htlcswitch/hop/payload.go index 2edd9aa8b..ad164e287 100644 --- a/htlcswitch/hop/payload.go +++ b/htlcswitch/hop/payload.go @@ -81,6 +81,10 @@ type Payload struct { // FwdInfo holds the basic parameters required for HTLC forwarding, e.g. // amount, cltv, and next hop. FwdInfo ForwardingInfo + + // MPP holds the info provided in an option_mpp record when parsed from + // a TLV onion payload. + MPP *record.MPP } // NewLegacyPayload builds a Payload from the amount, cltv, and next hop @@ -105,12 +109,14 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { cid uint64 amt uint64 cltv uint32 + mpp = &record.MPP{} ) tlvStream, err := tlv.NewStream( record.NewAmtToFwdRecord(&amt), record.NewLockTimeRecord(&cltv), record.NewNextHopIDRecord(&cid), + mpp.Record(), ) if err != nil { return nil, err @@ -151,6 +157,12 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { return nil, err } + // If no MPP field was parsed, set the MPP field on the resulting + // payload to nil. + if _, ok := parsedTypes[record.MPPOnionType]; !ok { + mpp = nil + } + return &Payload{ FwdInfo: ForwardingInfo{ Network: BitcoinNetwork, @@ -158,6 +170,7 @@ func NewPayloadFromReader(r io.Reader) (*Payload, error) { AmountToForward: lnwire.MilliSatoshi(amt), OutgoingCTLV: cltv, }, + MPP: mpp, }, nil } @@ -179,6 +192,7 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, _, hasAmt := parsedTypes[record.AmtOnionType] _, hasLockTime := parsedTypes[record.LockTimeOnionType] _, hasNextHop := parsedTypes[record.NextHopOnionType] + _, hasMPP := parsedTypes[record.MPPOnionType] switch { @@ -207,6 +221,14 @@ func ValidateParsedPayloadTypes(parsedTypes tlv.TypeSet, Violation: IncludedViolation, FinalHop: true, } + + // Intermediate nodes should never receive MPP fields. + case !isFinalHop && hasMPP: + return ErrInvalidPayload{ + Type: record.MPPOnionType, + Violation: IncludedViolation, + FinalHop: isFinalHop, + } } return nil diff --git a/htlcswitch/hop/payload_test.go b/htlcswitch/hop/payload_test.go index a49cd2b11..7ef35e8b4 100644 --- a/htlcswitch/hop/payload_test.go +++ b/htlcswitch/hop/payload_test.go @@ -6,13 +6,15 @@ import ( "testing" "github.com/lightningnetwork/lnd/htlcswitch/hop" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" ) type decodePayloadTest struct { - name string - payload []byte - expErr error + name string + payload []byte + expErr error + shouldHaveMPP bool } var decodePayloadTests = []decodePayloadTest{ @@ -79,9 +81,9 @@ var decodePayloadTests = []decodePayloadTest{ }, { name: "required type after omitted hop id", - payload: []byte{0x02, 0x00, 0x04, 0x00, 0x08, 0x00}, + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x0a, 0x00}, expErr: hop.ErrInvalidPayload{ - Type: 8, + Type: 10, Violation: hop.RequiredViolation, FinalHop: true, }, @@ -89,10 +91,10 @@ var decodePayloadTests = []decodePayloadTest{ { name: "required type after included hop id", payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x08, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: 8, + Type: 10, Violation: hop.RequiredViolation, FinalHop: false, }, @@ -112,7 +114,7 @@ var decodePayloadTests = []decodePayloadTest{ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, }, expErr: hop.ErrInvalidPayload{ - Type: 6, + Type: record.NextHopOnionType, Violation: hop.IncludedViolation, FinalHop: true, }, @@ -128,6 +130,60 @@ var decodePayloadTests = []decodePayloadTest{ FinalHop: false, }, }, + { + name: "valid intermediate hop", + payload: []byte{0x02, 0x00, 0x04, 0x00, 0x06, 0x08, 0x01, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, + expErr: nil, + }, + { + name: "valid final hop", + payload: []byte{0x02, 0x00, 0x04, 0x00}, + expErr: nil, + }, + { + name: "intermediate hop with mpp", + payload: []byte{ + // amount + 0x02, 0x00, + // cltv + 0x04, 0x00, + // next hop id + 0x06, 0x08, + 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + // mpp + 0x08, 0x21, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x08, + }, + expErr: hop.ErrInvalidPayload{ + Type: record.MPPOnionType, + Violation: hop.IncludedViolation, + FinalHop: false, + }, + }, + { + name: "final hop with mpp", + payload: []byte{ + // amount + 0x02, 0x00, + // cltv + 0x04, 0x00, + // mpp + 0x08, 0x21, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x08, + }, + expErr: nil, + shouldHaveMPP: true, + }, } // TestDecodeHopPayloadRecordValidation asserts that parsing the payloads in the @@ -142,9 +198,37 @@ func TestDecodeHopPayloadRecordValidation(t *testing.T) { } func testDecodeHopPayloadValidation(t *testing.T, test decodePayloadTest) { - _, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) + var ( + testTotalMsat = lnwire.MilliSatoshi(8) + testAddr = [32]byte{ + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, 0x11, + } + ) + + p, err := hop.NewPayloadFromReader(bytes.NewReader(test.payload)) if !reflect.DeepEqual(test.expErr, err) { t.Fatalf("expected error mismatch, want: %v, got: %v", test.expErr, err) } + if err != nil { + return + } + + // Assert MPP fields if we expect them. + if test.shouldHaveMPP { + if p.MPP == nil { + t.Fatalf("payload should have MPP record") + } + if p.MPP.TotalMsat() != testTotalMsat { + t.Fatalf("invalid total msat") + } + if p.MPP.PaymentAddr() != testAddr { + t.Fatalf("invalid payment addr") + } + } else if p.MPP != nil { + t.Fatalf("unexpected MPP payload") + } }