channeldb+invoices: refactor invoice logic when updating

This commit is contained in:
eugene
2022-04-14 16:07:06 -04:00
parent c2adb03e38
commit 4eea395a7f
3 changed files with 70 additions and 5 deletions

View File

@@ -494,6 +494,31 @@ type InvoiceStateAMP struct {
AmtPaid lnwire.MilliSatoshi AmtPaid lnwire.MilliSatoshi
} }
// copy makes a deep copy of the underlying InvoiceStateAMP.
func (i *InvoiceStateAMP) copy() (InvoiceStateAMP, error) {
result := *i
// Make a copy of the InvoiceKeys map.
result.InvoiceKeys = make(map[CircuitKey]struct{})
for k := range i.InvoiceKeys {
result.InvoiceKeys[k] = struct{}{}
}
// As a safety measure, copy SettleDate. time.Time is concurrency safe
// except when using any of the (un)marshalling methods.
settleDateBytes, err := i.SettleDate.MarshalBinary()
if err != nil {
return InvoiceStateAMP{}, err
}
err = result.SettleDate.UnmarshalBinary(settleDateBytes)
if err != nil {
return InvoiceStateAMP{}, err
}
return result, nil
}
// AMPInvoiceState represents a type that stores metadata related to the set of // AMPInvoiceState represents a type that stores metadata related to the set of
// settled AMP "sub-invoices". // settled AMP "sub-invoices".
type AMPInvoiceState map[SetID]InvoiceStateAMP type AMPInvoiceState map[SetID]InvoiceStateAMP
@@ -2418,7 +2443,7 @@ func copySlice(src []byte) []byte {
} }
// copyInvoice makes a deep copy of the supplied invoice. // copyInvoice makes a deep copy of the supplied invoice.
func copyInvoice(src *Invoice) *Invoice { func copyInvoice(src *Invoice) (*Invoice, error) {
dest := Invoice{ dest := Invoice{
Memo: copySlice(src.Memo), Memo: copySlice(src.Memo),
PaymentRequest: copySlice(src.PaymentRequest), PaymentRequest: copySlice(src.PaymentRequest),
@@ -2432,6 +2457,7 @@ func copyInvoice(src *Invoice) *Invoice {
Htlcs: make( Htlcs: make(
map[CircuitKey]*InvoiceHTLC, len(src.Htlcs), map[CircuitKey]*InvoiceHTLC, len(src.Htlcs),
), ),
AMPState: make(map[SetID]InvoiceStateAMP),
HodlInvoice: src.HodlInvoice, HodlInvoice: src.HodlInvoice,
} }
@@ -2446,7 +2472,17 @@ func copyInvoice(src *Invoice) *Invoice {
dest.Htlcs[k] = v.Copy() dest.Htlcs[k] = v.Copy()
} }
return &dest // Lastly, copy the amp invoice state.
for k, v := range src.AMPState {
ampInvState, err := v.copy()
if err != nil {
return nil, err
}
dest.AMPState[k] = ampInvState
}
return &dest, nil
} }
// invoiceSetIDKeyLen is the length of the key that's used to store the // invoiceSetIDKeyLen is the length of the key that's used to store the
@@ -2628,7 +2664,10 @@ func (d *DB) updateInvoice(hash *lntypes.Hash, refSetID *SetID, invoices,
// Create deep copy to prevent any accidental modification in the // Create deep copy to prevent any accidental modification in the
// callback. // callback.
invoiceCopy := copyInvoice(&invoice) invoiceCopy, err := copyInvoice(&invoice)
if err != nil {
return nil, err
}
// Call the callback and obtain the update descriptor. // Call the callback and obtain the update descriptor.
update, err := callback(invoiceCopy) update, err := callback(invoiceCopy)

View File

@@ -697,8 +697,9 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef channeldb.InvoiceRef,
// Try to mark the specified htlc as canceled in the invoice database. // Try to mark the specified htlc as canceled in the invoice database.
// Intercept the update descriptor to set the local updated variable. If // Intercept the update descriptor to set the local updated variable. If
// no invoice update is performed, we can return early. // no invoice update is performed, we can return early.
setID := (*channeldb.SetID)(invoiceRef.SetID())
var updated bool var updated bool
invoice, err := i.cdb.UpdateInvoice(invoiceRef, nil, invoice, err := i.cdb.UpdateInvoice(invoiceRef, setID,
func(invoice *channeldb.Invoice) ( func(invoice *channeldb.Invoice) (
*channeldb.InvoiceUpdateDesc, error) { *channeldb.InvoiceUpdateDesc, error) {
@@ -958,8 +959,15 @@ func (i *InvoiceRegistry) NotifyExitHopHtlc(rHash lntypes.Hash,
// main event loop. // main event loop.
case *htlcAcceptResolution: case *htlcAcceptResolution:
if r.autoRelease { if r.autoRelease {
var invRef channeldb.InvoiceRef
if ctx.amp != nil {
invRef = channeldb.InvoiceRefBySetID(*ctx.setID())
} else {
invRef = ctx.invoiceRef()
}
err := i.startHtlcTimer( err := i.startHtlcTimer(
ctx.invoiceRef(), circuitKey, r.acceptTime, invRef, circuitKey, r.acceptTime,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@@ -1015,6 +1023,14 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
return updateDesc, nil return updateDesc, nil
}, },
) )
if _, ok := err.(channeldb.ErrDuplicateSetID); ok {
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
}
switch err { switch err {
case channeldb.ErrInvoiceNotFound: case channeldb.ErrInvoiceNotFound:
// If the invoice was not found, return a failure resolution // If the invoice was not found, return a failure resolution
@@ -1024,6 +1040,12 @@ func (i *InvoiceRegistry) notifyExitHopHtlcLocked(
ResultInvoiceNotFound, ResultInvoiceNotFound,
), nil, nil ), nil, nil
case channeldb.ErrInvRefEquivocation:
return NewFailResolution(
ctx.circuitKey, ctx.currentHeight,
ResultInvoiceNotFound,
), nil, nil
case nil: case nil:
default: default:

View File

@@ -227,6 +227,10 @@ func updateMpp(ctx *invoiceUpdateCtx,
return nil, ctx.failRes(ResultExpiryTooSoon), nil return nil, ctx.failRes(ResultExpiryTooSoon), nil
} }
if setID != nil && *setID == channeldb.BlankPayAddr {
return nil, ctx.failRes(ResultAmpError), nil
}
// Record HTLC in the invoice database. // Record HTLC in the invoice database.
newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{ newHtlcs := map[channeldb.CircuitKey]*channeldb.HtlcAcceptDesc{
ctx.circuitKey: acceptDesc, ctx.circuitKey: acceptDesc,