mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-12-09 20:33:45 +01:00
channeldb+routing: persist first hop custom records
With this commit we make sure the first hop custom records aren't lost on restart/resume of a payment, so we persist it as part of the PaymentCreationInfo struct.
This commit is contained in:
@@ -195,6 +195,11 @@ type PaymentCreationInfo struct {
|
|||||||
|
|
||||||
// PaymentRequest is the full payment request, if any.
|
// PaymentRequest is the full payment request, if any.
|
||||||
PaymentRequest []byte
|
PaymentRequest []byte
|
||||||
|
|
||||||
|
// FirstHopCustomRecords are the TLV records that are to be sent to the
|
||||||
|
// first hop of this payment. These records will be transmitted via the
|
||||||
|
// wire message only and therefore do not affect the onion payload size.
|
||||||
|
FirstHopCustomRecords lnwire.CustomRecords
|
||||||
}
|
}
|
||||||
|
|
||||||
// htlcBucketKey creates a composite key from prefix and id where the result is
|
// htlcBucketKey creates a composite key from prefix and id where the result is
|
||||||
@@ -1010,10 +1015,21 @@ func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Any remaining bytes are TLV encoded records. Currently, these are
|
||||||
|
// only the custom records provided by the user to be sent to the first
|
||||||
|
// hop. But this can easily be extended with further records by merging
|
||||||
|
// the records into a single TLV stream.
|
||||||
|
err := c.FirstHopCustomRecords.SerializeTo(w)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
|
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo,
|
||||||
|
error) {
|
||||||
|
|
||||||
var scratch [8]byte
|
var scratch [8]byte
|
||||||
|
|
||||||
c := &PaymentCreationInfo{}
|
c := &PaymentCreationInfo{}
|
||||||
@@ -1046,6 +1062,15 @@ func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
|
|||||||
}
|
}
|
||||||
c.PaymentRequest = payReq
|
c.PaymentRequest = payReq
|
||||||
|
|
||||||
|
// Any remaining bytes are TLV encoded records. Currently, these are
|
||||||
|
// only the custom records provided by the user to be sent to the first
|
||||||
|
// hop. But this can easily be extended with further records by merging
|
||||||
|
// the records into a single TLV stream.
|
||||||
|
c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lightningnetwork/lnd/kvdb"
|
"github.com/lightningnetwork/lnd/kvdb"
|
||||||
"github.com/lightningnetwork/lnd/lntypes"
|
"github.com/lightningnetwork/lnd/lntypes"
|
||||||
|
"github.com/lightningnetwork/lnd/lnwire"
|
||||||
"github.com/lightningnetwork/lnd/record"
|
"github.com/lightningnetwork/lnd/record"
|
||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -108,7 +109,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) {
|
|||||||
// Use single second precision to avoid false positive test
|
// Use single second precision to avoid false positive test
|
||||||
// failures due to the monotonic time component.
|
// failures due to the monotonic time component.
|
||||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||||
PaymentRequest: []byte(""),
|
PaymentRequest: []byte("test"),
|
||||||
}
|
}
|
||||||
|
|
||||||
a := NewHtlcAttempt(
|
a := NewHtlcAttempt(
|
||||||
@@ -124,36 +125,40 @@ func TestSentPaymentSerialization(t *testing.T) {
|
|||||||
c, s := makeFakeInfo()
|
c, s := makeFakeInfo()
|
||||||
|
|
||||||
var b bytes.Buffer
|
var b bytes.Buffer
|
||||||
if err := serializePaymentCreationInfo(&b, c); err != nil {
|
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
|
||||||
t.Fatalf("unable to serialize creation info: %v", err)
|
|
||||||
}
|
// Assert the length of the serialized creation info is as expected,
|
||||||
|
// without any custom records.
|
||||||
|
baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest)
|
||||||
|
require.Len(t, b.Bytes(), baseLength)
|
||||||
|
|
||||||
newCreationInfo, err := deserializePaymentCreationInfo(&b)
|
newCreationInfo, err := deserializePaymentCreationInfo(&b)
|
||||||
require.NoError(t, err, "unable to deserialize creation info")
|
require.NoError(t, err, "deserialize")
|
||||||
|
require.Equal(t, c, newCreationInfo)
|
||||||
if !reflect.DeepEqual(c, newCreationInfo) {
|
|
||||||
t.Fatalf("Payments do not match after "+
|
|
||||||
"serialization/deserialization %v vs %v",
|
|
||||||
spew.Sdump(c), spew.Sdump(newCreationInfo),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.Reset()
|
b.Reset()
|
||||||
if err := serializeHTLCAttemptInfo(&b, s); err != nil {
|
|
||||||
t.Fatalf("unable to serialize info: %v", err)
|
// Now we add some custom records to the creation info and serialize it
|
||||||
|
// again.
|
||||||
|
c.FirstHopCustomRecords = lnwire.CustomRecords{
|
||||||
|
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
|
||||||
}
|
}
|
||||||
|
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
|
||||||
|
|
||||||
|
newCreationInfo, err = deserializePaymentCreationInfo(&b)
|
||||||
|
require.NoError(t, err, "deserialize")
|
||||||
|
require.Equal(t, c, newCreationInfo)
|
||||||
|
|
||||||
|
require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize")
|
||||||
|
|
||||||
newWireInfo, err := deserializeHTLCAttemptInfo(&b)
|
newWireInfo, err := deserializeHTLCAttemptInfo(&b)
|
||||||
require.NoError(t, err, "unable to deserialize info")
|
require.NoError(t, err, "deserialize")
|
||||||
newWireInfo.AttemptID = s.AttemptID
|
newWireInfo.AttemptID = s.AttemptID
|
||||||
|
|
||||||
// First we verify all the records match up porperly, as they aren't
|
// First we verify all the records match up properly, as they aren't
|
||||||
// able to be properly compared using reflect.DeepEqual.
|
// able to be properly compared using reflect.DeepEqual.
|
||||||
err = assertRouteEqual(&s.Route, &newWireInfo.Route)
|
err = assertRouteEqual(&s.Route, &newWireInfo.Route)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Fatalf("Routes do not match after "+
|
|
||||||
"serialization/deserialization: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear routes to allow DeepEqual to compare the remaining fields.
|
// Clear routes to allow DeepEqual to compare the remaining fields.
|
||||||
newWireInfo.Route = route.Route{}
|
newWireInfo.Route = route.Route{}
|
||||||
@@ -163,12 +168,7 @@ func TestSentPaymentSerialization(t *testing.T) {
|
|||||||
// DeepEqual, and assert that our key equals the original key.
|
// DeepEqual, and assert that our key equals the original key.
|
||||||
require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey())
|
require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey())
|
||||||
|
|
||||||
if !reflect.DeepEqual(s, newWireInfo) {
|
require.Equal(t, s, newWireInfo)
|
||||||
t.Fatalf("Payments do not match after "+
|
|
||||||
"serialization/deserialization %v vs %v",
|
|
||||||
spew.Sdump(s), spew.Sdump(newWireInfo),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// assertRouteEquals compares to routes for equality and returns an error if
|
// assertRouteEquals compares to routes for equality and returns an error if
|
||||||
|
|||||||
@@ -1026,6 +1026,7 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
|
|||||||
Value: payment.Amount,
|
Value: payment.Amount,
|
||||||
CreationTime: r.cfg.Clock.Now(),
|
CreationTime: r.cfg.Clock.Now(),
|
||||||
PaymentRequest: payment.PaymentRequest,
|
PaymentRequest: payment.PaymentRequest,
|
||||||
|
FirstHopCustomRecords: payment.FirstHopCustomRecords,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create a new ShardTracker that we'll use during the life cycle of
|
// Create a new ShardTracker that we'll use during the life cycle of
|
||||||
@@ -1124,6 +1125,7 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
|
|||||||
Value: amt,
|
Value: amt,
|
||||||
CreationTime: r.cfg.Clock.Now(),
|
CreationTime: r.cfg.Clock.Now(),
|
||||||
PaymentRequest: nil,
|
PaymentRequest: nil,
|
||||||
|
FirstHopCustomRecords: firstHopCustomRecords,
|
||||||
}
|
}
|
||||||
|
|
||||||
err := r.cfg.Control.InitPayment(paymentIdentifier, info)
|
err := r.cfg.Control.InitPayment(paymentIdentifier, info)
|
||||||
@@ -1483,7 +1485,7 @@ func (r *ChannelRouter) resumePayments() error {
|
|||||||
noTimeout := time.Duration(0)
|
noTimeout := time.Duration(0)
|
||||||
_, _, err := r.sendPayment(
|
_, _, err := r.sendPayment(
|
||||||
context.Background(), 0, payHash, noTimeout, paySession,
|
context.Background(), 0, payHash, noTimeout, paySession,
|
||||||
shardTracker, nil,
|
shardTracker, payment.Info.FirstHopCustomRecords,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Resuming payment %v failed: %v", payHash,
|
log.Errorf("Resuming payment %v failed: %v", payHash,
|
||||||
|
|||||||
Reference in New Issue
Block a user