diff --git a/channeldb/payments.go b/channeldb/payments.go index 079fedf23..a0db8f22a 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -195,6 +195,11 @@ type PaymentCreationInfo struct { // PaymentRequest is the full payment request, if any. PaymentRequest []byte + + // FirstHopCustomRecords are the TLV records that are to be sent to the + // first hop of this payment. These records will be transmitted via the + // wire message only and therefore do not affect the onion payload size. + FirstHopCustomRecords lnwire.CustomRecords } // htlcBucketKey creates a composite key from prefix and id where the result is @@ -1010,10 +1015,21 @@ func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error { return err } + // Any remaining bytes are TLV encoded records. Currently, these are + // only the custom records provided by the user to be sent to the first + // hop. But this can easily be extended with further records by merging + // the records into a single TLV stream. + err := c.FirstHopCustomRecords.SerializeTo(w) + if err != nil { + return err + } + return nil } -func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { +func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, + error) { + var scratch [8]byte c := &PaymentCreationInfo{} @@ -1046,6 +1062,15 @@ func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) { } c.PaymentRequest = payReq + // Any remaining bytes are TLV encoded records. Currently, these are + // only the custom records provided by the user to be sent to the first + // hop. But this can easily be extended with further records by merging + // the records into a single TLV stream. + c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r) + if err != nil { + return nil, err + } + return c, nil } diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index 769f4cc77..844b818e8 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -13,6 +13,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/stretchr/testify/require" @@ -108,7 +109,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) { // Use single second precision to avoid false positive test // failures due to the monotonic time component. CreationTime: time.Unix(time.Now().Unix(), 0), - PaymentRequest: []byte(""), + PaymentRequest: []byte("test"), } a := NewHtlcAttempt( @@ -124,36 +125,40 @@ func TestSentPaymentSerialization(t *testing.T) { c, s := makeFakeInfo() var b bytes.Buffer - if err := serializePaymentCreationInfo(&b, c); err != nil { - t.Fatalf("unable to serialize creation info: %v", err) - } + require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + + // Assert the length of the serialized creation info is as expected, + // without any custom records. + baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest) + require.Len(t, b.Bytes(), baseLength) newCreationInfo, err := deserializePaymentCreationInfo(&b) - require.NoError(t, err, "unable to deserialize creation info") - - if !reflect.DeepEqual(c, newCreationInfo) { - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(c), spew.Sdump(newCreationInfo), - ) - } + require.NoError(t, err, "deserialize") + require.Equal(t, c, newCreationInfo) b.Reset() - if err := serializeHTLCAttemptInfo(&b, s); err != nil { - t.Fatalf("unable to serialize info: %v", err) + + // Now we add some custom records to the creation info and serialize it + // again. + c.FirstHopCustomRecords = lnwire.CustomRecords{ + lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3}, } + require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize") + + newCreationInfo, err = deserializePaymentCreationInfo(&b) + require.NoError(t, err, "deserialize") + require.Equal(t, c, newCreationInfo) + + require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize") newWireInfo, err := deserializeHTLCAttemptInfo(&b) - require.NoError(t, err, "unable to deserialize info") + require.NoError(t, err, "deserialize") newWireInfo.AttemptID = s.AttemptID - // First we verify all the records match up porperly, as they aren't + // First we verify all the records match up properly, as they aren't // able to be properly compared using reflect.DeepEqual. err = assertRouteEqual(&s.Route, &newWireInfo.Route) - if err != nil { - t.Fatalf("Routes do not match after "+ - "serialization/deserialization: %v", err) - } + require.NoError(t, err) // Clear routes to allow DeepEqual to compare the remaining fields. newWireInfo.Route = route.Route{} @@ -163,12 +168,7 @@ func TestSentPaymentSerialization(t *testing.T) { // DeepEqual, and assert that our key equals the original key. require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey()) - if !reflect.DeepEqual(s, newWireInfo) { - t.Fatalf("Payments do not match after "+ - "serialization/deserialization %v vs %v", - spew.Sdump(s), spew.Sdump(newWireInfo), - ) - } + require.Equal(t, s, newWireInfo) } // assertRouteEquals compares to routes for equality and returns an error if diff --git a/routing/router.go b/routing/router.go index dbd2824bd..de5f45fd8 100644 --- a/routing/router.go +++ b/routing/router.go @@ -1022,10 +1022,11 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) ( // // TODO(roasbeef): store records as part of creation info? info := &channeldb.PaymentCreationInfo{ - PaymentIdentifier: payment.Identifier(), - Value: payment.Amount, - CreationTime: r.cfg.Clock.Now(), - PaymentRequest: payment.PaymentRequest, + PaymentIdentifier: payment.Identifier(), + Value: payment.Amount, + CreationTime: r.cfg.Clock.Now(), + PaymentRequest: payment.PaymentRequest, + FirstHopCustomRecords: payment.FirstHopCustomRecords, } // Create a new ShardTracker that we'll use during the life cycle of @@ -1120,10 +1121,11 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route, // Record this payment hash with the ControlTower, ensuring it is not // already in-flight. info := &channeldb.PaymentCreationInfo{ - PaymentIdentifier: paymentIdentifier, - Value: amt, - CreationTime: r.cfg.Clock.Now(), - PaymentRequest: nil, + PaymentIdentifier: paymentIdentifier, + Value: amt, + CreationTime: r.cfg.Clock.Now(), + PaymentRequest: nil, + FirstHopCustomRecords: firstHopCustomRecords, } err := r.cfg.Control.InitPayment(paymentIdentifier, info) @@ -1483,7 +1485,7 @@ func (r *ChannelRouter) resumePayments() error { noTimeout := time.Duration(0) _, _, err := r.sendPayment( context.Background(), 0, payHash, noTimeout, paySession, - shardTracker, nil, + shardTracker, payment.Info.FirstHopCustomRecords, ) if err != nil { log.Errorf("Resuming payment %v failed: %v", payHash,