diff --git a/channeldb/invoice_test.go b/channeldb/invoice_test.go index d7bbc5b5c..d6d1eaf7a 100644 --- a/channeldb/invoice_test.go +++ b/channeldb/invoice_test.go @@ -2997,7 +2997,14 @@ func TestUpdateHTLC(t *testing.T) { func testUpdateHTLC(t *testing.T, test updateHTLCTest) { htlc := test.input.Copy() - _, err := updateHtlc(testNow, htlc, test.invState, test.setID) + stateChanged, state, err := getUpdatedHtlcState( + htlc, test.invState, test.setID, + ) + if stateChanged { + htlc.State = state + htlc.ResolveTime = testNow + } + require.Equal(t, test.expErr, err) require.Equal(t, test.output, *htlc) } diff --git a/channeldb/invoices.go b/channeldb/invoices.go index 24ca9c805..74b4681d5 100644 --- a/channeldb/invoices.go +++ b/channeldb/invoices.go @@ -2148,13 +2148,22 @@ func (d *DB) addHTLCs(invoices, settleIndex, //nolint:funlen if settleEligibleAMP { htlcContextState = invpkg.ContractSettled } - htlcSettled, err := updateHtlc( - timestamp, htlc, htlcContextState, setID, + + htlcStateChanged, htlcState, err := getUpdatedHtlcState( + htlc, htlcContextState, setID, ) if err != nil { return nil, err } + if htlcStateChanged { + htlc.State = htlcState + htlc.ResolveTime = timestamp + } + + htlcSettled := htlcStateChanged && + htlcState == invpkg.HtlcStateSettled + // If the HTLC has being settled for the first time, and this // is an AMP invoice, then we'll need to update some additional // meta data state. @@ -2334,14 +2343,16 @@ func (d *DB) settleHodlInvoice(invoices, settleIndex kvdb.RwBucket, // TODO(positiveblue): this logic can be further simplified. var amtPaid lnwire.MilliSatoshi for _, htlc := range invoice.Htlcs { - _, err := updateHtlc( - timestamp, htlc, invpkg.ContractSettled, nil, + settled, _, err := getUpdatedHtlcState( + htlc, invpkg.ContractSettled, nil, ) if err != nil { return nil, err } - if htlc.State == invpkg.HtlcStateSettled { + if settled { + htlc.State = invpkg.HtlcStateSettled + htlc.ResolveTime = timestamp amtPaid += htlc.Amt } } @@ -2419,12 +2430,17 @@ func (d *DB) cancelInvoice(invoices kvdb.RwBucket, invoiceNum []byte, // TODO(positiveblue): this logic can be simplified. for _, htlc := range invoice.Htlcs { - _, err := updateHtlc( - timestamp, htlc, invpkg.ContractCanceled, setID, + canceled, _, err := getUpdatedHtlcState( + htlc, invpkg.ContractCanceled, setID, ) if err != nil { return nil, err } + + if canceled { + htlc.State = invpkg.HtlcStateCanceled + htlc.ResolveTime = timestamp + } } err = d.cancelInvoiceStoreUpdate(invoices, invoiceNum, invoice) @@ -2602,21 +2618,23 @@ func cancelSingleHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, return nil } -// updateHtlc aligns the state of an htlc with the given invoice state. A -// boolean is returned if the HTLC was settled. -func updateHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, - invState invpkg.ContractState, setID *[32]byte) (bool, error) { +// getUpdatedHtlcState aligns the state of an htlc with the given invoice state. +// A boolean indicating whether the HTLCs state need to be updated, along with +// the new state (or old state if no change is needed) is returned. +func getUpdatedHtlcState(htlc *invpkg.InvoiceHTLC, + invoiceState invpkg.ContractState, setID *[32]byte) ( + bool, invpkg.HtlcState, error) { - trySettle := func(persist bool) (bool, error) { + trySettle := func(persist bool) (bool, invpkg.HtlcState, error) { if htlc.State != invpkg.HtlcStateAccepted { - return false, nil + return false, htlc.State, nil } // Settle the HTLC if it matches the settled set id. If // there're other HTLCs with distinct setIDs, then we'll leave // them, as they may eventually be settled as we permit // multiple settles to a single pay_addr for AMP. - var htlcState invpkg.HtlcState + settled := false if htlc.IsInHTLCSet(setID) { // Non-AMP HTLCs can be settled immediately since we // already know the preimage is valid due to checks at @@ -2634,29 +2652,31 @@ func updateHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, // // Fail if an accepted AMP HTLC has no preimage. case htlc.AMP.Preimage == nil: - return false, invpkg.ErrHTLCPreimageMissing + return false, htlc.State, + invpkg.ErrHTLCPreimageMissing // Fail if the accepted AMP HTLC has an invalid // preimage. case !htlc.AMP.Preimage.Matches(htlc.AMP.Hash): - return false, invpkg.ErrHTLCPreimageMismatch + return false, htlc.State, + invpkg.ErrHTLCPreimageMismatch } - htlcState = invpkg.HtlcStateSettled + settled = true } // Only persist the changes if the invoice is moving to the // settled state, and we're actually updating the state to // settled. - if persist && htlcState == invpkg.HtlcStateSettled { - htlc.State = htlcState - htlc.ResolveTime = resolveTime + newState := htlc.State + if settled { + newState = invpkg.HtlcStateSettled } - return persist && htlcState == invpkg.HtlcStateSettled, nil + return persist && settled, newState, nil } - if invState == invpkg.ContractSettled { + if invoiceState == invpkg.ContractSettled { // Check that we can settle the HTLCs. For legacy and MPP HTLCs // this will be a NOP, but for AMP HTLCs this asserts that we // have a valid hash/preimage pair. Passing true permits the @@ -2667,16 +2687,13 @@ func updateHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, // We should never find a settled HTLC on an invoice that isn't in // ContractSettled. if htlc.State == invpkg.HtlcStateSettled { - return false, invpkg.ErrHTLCAlreadySettled + return false, htlc.State, invpkg.ErrHTLCAlreadySettled } - switch invState { + switch invoiceState { case invpkg.ContractCanceled: - if htlc.State == invpkg.HtlcStateAccepted { - htlc.State = invpkg.HtlcStateCanceled - htlc.ResolveTime = resolveTime - } - return false, nil + htlcAlreadyCanceled := htlc.State == invpkg.HtlcStateCanceled + return !htlcAlreadyCanceled, invpkg.HtlcStateCanceled, nil // TODO(roasbeef): never fully passed thru now? case invpkg.ContractAccepted: @@ -2688,10 +2705,10 @@ func updateHtlc(resolveTime time.Time, htlc *invpkg.InvoiceHTLC, return trySettle(false) case invpkg.ContractOpen: - return false, nil + return false, htlc.State, nil default: - return false, errors.New("unknown state transition") + return false, htlc.State, errors.New("unknown state transition") } }