mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-04-05 10:39:03 +02:00
channeldb+routing: persist first hop custom records
With this commit we make sure the first hop custom records aren't lost on restart/resume of a payment, so we persist it as part of the PaymentCreationInfo struct.
This commit is contained in:
parent
857a16d838
commit
ff1a45549e
@ -195,6 +195,11 @@ type PaymentCreationInfo struct {
|
||||
|
||||
// PaymentRequest is the full payment request, if any.
|
||||
PaymentRequest []byte
|
||||
|
||||
// FirstHopCustomRecords are the TLV records that are to be sent to the
|
||||
// first hop of this payment. These records will be transmitted via the
|
||||
// wire message only and therefore do not affect the onion payload size.
|
||||
FirstHopCustomRecords lnwire.CustomRecords
|
||||
}
|
||||
|
||||
// htlcBucketKey creates a composite key from prefix and id where the result is
|
||||
@ -1010,10 +1015,21 @@ func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Any remaining bytes are TLV encoded records. Currently, these are
|
||||
// only the custom records provided by the user to be sent to the first
|
||||
// hop. But this can easily be extended with further records by merging
|
||||
// the records into a single TLV stream.
|
||||
err := c.FirstHopCustomRecords.SerializeTo(w)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
|
||||
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo,
|
||||
error) {
|
||||
|
||||
var scratch [8]byte
|
||||
|
||||
c := &PaymentCreationInfo{}
|
||||
@ -1046,6 +1062,15 @@ func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
|
||||
}
|
||||
c.PaymentRequest = payReq
|
||||
|
||||
// Any remaining bytes are TLV encoded records. Currently, these are
|
||||
// only the custom records provided by the user to be sent to the first
|
||||
// hop. But this can easily be extended with further records by merging
|
||||
// the records into a single TLV stream.
|
||||
c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -13,6 +13,7 @@ import (
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lntypes"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/record"
|
||||
"github.com/lightningnetwork/lnd/routing/route"
|
||||
"github.com/stretchr/testify/require"
|
||||
@ -108,7 +109,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) {
|
||||
// Use single second precision to avoid false positive test
|
||||
// failures due to the monotonic time component.
|
||||
CreationTime: time.Unix(time.Now().Unix(), 0),
|
||||
PaymentRequest: []byte(""),
|
||||
PaymentRequest: []byte("test"),
|
||||
}
|
||||
|
||||
a := NewHtlcAttempt(
|
||||
@ -124,36 +125,40 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
c, s := makeFakeInfo()
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := serializePaymentCreationInfo(&b, c); err != nil {
|
||||
t.Fatalf("unable to serialize creation info: %v", err)
|
||||
}
|
||||
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
|
||||
|
||||
// Assert the length of the serialized creation info is as expected,
|
||||
// without any custom records.
|
||||
baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest)
|
||||
require.Len(t, b.Bytes(), baseLength)
|
||||
|
||||
newCreationInfo, err := deserializePaymentCreationInfo(&b)
|
||||
require.NoError(t, err, "unable to deserialize creation info")
|
||||
|
||||
if !reflect.DeepEqual(c, newCreationInfo) {
|
||||
t.Fatalf("Payments do not match after "+
|
||||
"serialization/deserialization %v vs %v",
|
||||
spew.Sdump(c), spew.Sdump(newCreationInfo),
|
||||
)
|
||||
}
|
||||
require.NoError(t, err, "deserialize")
|
||||
require.Equal(t, c, newCreationInfo)
|
||||
|
||||
b.Reset()
|
||||
if err := serializeHTLCAttemptInfo(&b, s); err != nil {
|
||||
t.Fatalf("unable to serialize info: %v", err)
|
||||
|
||||
// Now we add some custom records to the creation info and serialize it
|
||||
// again.
|
||||
c.FirstHopCustomRecords = lnwire.CustomRecords{
|
||||
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
|
||||
}
|
||||
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
|
||||
|
||||
newCreationInfo, err = deserializePaymentCreationInfo(&b)
|
||||
require.NoError(t, err, "deserialize")
|
||||
require.Equal(t, c, newCreationInfo)
|
||||
|
||||
require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize")
|
||||
|
||||
newWireInfo, err := deserializeHTLCAttemptInfo(&b)
|
||||
require.NoError(t, err, "unable to deserialize info")
|
||||
require.NoError(t, err, "deserialize")
|
||||
newWireInfo.AttemptID = s.AttemptID
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// First we verify all the records match up properly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
err = assertRouteEqual(&s.Route, &newWireInfo.Route)
|
||||
if err != nil {
|
||||
t.Fatalf("Routes do not match after "+
|
||||
"serialization/deserialization: %v", err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
|
||||
// Clear routes to allow DeepEqual to compare the remaining fields.
|
||||
newWireInfo.Route = route.Route{}
|
||||
@ -163,12 +168,7 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
// DeepEqual, and assert that our key equals the original key.
|
||||
require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey())
|
||||
|
||||
if !reflect.DeepEqual(s, newWireInfo) {
|
||||
t.Fatalf("Payments do not match after "+
|
||||
"serialization/deserialization %v vs %v",
|
||||
spew.Sdump(s), spew.Sdump(newWireInfo),
|
||||
)
|
||||
}
|
||||
require.Equal(t, s, newWireInfo)
|
||||
}
|
||||
|
||||
// assertRouteEquals compares to routes for equality and returns an error if
|
||||
|
@ -1022,10 +1022,11 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
|
||||
//
|
||||
// TODO(roasbeef): store records as part of creation info?
|
||||
info := &channeldb.PaymentCreationInfo{
|
||||
PaymentIdentifier: payment.Identifier(),
|
||||
Value: payment.Amount,
|
||||
CreationTime: r.cfg.Clock.Now(),
|
||||
PaymentRequest: payment.PaymentRequest,
|
||||
PaymentIdentifier: payment.Identifier(),
|
||||
Value: payment.Amount,
|
||||
CreationTime: r.cfg.Clock.Now(),
|
||||
PaymentRequest: payment.PaymentRequest,
|
||||
FirstHopCustomRecords: payment.FirstHopCustomRecords,
|
||||
}
|
||||
|
||||
// Create a new ShardTracker that we'll use during the life cycle of
|
||||
@ -1120,10 +1121,11 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
|
||||
// Record this payment hash with the ControlTower, ensuring it is not
|
||||
// already in-flight.
|
||||
info := &channeldb.PaymentCreationInfo{
|
||||
PaymentIdentifier: paymentIdentifier,
|
||||
Value: amt,
|
||||
CreationTime: r.cfg.Clock.Now(),
|
||||
PaymentRequest: nil,
|
||||
PaymentIdentifier: paymentIdentifier,
|
||||
Value: amt,
|
||||
CreationTime: r.cfg.Clock.Now(),
|
||||
PaymentRequest: nil,
|
||||
FirstHopCustomRecords: firstHopCustomRecords,
|
||||
}
|
||||
|
||||
err := r.cfg.Control.InitPayment(paymentIdentifier, info)
|
||||
@ -1483,7 +1485,7 @@ func (r *ChannelRouter) resumePayments() error {
|
||||
noTimeout := time.Duration(0)
|
||||
_, _, err := r.sendPayment(
|
||||
context.Background(), 0, payHash, noTimeout, paySession,
|
||||
shardTracker, nil,
|
||||
shardTracker, payment.Info.FirstHopCustomRecords,
|
||||
)
|
||||
if err != nil {
|
||||
log.Errorf("Resuming payment %v failed: %v", payHash,
|
||||
|
Loading…
x
Reference in New Issue
Block a user