mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-10-05 18:13:31 +02:00
channeldb+invoices: refactor invoice logic when updating
This commit is contained in:
@@ -494,6 +494,31 @@ type InvoiceStateAMP struct {
|
|||||||
AmtPaid lnwire.MilliSatoshi
|
AmtPaid lnwire.MilliSatoshi
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copy makes a deep copy of the underlying InvoiceStateAMP.
|
||||||
|
func (i *InvoiceStateAMP) copy() (InvoiceStateAMP, error) {
|
||||||
|
result := *i
|
||||||
|
|
||||||
|
// Make a copy of the InvoiceKeys map.
|
||||||
|
result.InvoiceKeys = make(map[CircuitKey]struct{})
|
||||||
|
for k := range i.InvoiceKeys {
|
||||||
|
result.InvoiceKeys[k] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// As a safety measure, copy SettleDate. time.Time is concurrency safe
|
||||||
|
// except when using any of the (un)marshalling methods.
|
||||||
|
settleDateBytes, err := i.SettleDate.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
return InvoiceStateAMP{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = result.SettleDate.UnmarshalBinary(settleDateBytes)
|
||||||
|
if err != nil {
|
||||||
|
return InvoiceStateAMP{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
// AMPInvoiceState represents a type that stores metadata related to the set of
|
// AMPInvoiceState represents a type that stores metadata related to the set of
|
||||||
// settled AMP "sub-invoices".
|
// settled AMP "sub-invoices".
|
||||||
type AMPInvoiceState map[SetID]InvoiceStateAMP
|
type AMPInvoiceState map[SetID]InvoiceStateAMP
|
||||||
@@ -2418,7 +2443,7 @@ func copySlice(src []byte) []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// copyInvoice makes a deep copy of the supplied invoice.
|
// copyInvoice makes a deep copy of the supplied invoice.
|
||||||
func copyInvoice(src *Invoice) *Invoice {
|
func copyInvoice(src *Invoice) (*Invoice, error) {
|
||||||
dest := Invoice{
|
dest := Invoice{
|
||||||
Memo: copySlice(src.Memo),
|
Memo: copySlice(src.Memo),
|
||||||
PaymentRequest: copySlice(src.PaymentRequest),
|
PaymentRequest: copySlice(src.PaymentRequest),
|
||||||
@@ -2432,6 +2457,7 @@ func copyInvoice(src *Invoice) *Invoice {
|
|||||||
Htlcs: make(
|
Htlcs: make(
|
||||||
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
|
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
|
||||||
),
|
),
|
||||||
|
AMPState: make(map[SetID]InvoiceStateAMP),
|
||||||
HodlInvoice: src.HodlInvoice,
|
HodlInvoice: src.HodlInvoice,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2446,7 +2472,17 @@ func copyInvoice(src *Invoice) *Invoice {
|
|||||||
dest.Htlcs[k] = v.Copy()
|
dest.Htlcs[k] = v.Copy()
|
||||||
}
|
}
|
||||||
|
|
||||||
return &dest
|
// Lastly, copy the amp invoice state.
|
||||||
|
for k, v := range src.AMPState {
|
||||||
|
ampInvState, err := v.copy()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
dest.AMPState[k] = ampInvState
|
||||||
|
}
|
||||||
|
|
||||||
|
return &dest, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// invoiceSetIDKeyLen is the length of the key that's used to store the
|
// invoiceSetIDKeyLen is the length of the key that's used to store the
|
||||||
@@ -2628,7 +2664,10 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices,
|
|||||||
|
|
||||||
// Create deep copy to prevent any accidental modification in the
|
// Create deep copy to prevent any accidental modification in the
|
||||||
// callback.
|
// callback.
|
||||||
invoiceCopy := copyInvoice(&invoice)
|
invoiceCopy, err := copyInvoice(&invoice)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// Call the callback and obtain the update descriptor.
|
// Call the callback and obtain the update descriptor.
|
||||||
update, err := callback(invoiceCopy)
|
update, err := callback(invoiceCopy)
|
||||||
|
@@ -697,8 +697,9 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef,
|
|||||||
// Try to mark the specified htlc as canceled in the invoice database.
|
// Try to mark the specified htlc as canceled in the invoice database.
|
||||||
// Intercept the update descriptor to set the local updated variable. If
|
// Intercept the update descriptor to set the local updated variable. If
|
||||||
// no invoice update is performed, we can return early.
|
// no invoice update is performed, we can return early.
|
||||||
|
setID := (*channeldb.SetID)(invoiceRef.SetID())
|
||||||
var updated bool
|
var updated bool
|
||||||
invoice, err := i.cdb.UpdateInvoice(invoiceRef, nil,
|
invoice, err := i.cdb.UpdateInvoice(invoiceRef, setID,
|
||||||
func(invoice *channeldb.Invoice) (
|
func(invoice *channeldb.Invoice) (
|
||||||
*channeldb.InvoiceUpdateDesc, error) {
|
*channeldb.InvoiceUpdateDesc, error) {
|
||||||
|
|
||||||
@@ -958,8 +959,15 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
|
|||||||
// main event loop.
|
// main event loop.
|
||||||
case *htlcAcceptResolution:
|
case *htlcAcceptResolution:
|
||||||
if r.autoRelease {
|
if r.autoRelease {
|
||||||
|
var invRef channeldb.InvoiceRef
|
||||||
|
if ctx.amp != nil {
|
||||||
|
invRef = channeldb.InvoiceRefBySetID(*ctx.setID())
|
||||||
|
} else {
|
||||||
|
invRef = ctx.invoiceRef()
|
||||||
|
}
|
||||||
|
|
||||||
err := i.startHtlcTimer(
|
err := i.startHtlcTimer(
|
||||||
ctx.invoiceRef(), circuitKey, r.acceptTime,
|
invRef, circuitKey, r.acceptTime,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1015,6 +1023,14 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
|
|||||||
return updateDesc, nil
|
return updateDesc, nil
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _, ok := err.(channeldb.ErrDuplicateSetID); ok {
|
||||||
|
return NewFailResolution(
|
||||||
|
ctx.circuitKey, ctx.currentHeight,
|
||||||
|
ResultInvoiceNotFound,
|
||||||
|
), nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
switch err {
|
switch err {
|
||||||
case channeldb.ErrInvoiceNotFound:
|
case channeldb.ErrInvoiceNotFound:
|
||||||
// If the invoice was not found, return a failure resolution
|
// If the invoice was not found, return a failure resolution
|
||||||
@@ -1024,6 +1040,12 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
|
|||||||
ResultInvoiceNotFound,
|
ResultInvoiceNotFound,
|
||||||
), nil, nil
|
), nil, nil
|
||||||
|
|
||||||
|
case channeldb.ErrInvRefEquivocation:
|
||||||
|
return NewFailResolution(
|
||||||
|
ctx.circuitKey, ctx.currentHeight,
|
||||||
|
ResultInvoiceNotFound,
|
||||||
|
), nil, nil
|
||||||
|
|
||||||
case nil:
|
case nil:
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@@ -227,6 +227,10 @@ func updateMpp(ctx *invoiceUpdateCtx,
|
|||||||
return nil, ctx.failRes(ResultExpiryTooSoon), nil
|
return nil, ctx.failRes(ResultExpiryTooSoon), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if setID != nil && *setID == channeldb.BlankPayAddr {
|
||||||
|
return nil, ctx.failRes(ResultAmpError), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Record HTLC in the invoice database.
|
// Record HTLC in the invoice database.
|
||||||
newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{
|
newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{
|
||||||
ctx.circuitKey: acceptDesc,
|
ctx.circuitKey: acceptDesc,
|
||||||
|
Reference in New Issue
Block a user