channeldb: add new AMPInvoiceState field to store AMP sub-invoice metadata

In this commit, we add a new type `AMPInvoiceState` that's used to store
AMP sub-invoice meta data alongside the main invoice. This will be used
to allow changes to be made to an AMP invoices without reading out all
the HTLCs. In addition, callers can use this metadata to look up
information about the current sub-invoice state of AMP HTLCs.
This commit is contained in:
Olaoluwa Osuntokun 2021-10-14 18:11:11 +02:00
parent 8299d632e8
commit 65cca8dd1c
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
2 changed files with 431 additions and 18 deletions

View File

@ -1,6 +1,7 @@
package channeldb package channeldb
import ( import (
"bytes"
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"math" "math"
@ -12,6 +13,7 @@ import (
"github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lntypes"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@ -49,7 +51,8 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) {
Value: value, Value: value,
Features: emptyFeatures, Features: emptyFeatures,
}, },
Htlcs: map[CircuitKey]*InvoiceHTLC{}, Htlcs: map[CircuitKey]*InvoiceHTLC{},
AMPState: map[SetID]InvoiceStateAMP{},
} }
i.Memo = []byte("memo") i.Memo = []byte("memo")
@ -2474,3 +2477,89 @@ func TestAddInvoiceInvalidFeatureDeps(t *testing.T) {
lnwire.PaymentAddrOptional, lnwire.PaymentAddrOptional,
)) ))
} }
// TestEncodeDecodeAmpInvoiceState asserts that the nested TLV
// encoding+decoding for the AMPInvoiceState struct works as expected.
func TestEncodeDecodeAmpInvoiceState(t *testing.T) {
t.Parallel()
setID1 := [32]byte{1}
setID2 := [32]byte{2}
setID3 := [32]byte{3}
circuitKey1 := CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 1,
}
circuitKey2 := CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(2), HtlcID: 2,
}
circuitKey3 := CircuitKey{
ChanID: lnwire.NewShortChanIDFromInt(2), HtlcID: 3,
}
// Make a sample invoice state map that we'll encode then decode to
// assert equality of.
ampState := AMPInvoiceState{
setID1: InvoiceStateAMP{
State: HtlcStateSettled,
SettleDate: testNow,
SettleIndex: 1,
InvoiceKeys: map[CircuitKey]struct{}{
circuitKey1: struct{}{},
circuitKey2: struct{}{},
},
},
setID2: InvoiceStateAMP{
State: HtlcStateCanceled,
SettleDate: testNow,
SettleIndex: 2,
InvoiceKeys: map[CircuitKey]struct{}{
circuitKey1: struct{}{},
},
},
setID3: InvoiceStateAMP{
State: HtlcStateAccepted,
SettleDate: testNow,
SettleIndex: 3,
InvoiceKeys: map[CircuitKey]struct{}{
circuitKey1: struct{}{},
circuitKey2: struct{}{},
circuitKey3: struct{}{},
},
},
}
// We'll now make a sample invoice stream, and use that to encode the
// amp state we created above.
tlvStream, err := tlv.NewStream(
tlv.MakeDynamicRecord(
invoiceAmpStateType, &ampState, ampState.recordSize,
ampStateEncoder, ampStateDecoder,
),
)
require.Nil(t, err)
// Next encode the stream into a set of raw bytes.
var b bytes.Buffer
err = tlvStream.Encode(&b)
require.Nil(t, err)
// Now create a new blank ampState map, which we'll use to decode the
// bytes into.
ampState2 := make(AMPInvoiceState)
// Decode from the raw stream into this blank mpa.
tlvStream, err = tlv.NewStream(
tlv.MakeDynamicRecord(
invoiceAmpStateType, &ampState2, nil,
ampStateEncoder, ampStateDecoder,
),
)
require.Nil(t, err)
err = tlvStream.Decode(&b)
require.Nil(t, err)
// The two states should match.
require.Equal(t, ampState, ampState2)
}

View File

