channeldb+routing: persist first hop custom records

With this commit we make sure the first hop custom records aren't lost
on restart/resume of a payment, so we persist it as part of the
PaymentCreationInfo struct.
This commit is contained in:
Oliver Gugger 2024-08-30 13:10:34 +02:00
parent 857a16d838
commit ff1a45549e
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
3 changed files with 63 additions and 36 deletions

View File

@ -195,6 +195,11 @@ type PaymentCreationInfo struct {
// PaymentRequest is the full payment request, if any.
PaymentRequest []byte
// FirstHopCustomRecords are the TLV records that are to be sent to the
// first hop of this payment. These records will be transmitted via the
// wire message only and therefore do not affect the onion payload size.
FirstHopCustomRecords lnwire.CustomRecords
}
// htlcBucketKey creates a composite key from prefix and id where the result is
@ -1010,10 +1015,21 @@ func serializePaymentCreationInfo(w io.Writer, c *PaymentCreationInfo) error {
return err
}
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
err := c.FirstHopCustomRecords.SerializeTo(w)
if err != nil {
return err
}
return nil
}
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo,
error) {
var scratch [8]byte
c := &PaymentCreationInfo{}
@ -1046,6 +1062,15 @@ func deserializePaymentCreationInfo(r io.Reader) (*PaymentCreationInfo, error) {
}
c.PaymentRequest = payReq
// Any remaining bytes are TLV encoded records. Currently, these are
// only the custom records provided by the user to be sent to the first
// hop. But this can easily be extended with further records by merging
// the records into a single TLV stream.
c.FirstHopCustomRecords, err = lnwire.ParseCustomRecordsFrom(r)
if err != nil {
return nil, err
}
return c, nil
}

View File

@ -13,6 +13,7 @@ import (
"github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route"
"github.com/stretchr/testify/require"
@ -108,7 +109,7 @@ func makeFakeInfo() (*PaymentCreationInfo, *HTLCAttemptInfo) {
// Use single second precision to avoid false positive test
// failures due to the monotonic time component.
CreationTime: time.Unix(time.Now().Unix(), 0),
PaymentRequest: []byte(""),
PaymentRequest: []byte("test"),
}
a := NewHtlcAttempt(
@ -124,36 +125,40 @@ func TestSentPaymentSerialization(t *testing.T) {
c, s := makeFakeInfo()
var b bytes.Buffer
if err := serializePaymentCreationInfo(&b, c); err != nil {
t.Fatalf("unable to serialize creation info: %v", err)
}
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
// Assert the length of the serialized creation info is as expected,
// without any custom records.
baseLength := 32 + 8 + 8 + 4 + len(c.PaymentRequest)
require.Len(t, b.Bytes(), baseLength)
newCreationInfo, err := deserializePaymentCreationInfo(&b)
require.NoError(t, err, "unable to deserialize creation info")
if !reflect.DeepEqual(c, newCreationInfo) {
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(c), spew.Sdump(newCreationInfo),
)
}
require.NoError(t, err, "deserialize")
require.Equal(t, c, newCreationInfo)
b.Reset()
if err := serializeHTLCAttemptInfo(&b, s); err != nil {
t.Fatalf("unable to serialize info: %v", err)
// Now we add some custom records to the creation info and serialize it
// again.
c.FirstHopCustomRecords = lnwire.CustomRecords{
lnwire.MinCustomRecordsTlvType: []byte{1, 2, 3},
}
require.NoError(t, serializePaymentCreationInfo(&b, c), "serialize")
newCreationInfo, err = deserializePaymentCreationInfo(&b)
require.NoError(t, err, "deserialize")
require.Equal(t, c, newCreationInfo)
require.NoError(t, serializeHTLCAttemptInfo(&b, s), "serialize")
newWireInfo, err := deserializeHTLCAttemptInfo(&b)
require.NoError(t, err, "unable to deserialize info")
require.NoError(t, err, "deserialize")
newWireInfo.AttemptID = s.AttemptID
// First we verify all the records match up porperly, as they aren't
// First we verify all the records match up properly, as they aren't
// able to be properly compared using reflect.DeepEqual.
err = assertRouteEqual(&s.Route, &newWireInfo.Route)
if err != nil {
t.Fatalf("Routes do not match after "+
"serialization/deserialization: %v", err)
}
require.NoError(t, err)
// Clear routes to allow DeepEqual to compare the remaining fields.
newWireInfo.Route = route.Route{}
@ -163,12 +168,7 @@ func TestSentPaymentSerialization(t *testing.T) {
// DeepEqual, and assert that our key equals the original key.
require.Equal(t, s.cachedSessionKey, newWireInfo.SessionKey())
if !reflect.DeepEqual(s, newWireInfo) {
t.Fatalf("Payments do not match after "+
"serialization/deserialization %v vs %v",
spew.Sdump(s), spew.Sdump(newWireInfo),
)
}
require.Equal(t, s, newWireInfo)
}
// assertRouteEquals compares to routes for equality and returns an error if

View File

@ -1022,10 +1022,11 @@ func (r *ChannelRouter) PreparePayment(payment *LightningPayment) (
//
// TODO(roasbeef): store records as part of creation info?
info := &channeldb.PaymentCreationInfo{
PaymentIdentifier: payment.Identifier(),
Value: payment.Amount,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: payment.PaymentRequest,
PaymentIdentifier: payment.Identifier(),
Value: payment.Amount,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: payment.PaymentRequest,
FirstHopCustomRecords: payment.FirstHopCustomRecords,
}
// Create a new ShardTracker that we'll use during the life cycle of
@ -1120,10 +1121,11 @@ func (r *ChannelRouter) sendToRoute(htlcHash lntypes.Hash, rt *route.Route,
// Record this payment hash with the ControlTower, ensuring it is not
// already in-flight.
info := &channeldb.PaymentCreationInfo{
PaymentIdentifier: paymentIdentifier,
Value: amt,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: nil,
PaymentIdentifier: paymentIdentifier,
Value: amt,
CreationTime: r.cfg.Clock.Now(),
PaymentRequest: nil,
FirstHopCustomRecords: firstHopCustomRecords,
}
err := r.cfg.Control.InitPayment(paymentIdentifier, info)
@ -1483,7 +1485,7 @@ func (r *ChannelRouter) resumePayments() error {
noTimeout := time.Duration(0)
_, _, err := r.sendPayment(
context.Background(), 0, payHash, noTimeout, paySession,
shardTracker, nil,
shardTracker, payment.Info.FirstHopCustomRecords,
)
if err != nil {
log.Errorf("Resuming payment %v failed: %v", payHash,