diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index fc61736a3..2a0d01fb9 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -1,6 +1,7 @@ package channeldb import ( + "bytes" "crypto/rand" "fmt" "math" @@ -12,6 +13,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/record" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -49,7 +51,8 @@ func randInvoice(value lnwire.MilliSatoshi) (*Invoice, error) { Value: value, Features: emptyFeatures, }, - Htlcs: map[CircuitKey]*InvoiceHTLC{}, + Htlcs: map[CircuitKey]*InvoiceHTLC{}, + AMPState: map[SetID]InvoiceStateAMP{}, } i.Memo = []byte("memo") @@ -2474,3 +2477,89 @@ func TestAddInvoiceInvalidFeatureDeps(t *testing.T) { lnwire.PaymentAddrOptional, )) } + +// TestEncodeDecodeAmpInvoiceState asserts that the nested TLV +// encoding+decoding for the AMPInvoiceState struct works as expected. +func TestEncodeDecodeAmpInvoiceState(t *testing.T) { + t.Parallel() + + setID1 := [32]byte{1} + setID2 := [32]byte{2} + setID3 := [32]byte{3} + + circuitKey1 := CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt(1), HtlcID: 1, + } + circuitKey2 := CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt(2), HtlcID: 2, + } + circuitKey3 := CircuitKey{ + ChanID: lnwire.NewShortChanIDFromInt(2), HtlcID: 3, + } + + // Make a sample invoice state map that we'll encode then decode to + // assert equality of. + ampState := AMPInvoiceState{ + setID1: InvoiceStateAMP{ + State: HtlcStateSettled, + SettleDate: testNow, + SettleIndex: 1, + InvoiceKeys: map[CircuitKey]struct{}{ + circuitKey1: struct{}{}, + circuitKey2: struct{}{}, + }, + }, + setID2: InvoiceStateAMP{ + State: HtlcStateCanceled, + SettleDate: testNow, + SettleIndex: 2, + InvoiceKeys: map[CircuitKey]struct{}{ + circuitKey1: struct{}{}, + }, + }, + setID3: InvoiceStateAMP{ + State: HtlcStateAccepted, + SettleDate: testNow, + SettleIndex: 3, + InvoiceKeys: map[CircuitKey]struct{}{ + circuitKey1: struct{}{}, + circuitKey2: struct{}{}, + circuitKey3: struct{}{}, + }, + }, + } + + // We'll now make a sample invoice stream, and use that to encode the + // amp state we created above. + tlvStream, err := tlv.NewStream( + tlv.MakeDynamicRecord( + invoiceAmpStateType, &State, ampState.recordSize, + ampStateEncoder, ampStateDecoder, + ), + ) + require.Nil(t, err) + + // Next encode the stream into a set of raw bytes. + var b bytes.Buffer + err = tlvStream.Encode(&b) + require.Nil(t, err) + + // Now create a new blank ampState map, which we'll use to decode the + // bytes into. + ampState2 := make(AMPInvoiceState) + + // Decode from the raw stream into this blank mpa. + tlvStream, err = tlv.NewStream( + tlv.MakeDynamicRecord( + invoiceAmpStateType, &State2, nil, + ampStateEncoder, ampStateDecoder, + ), + ) + require.Nil(t, err) + + err = tlvStream.Decode(&b) + require.Nil(t, err) + + // The two states should match. + require.Equal(t, ampState, ampState2) +} diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 951b30092..2b26d5800 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -200,21 +200,31 @@ const ( // prevents against the database being rolled back to an older // format where the surrounding logic might assume a different set of // fields are known. - memoType tlv.Type = 0 - payReqType tlv.Type = 1 - createTimeType tlv.Type = 2 - settleTimeType tlv.Type = 3 - addIndexType tlv.Type = 4 - settleIndexType tlv.Type = 5 - preimageType tlv.Type = 6 - valueType tlv.Type = 7 - cltvDeltaType tlv.Type = 8 - expiryType tlv.Type = 9 - paymentAddrType tlv.Type = 10 - featuresType tlv.Type = 11 - invStateType tlv.Type = 12 - amtPaidType tlv.Type = 13 - hodlInvoiceType tlv.Type = 14 + memoType tlv.Type = 0 + payReqType tlv.Type = 1 + createTimeType tlv.Type = 2 + settleTimeType tlv.Type = 3 + addIndexType tlv.Type = 4 + settleIndexType tlv.Type = 5 + preimageType tlv.Type = 6 + valueType tlv.Type = 7 + cltvDeltaType tlv.Type = 8 + expiryType tlv.Type = 9 + paymentAddrType tlv.Type = 10 + featuresType tlv.Type = 11 + invStateType tlv.Type = 12 + amtPaidType tlv.Type = 13 + hodlInvoiceType tlv.Type = 14 + invoiceAmpStateType tlv.Type = 15 + + // A set of tlv type definitions used to serialize the invoice AMP + // state along-side the main invoice body. + ampStateSetIDType tlv.Type = 0 + ampStateHtlcStateType tlv.Type = 1 + ampStateSettleIndexType tlv.Type = 2 + ampStateSettleDateType tlv.Type = 3 + ampStateCircuitKeysType tlv.Type = 4 + ampStateAmtPaidType tlv.Type = 5 ) // InvoiceRef is a composite identifier for invoices. Invoices can be referenced @@ -401,6 +411,63 @@ func (c ContractTerm) String() string { c.Expiry, c.FinalCltvDelta) } +// SetID is the extra unique tuple item for AMP invoices. In addition to +// setting a payment address, each repeated payment to an AMP invoice will also +// contain a set ID as well. +type SetID [32]byte + +// InvoiceStateAMP is a struct that associates the current state of an AMP +// invoice identified by its set ID along with the set of invoices identified +// by the circuit key. This allows callers to easily look up the latest state +// of an AMP "sub-invoice" and also look up the invoice HLTCs themselves in the +// greater HTLC map index. +type InvoiceStateAMP struct { + // State is the state of this sub-AMP invoice. + State HtlcState + + // SettleIndex indicates the location in the settle index that + // references this instance of InvoiceStateAMP, but only if + // this value is set (non-zero), and State is HtlcStateSettled. + SettleIndex uint64 + + // SettleDate is the date that the setID was settled. + SettleDate time.Time + + // InvoiceKeys is the set of circuit keys that can be used to locate + // the invoices for a given set ID. + InvoiceKeys map[CircuitKey]struct{} + + // AmtPaid is the total amount that was paid in the AMP sub-invoice. + // Fetching the full HTLC/invoice state allows one to extract the + // custom records as well as the break down of the payment splits used + // when paying. + AmtPaid lnwire.MilliSatoshi +} + +// AMPInvoiceState represents a type that stores metadata related to the set of +// settled AMP "sub-invoices". +type AMPInvoiceState map[SetID]InvoiceStateAMP + +// recordSize returns the amount of bytes this TLV record will occupy when +// encoded. +func (a *AMPInvoiceState) recordSize() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + + // We know that encoding works since the tests pass in the build this file + // is checked into, so we'll simplify things and simply encode it ourselves + // then report the total amount of bytes used. + if err := ampStateEncoder(&b, a, &buf); err != nil { + // This should never error out, but we log it just in case it + // does. + log.Errorf("encoding the amp invoice state failed: %v", err) + } + + return uint64(len(b.Bytes())) +} + // Invoice is a payment invoice generated by a payee in order to request // payment for some good or service. The inclusion of invoices within Lightning // creates a payment work flow for merchants very similar to that of the @@ -453,7 +520,9 @@ type Invoice struct { // NOTE: This index starts at 1. SettleIndex uint64 - // State describes the state the invoice is in. + // State describes the state the invoice is in. This is the global + // state of the invoice which may remain open even when a series of + // sub-invoices for this invoice has been settled. State ContractState // AmtPaid is the final amount that we ultimately accepted for pay for @@ -466,6 +535,13 @@ type Invoice struct { // htlcs may have been marked as canceled. Htlcs map[CircuitKey]*InvoiceHTLC + // AMPState describes the state of any related sub-invoices AMP to this + // greater invoice. A sub-invoice is defined by a set of HTLCs with the + // same set ID that attempt to make one time or recurring payments to + // this greater invoice. It's possible for a sub-invoice to be canceled + // or settled, but the greater invoice still open. + AMPState AMPInvoiceState + // HodlInvoice indicates whether the invoice should be held in the // Accepted state or be settled right away. HodlInvoice bool @@ -635,7 +711,7 @@ type InvoiceHtlcAMPData struct { // reconstruction of the shares in the AMP payload. // // NOTE: Preimage will only be present once the HTLC is in - // HltcStateSetteled. + // HtlcStateSettled. Preimage *lntypes.Preimage } @@ -1472,6 +1548,13 @@ func serializeInvoice(w io.Writer, i *Invoice) error { tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice), + + // Invoice AMP state. + tlv.MakeDynamicRecord( + invoiceAmpStateType, &i.AMPState, + i.AMPState.recordSize, + ampStateEncoder, ampStateDecoder, + ), ) if err != nil { return err @@ -1623,6 +1706,7 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { ) var i Invoice + i.AMPState = make(AMPInvoiceState) tlvStream, err := tlv.NewStream( // Memo and payreq. tlv.MakePrimitiveRecord(memoType, &i.Memo), @@ -1647,6 +1731,12 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { tlv.MakePrimitiveRecord(amtPaidType, &amtPaid), tlv.MakePrimitiveRecord(hodlInvoiceType, &hodlInvoice), + + // Invoice AMP state. + tlv.MakeDynamicRecord( + invoiceAmpStateType, &i.AMPState, nil, + ampStateEncoder, ampStateDecoder, + ), ) if err != nil { return i, err @@ -1704,6 +1794,240 @@ func deserializeInvoice(r io.Reader) (Invoice, error) { return i, err } +func encodeCircuitKeys(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*map[CircuitKey]struct{}); ok { + // We encode the set of circuit keys as a varint length prefix. + // followed by a series of fixed sized uint8 integers. + numKeys := uint64(len(*v)) + + if err := tlv.WriteVarInt(w, numKeys, buf); err != nil { + return err + } + + for key := range *v { + scidInt := key.ChanID.ToUint64() + + if err := tlv.EUint64(w, &scidInt, buf); err != nil { + return err + } + if err := tlv.EUint64(w, &key.HtlcID, buf); err != nil { + return err + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "*map[CircuitKey]struct{}") +} + +func decodeCircuitKeys(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*map[CircuitKey]struct{}); ok { + // First, we'll read out the varint that encodes the number of + // circuit keys encoded. + numKeys, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Now that we know how many keys to expect, iterate reading each + // one until we're done. + for i := uint64(0); i < numKeys; i++ { + var ( + key CircuitKey + scid uint64 + ) + + if err := tlv.DUint64(r, &scid, buf, 8); err != nil { + return err + } + + key.ChanID = lnwire.NewShortChanIDFromInt(scid) + + if err := tlv.DUint64(r, &key.HtlcID, buf, 8); err != nil { + return err + } + + (*v)[key] = struct{}{} + } + + return nil + } + + return tlv.NewTypeForDecodingErr(val, "*map[CircuitKey]struct{}", l, l) +} + +// ampStateEncoder is a custom TLV encoder for the AMPInvoiceState record. +func ampStateEncoder(w io.Writer, val interface{}, buf *[8]byte) error { + if v, ok := val.(*AMPInvoiceState); ok { + // We'll encode the AMP state as a series of KV pairs on the + // wire with a length prefix. + numRecords := uint64(len(*v)) + + // First, we'll write out the number of records as a var int. + if err := tlv.WriteVarInt(w, numRecords, buf); err != nil { + return err + } + + // With that written out, we'll now encode the entries + // themselves as a sub-TLV record, which includes its _own_ + // inner length prefix. + for setID, ampState := range *v { + setID := [32]byte(setID) + ampState := ampState + + htlcState := uint8(ampState.State) + settleDateBytes, err := ampState.SettleDate.MarshalBinary() + if err != nil { + return err + } + + amtPaid := uint64(ampState.AmtPaid) + + var ampStateTlvBytes bytes.Buffer + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord( + ampStateSetIDType, &setID, + ), + tlv.MakePrimitiveRecord( + ampStateHtlcStateType, &htlcState, + ), + tlv.MakePrimitiveRecord( + ampStateSettleIndexType, &State.SettleIndex, + ), + tlv.MakePrimitiveRecord( + ampStateSettleDateType, &settleDateBytes, + ), + tlv.MakeDynamicRecord( + ampStateCircuitKeysType, + &State.InvoiceKeys, + func() uint64 { + // The record takes 8 bytes to encode the + // set of circuits, 8 bytes for the scid + // for the key, and 8 bytes for the HTLC + // index. + numKeys := uint64(len(ampState.InvoiceKeys)) + return tlv.VarIntSize(numKeys) + (numKeys * 16) + }, + encodeCircuitKeys, decodeCircuitKeys, + ), + tlv.MakePrimitiveRecord( + ampStateAmtPaidType, &amtPaid, + ), + ) + if err != nil { + return err + } + + if err := tlvStream.Encode(&StateTlvBytes); err != nil { + return err + } + + // We encode the record with a varint length followed by + // the _raw_ TLV bytes. + tlvLen := uint64(len(ampStateTlvBytes.Bytes())) + if err := tlv.WriteVarInt(w, tlvLen, buf); err != nil { + return err + } + + if _, err := w.Write(ampStateTlvBytes.Bytes()); err != nil { + return err + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState") +} + +// ampStateDecoder is a custom TLV decoder for the AMPInvoiceState record. +func ampStateDecoder(r io.Reader, val interface{}, buf *[8]byte, l uint64) error { + if v, ok := val.(*AMPInvoiceState); ok { + // First, we'll decode the varint that encodes how many set IDs + // are encoded within the greater map. + numRecords, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Now that we know how many records we'll need to read, we can + // iterate and read them all out in series. + for i := uint64(0); i < numRecords; i++ { + // Read out the varint that encodes the size of this inner + // TLV record + stateRecordSize, err := tlv.ReadVarInt(r, buf) + if err != nil { + return err + } + + // Using this information, we'll create a new limited + // reader that'll return an EOF once the end has been + // reached so the stream stops consuming bytes. + innerTlvReader := io.LimitedReader{ + R: r, + N: int64(stateRecordSize), + } + + var ( + setID [32]byte + htlcState uint8 + settleIndex uint64 + settleDateBytes []byte + invoiceKeys = make(map[CircuitKey]struct{}) + amtPaid uint64 + ) + tlvStream, err := tlv.NewStream( + tlv.MakePrimitiveRecord( + ampStateSetIDType, &setID, + ), + tlv.MakePrimitiveRecord( + ampStateHtlcStateType, &htlcState, + ), + tlv.MakePrimitiveRecord( + ampStateSettleIndexType, &settleIndex, + ), + tlv.MakePrimitiveRecord( + ampStateSettleDateType, &settleDateBytes, + ), + tlv.MakeDynamicRecord( + ampStateCircuitKeysType, + &invoiceKeys, nil, + encodeCircuitKeys, decodeCircuitKeys, + ), + tlv.MakePrimitiveRecord( + ampStateAmtPaidType, &amtPaid, + ), + ) + if err != nil { + return err + } + + if err := tlvStream.Decode(&innerTlvReader); err != nil { + return err + } + + var settleDate time.Time + err = settleDate.UnmarshalBinary(settleDateBytes) + if err != nil { + return err + } + + (*v)[setID] = InvoiceStateAMP{ + State: HtlcState(htlcState), + SettleIndex: settleIndex, + SettleDate: settleDate, + InvoiceKeys: invoiceKeys, + AmtPaid: lnwire.MilliSatoshi(amtPaid), + } + } + + return nil + } + + return tlv.NewTypeForEncodingErr(val, "channeldb.AMPInvoiceState") +} + // deserializeHtlcs reads a list of invoice htlcs from a reader and returns it // as a map. func deserializeHtlcs(r io.Reader) (map[CircuitKey]*InvoiceHTLC, error) {