@ -200,21 +200,31 @@ const (
// prevents against the database being rolled back to an older // prevents against the database being rolled back to an older
// format where the surrounding logic might assume a different set of // format where the surrounding logic might assume a different set of
// fields are known. // fields are known.
memoType tlv.Type = 0 memoType tlv.Type = 0
payReqType tlv.Type = 1 payReqType tlv.Type = 1
createTimeType tlv.Type = 2 createTimeType tlv.Type = 2
settleTimeType tlv.Type = 3 settleTimeType tlv.Type = 3
addIndexType tlv.Type = 4 addIndexType tlv.Type = 4
settleIndexType tlv.Type = 5 settleIndexType tlv.Type = 5
preimageType tlv.Type = 6 preimageType tlv.Type = 6
valueType tlv.Type = 7 valueType tlv.Type = 7
cltvDeltaType tlv.Type = 8 cltvDeltaType tlv.Type = 8
expiryType tlv.Type = 9 expiryType tlv.Type = 9
paymentAddrType tlv.Type = 10 paymentAddrType tlv.Type = 10
featuresType tlv.Type = 11 featuresType tlv.Type = 11
invStateType tlv.Type = 12 invStateType tlv.Type = 12
amtPaidType tlv.Type = 13 amtPaidType tlv.Type = 13
hodlInvoiceType tlv.Type = 14 hodlInvoiceType tlv.Type = 14
invoiceAmpStateType tlv.Type = 15
// A set of tlv type definitions used to serialize the invoice AMP
// state along-side the main invoice body.
ampStateSetIDType tlv.Type = 0
ampStateHtlcStateType tlv.Type = 1
ampStateSettleIndexType tlv.Type = 2
ampStateSettleDateType tlv.Type = 3
ampStateCircuitKeysType tlv.Type = 4
ampStateAmtPaidType tlv.Type = 5
) )
// InvoiceRef is a composite identifier for invoices. Invoices can be referenced // InvoiceRef is a composite identifier for invoices. Invoices can be referenced
@ -401,6 +411,63 @@ func (c ContractTerm) String() string {
c.Expiry, c.FinalCltvDelta) c.Expiry, c.FinalCltvDelta)
} }
// SetID is the extra unique tuple item for AMP invoices. In addition to
// setting a payment address, each repeated payment to an AMP invoice will also
// contain a set ID as well.
type SetID [32]byte
// InvoiceStateAMP is a struct that associates the current state of an AMP
// invoice identified by its set ID along with the set of invoices identified
// by the circuit key. This allows callers to easily look up the latest state
// of an AMP "sub-invoice" and also look up the invoice HLTCs themselves in the
// greater HTLC map index.
type InvoiceStateAMP struct {
// State is the state of this sub-AMP invoice.
State HtlcState
// SettleIndex indicates the location in the settle index that
// references this instance of InvoiceStateAMP, but only if
// this value is set (non-zero), and State is HtlcStateSettled.
SettleIndex uint64
// SettleDate is the date that the setID was settled.
SettleDate time.Time
// InvoiceKeys is the set of circuit keys that can be used to locate
// the invoices for a given set ID.
InvoiceKeys map[CircuitKey]struct{}
// AmtPaid is the total amount that was paid in the AMP sub-invoice.
// Fetching the full HTLC/invoice state allows one to extract the
// custom records as well as the break down of the payment splits used
// when paying.
AmtPaid lnwire.MilliSatoshi
}
// AMPInvoiceState represents a type that stores metadata related to the set of
// settled AMP "sub-invoices".
type AMPInvoiceState map[SetID]InvoiceStateAMP
// recordSize returns the amount of bytes this TLV record will occupy when
// encoded.
func (a *AMPInvoiceState) recordSize() uint64 {
var (
b bytes.Buffer
buf [8]byte
)
// We know that encoding works since the tests pass in the build this file
// is checked into, so we'll simplify things and simply encode it ourselves
// then report the total amount of bytes used.
if err := ampStateEncoder(&b, a, &buf); err != nil {
// This should never error out, but we log it just in case it
// does.
log.Errorf("encoding the amp invoice state failed: %v", err)
}
return uint64(len(b.Bytes()))
}
// Invoice is a payment invoice generated by a payee in order to request // Invoice is a payment invoice generated by a payee in order to request
// payment for some good or service. The inclusion of invoices within Lightning // payment for some good or service. The inclusion of invoices within Lightning
// creates a payment work flow for merchants very similar to that of the // creates a payment work flow for merchants very similar to that of the
@ -453,7 +520,9 @@ type Invoice struct {
// NOTE: This index starts at 1. // NOTE: This index starts at 1.
SettleIndex uint64 SettleIndex uint64
// State describes the state the invoice is in. // State describes the state the invoice is in. This is the global
// state of the invoice which may remain open even when a series of
// sub-invoices for this invoice has been settled.
State ContractState State ContractState
// AmtPaid is the final amount that we ultimately accepted for pay for // AmtPaid is the final amount that we ultimately accepted for pay for
@ -466,6 +535,13 @@ type Invoice struct {
// htlcs may have been marked as canceled. // htlcs may have been marked as canceled.
Htlcs map[CircuitKey]*InvoiceHTLC Htlcs map[CircuitKey]*InvoiceHTLC
// AMPState describes the state of any related sub-invoices AMP to this
// greater invoice. A sub-invoice is defined by a set of HTLCs with the
// same set ID that attempt to make one time or recurring payments to
// this greater invoice. It's possible for a sub-invoice to be canceled
// or settled, but the greater invoice still open.
AMPState AMPInvoiceState
// HodlInvoice indicates whether the invoice should be held in the // HodlInvoice indicates whether the invoice should be held in the
// Accepted state or be settled right away. // Accepted state or be settled right away.
HodlInvoice bool HodlInvoice bool
@ -635,7 +711,7 @@ type InvoiceHtlcAMPData struct {
// reconstruction of the shares in the AMP payload. // reconstruction of the shares in the AMP payload.
// //
// NOTE: Preimage will only be present once the HTLC is in // NOTE: Preimage will only be present once the HTLC is in
// HltcStateSetteled. // HtlcStateSettled.
Preimage *lntypes.Preimage Preimage *lntypes.Preimage
} }
@ -1472,6 +1548,13 @@ func serializeInvoice(w io.Writer, i *Invoice) error {
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice), tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
// Invoice AMP state.
tlv.MakeDynamicRecord(
invoiceAmpStateType, &i.AMPState,
i.AMPState.recordSize,
ampStateEncoder, ampStateDecoder,
),
) )
if err != nil { if err != nil {
return err return err
@ -1623,6 +1706,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
) )
var i Invoice var i Invoice
i.AMPState = make(AMPInvoiceState)
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
// Memo and payreq. // Memo and payreq.
tlv.MakePrimitiveRecord(memoType, &i.Memo), tlv.MakePrimitiveRecord(memoType, &i.Memo),
@ -1647,6 +1731,12 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(amtPaidType, &amtPaid),
tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice), tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice),
// Invoice AMP state.
tlv.MakeDynamicRecord(
invoiceAmpStateType, &i.AMPState, nil,
ampStateEncoder, ampStateDecoder,
),
) )
if err != nil { if err != nil {
return i, err return i, err
@ -1704,6 +1794,240 @@ func deserializeInvoice(r io.Reader) (Invoice, error) {
return i, err return i, err
} }
func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*map[CircuitKey]struct{}); ok {
// We encode the set of circuit keys as a varint length prefix.
// followed by a series of fixed sized uint8 integers.
numKeys := uint64(len(*v))
if err := tlv.WriteVarInt(w, numKeys, buf); err != nil {
return err
}
for key := range *v {
scidInt := key.ChanID.ToUint64()
if err := tlv.EUint64(w, &scidInt, buf); err != nil {
return err
}
if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}")
}
func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*map[CircuitKey]struct{}); ok {
// First, we'll read out the varint that encodes the number of
// circuit keys encoded.
numKeys, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many keys to expect, iterate reading each
// one until we're done.
for i := uint64(0); i < numKeys; i++ {
var (
key CircuitKey
scid uint64
)
if err := tlv.DUint64(r, &scid, buf, 8); err != nil {
return err
}
key.ChanID = lnwire.NewShortChanIDFromInt(scid)
if err := tlv.DUint64(r, &key.HtlcID, buf, 8); err != nil {
return err
}
(*v)[key] = struct{}{}
}
return nil
}
return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l)
}
// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record.
func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error {
if v, ok := val.(*AMPInvoiceState); ok {
// We'll encode the AMP state as a series of KV pairs on the
// wire with a length prefix.
numRecords := uint64(len(*v))
// First, we'll write out the number of records as a var int.
if err := tlv.WriteVarInt(w, numRecords, buf); err != nil {
return err
}
// With that written out, we'll now encode the entries
// themselves as a sub-TLV record, which includes its _own_
// inner length prefix.
for setID, ampState := range *v {
setID := [32]byte(setID)
ampState := ampState
htlcState := uint8(ampState.State)
settleDateBytes, err := ampState.SettleDate.MarshalBinary()
if err != nil {
return err
}
amtPaid := uint64(ampState.AmtPaid)
var ampStateTlvBytes bytes.Buffer
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(
ampStateSetIDType, &setID,
),
tlv.MakePrimitiveRecord(
ampStateHtlcStateType, &htlcState,
),
tlv.MakePrimitiveRecord(
ampStateSettleIndexType, &ampState.SettleIndex,
),
tlv.MakePrimitiveRecord(
ampStateSettleDateType, &settleDateBytes,
),
tlv.MakeDynamicRecord(
ampStateCircuitKeysType,
&ampState.InvoiceKeys,
func() uint64 {
// The record takes 8 bytes to encode the
// set of circuits, 8 bytes for the scid
// for the key, and 8 bytes for the HTLC
// index.
numKeys := uint64(len(ampState.InvoiceKeys))
return tlv.VarIntSize(numKeys) + (numKeys * 16)
},
encodeCircuitKeys, decodeCircuitKeys,
),
tlv.MakePrimitiveRecord(
ampStateAmtPaidType, &amtPaid,
),
)
if err != nil {
return err
}
if err := tlvStream.Encode(&ampStateTlvBytes); err != nil {
return err
}
// We encode the record with a varint length followed by
// the _raw_ TLV bytes.
tlvLen := uint64(len(ampStateTlvBytes.Bytes()))
if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil {
return err
}
if _, err := w.Write(ampStateTlvBytes.Bytes()); err != nil {
return err
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
}
// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record.
func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error {
if v, ok := val.(*AMPInvoiceState); ok {
// First, we'll decode the varint that encodes how many set IDs
// are encoded within the greater map.
numRecords, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Now that we know how many records we'll need to read, we can
// iterate and read them all out in series.
for i := uint64(0); i < numRecords; i++ {
// Read out the varint that encodes the size of this inner
// TLV record
stateRecordSize, err := tlv.ReadVarInt(r, buf)
if err != nil {
return err
}
// Using this information, we'll create a new limited
// reader that'll return an EOF once the end has been
// reached so the stream stops consuming bytes.
innerTlvReader := io.LimitedReader{
R: r,
N: int64(stateRecordSize),
}
var (
setID [32]byte
htlcState uint8
settleIndex uint64
settleDateBytes []byte
invoiceKeys = make(map[CircuitKey]struct{})
amtPaid uint64
)
tlvStream, err := tlv.NewStream(
tlv.MakePrimitiveRecord(
ampStateSetIDType, &setID,
),
tlv.MakePrimitiveRecord(
ampStateHtlcStateType, &htlcState,
),
tlv.MakePrimitiveRecord(
ampStateSettleIndexType, &settleIndex,
),
tlv.MakePrimitiveRecord(
ampStateSettleDateType, &settleDateBytes,
),
tlv.MakeDynamicRecord(
ampStateCircuitKeysType,
&invoiceKeys, nil,
encodeCircuitKeys, decodeCircuitKeys,
),
tlv.MakePrimitiveRecord(
ampStateAmtPaidType, &amtPaid,
),
)
if err != nil {
return err
}
if err := tlvStream.Decode(&innerTlvReader); err != nil {
return err
}
var settleDate time.Time
err = settleDate.UnmarshalBinary(settleDateBytes)
if err != nil {
return err
}
(*v)[setID] = InvoiceStateAMP{
State: HtlcState(htlcState),
SettleIndex: settleIndex,
SettleDate: settleDate,
InvoiceKeys: invoiceKeys,
AmtPaid: lnwire.MilliSatoshi(amtPaid),
}
}
return nil
}
return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState")
}
// deserializeHtlcs reads a list of invoice htlcs from a reader and returns it // deserializeHtlcs reads a list of invoice htlcs from a reader and returns it
// as a map. // as a map.
func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) { func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {