diff --git a/invoices/invoiceregistry.go b/invoices/invoiceregistry.go index 49f0cb0f0..90ffa4e26 100644 --- a/invoices/invoiceregistry.go +++ b/invoices/invoiceregistry.go @@ -654,7 +654,9 @@ func (i *InvoiceRegistry) startHtlcTimer(invoiceRef InvoiceRef, func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, key CircuitKey, result FailResolutionResult) error { - updateInvoice := func(invoice *Invoice) (*InvoiceUpdateDesc, error) { + updateInvoice := func(invoice *Invoice, setID *SetID) ( + *InvoiceUpdateDesc, error) { + // Only allow individual htlc cancellation on open invoices. if invoice.State != ContractOpen { log.Debugf("cancelSingleHtlc: invoice %v no longer "+ @@ -663,37 +665,16 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, return nil, nil } - // Lookup the current status of the htlc in the database. - var ( - htlcState HtlcState - setID *SetID - ) + // Also for AMP invoices we fetch the relevant HTLCs, so + // the HTLC should be found, otherwise we return an error. htlc, ok := invoice.Htlcs[key] if !ok { - // If this is an AMP invoice, then all the HTLCs won't - // be read out, so we'll consult the other mapping to - // try to find the HTLC state in question here. - var found bool - for ampSetID, htlcSet := range invoice.AMPState { - ampSetID := ampSetID - for htlcKey := range htlcSet.InvoiceKeys { - if htlcKey == key { - htlcState = htlcSet.State - setID = &SetID - - found = true - break - } - } - } - - if !found { - return nil, fmt.Errorf("htlc %v not found", key) - } - } else { - htlcState = htlc.State + return nil, fmt.Errorf("htlc %v not found on "+ + "invoice %v", key, invoiceRef) } + htlcState := htlc.State + // Cancellation is only possible if the htlc wasn't already // resolved. if htlcState != HtlcStateAccepted { @@ -729,7 +710,7 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, func(invoice *Invoice) ( *InvoiceUpdateDesc, error) { - updateDesc, err := updateInvoice(invoice) + updateDesc, err := updateInvoice(invoice, setID) if err != nil { return nil, err } @@ -756,8 +737,12 @@ func (i *InvoiceRegistry) cancelSingleHtlc(invoiceRef InvoiceRef, key, int32(htlc.AcceptHeight), result, ) + log.Debugf("Cancelling htlc (%v) of invoice(%v) with "+ + "resolution: %v", key, invoiceRef, result) + i.notifyHodlSubscribers(resolution) } + return nil }