diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 0f37a46c8..41b807ef1 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -2494,57 +2494,6 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, skipUs := make(map[uint64]struct{}) skipThem := make(map[uint64]struct{}) - // fetchParentEntry is a helper method that will fetch the parent of - // entry from the corresponding update log. - fetchParentEntry := func(entry *PaymentDescriptor, - remoteLog bool) (*PaymentDescriptor, error) { - - var ( - updateLog *updateLog - logName string - ) - - if remoteLog { - updateLog = lc.remoteUpdateLog - logName = "remote" - } else { - updateLog = lc.localUpdateLog - logName = "local" - } - - addEntry := updateLog.lookupHtlc(entry.ParentIndex) - - switch { - // We check if the parent entry is not found at this point. - // This could happen for old versions of lnd, and we return an - // error to gracefully shut down the state machine if such an - // entry is still in the logs. - case addEntry == nil: - return nil, fmt.Errorf("unable to find parent entry "+ - "%d in %v update log: %v\nUpdatelog: %v", - entry.ParentIndex, logName, - newLogClosure(func() string { - return spew.Sdump(entry) - }), newLogClosure(func() string { - return spew.Sdump(updateLog) - }), - ) - - // The parent add height should never be zero at this point. If - // that's the case we probably forgot to send a new commitment. - case remoteChain && addEntry.addCommitHeightRemote == 0: - return nil, fmt.Errorf("parent entry %d for update %d "+ - "had zero remote add height", entry.ParentIndex, - entry.LogIndex) - case !remoteChain && addEntry.addCommitHeightLocal == 0: - return nil, fmt.Errorf("parent entry %d for update %d "+ - "had zero local add height", entry.ParentIndex, - entry.LogIndex) - } - - return addEntry, nil - } - // First we run through non-add entries in both logs, populating the // skip sets and mutating the current chain state (crediting balances, // etc) to reflect the settle/timeout entry encountered. @@ -2571,7 +2520,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, lc.channelState.TotalMSatReceived += entry.Amount } - addEntry, err := fetchParentEntry(entry, true) + addEntry, err := lc.fetchParent(entry, remoteChain, true) if err != nil { return nil, err } @@ -2604,7 +2553,7 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, lc.channelState.TotalMSatSent += entry.Amount } - addEntry, err := fetchParentEntry(entry, false) + addEntry, err := lc.fetchParent(entry, remoteChain, false) if err != nil { return nil, err } @@ -2641,6 +2590,57 @@ func (lc *LightningChannel) evaluateHTLCView(view *htlcView, ourBalance, return newView, nil } +// getFetchParent is a helper that looks up update log parent entries in the +// appropriate log. +func (lc *LightningChannel) fetchParent(entry *PaymentDescriptor, + remoteChain, remoteLog bool) (*PaymentDescriptor, error) { + + var ( + updateLog *updateLog + logName string + ) + + if remoteLog { + updateLog = lc.remoteUpdateLog + logName = "remote" + } else { + updateLog = lc.localUpdateLog + logName = "local" + } + + addEntry := updateLog.lookupHtlc(entry.ParentIndex) + + switch { + // We check if the parent entry is not found at this point. + // This could happen for old versions of lnd, and we return an + // error to gracefully shut down the state machine if such an + // entry is still in the logs. + case addEntry == nil: + return nil, fmt.Errorf("unable to find parent entry "+ + "%d in %v update log: %v\nUpdatelog: %v", + entry.ParentIndex, logName, + newLogClosure(func() string { + return spew.Sdump(entry) + }), newLogClosure(func() string { + return spew.Sdump(updateLog) + }), + ) + + // The parent add height should never be zero at this point. If + // that's the case we probably forgot to send a new commitment. + case remoteChain && addEntry.addCommitHeightRemote == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ + "had zero remote add height", entry.ParentIndex, + entry.LogIndex) + case !remoteChain && addEntry.addCommitHeightLocal == 0: + return nil, fmt.Errorf("parent entry %d for update %d "+ + "had zero local add height", entry.ParentIndex, + entry.LogIndex) + } + + return addEntry, nil +} + // processAddEntry evaluates the effect of an add entry within the HTLC log. // If the HTLC hasn't yet been committed in either chain, then the height it // was committed is updated. Keeping track of this inclusion height allows us to diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 9f10be160..5de3d64e9 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5,7 +5,6 @@ import ( "container/list" "crypto/sha256" "fmt" - "reflect" "runtime" "testing" @@ -7676,3 +7675,1239 @@ func TestChannelFeeRateFloor(t *testing.T) { err) } } + +// TestFetchParent tests lookup of an entry's parent in the appropriate log. +func TestFetchParent(t *testing.T) { + tests := []struct { + name string + remoteChain bool + remoteLog bool + localEntries []*PaymentDescriptor + remoteEntries []*PaymentDescriptor + + // parentIndex is the parent index of the entry that we will + // lookup with fetch parent. + parentIndex uint64 + + // expectErr indicates that we expect fetch parent to fail. + expectErr bool + + // expectedIndex is the htlc index that we expect the parent + // to have. + expectedIndex uint64 + }{ + { + name: "not found in remote log", + localEntries: nil, + remoteEntries: nil, + remoteChain: true, + remoteLog: true, + parentIndex: 0, + expectErr: true, + }, + { + name: "not found in local log", + localEntries: nil, + remoteEntries: nil, + remoteChain: false, + remoteLog: false, + parentIndex: 0, + expectErr: true, + }, + { + name: "remote log + chain, remote add height 0", + localEntries: nil, + remoteEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 100, + addCommitHeightRemote: 0, + }, + }, + remoteChain: true, + remoteLog: true, + parentIndex: 1, + expectErr: true, + }, + { + name: "remote log, local chain, local add height 0", + remoteEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 0, + addCommitHeightRemote: 100, + }, + }, + localEntries: nil, + remoteChain: false, + remoteLog: true, + parentIndex: 1, + expectErr: true, + }, + { + name: "local log + chain, local add height 0", + localEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 0, + addCommitHeightRemote: 100, + }, + }, + remoteEntries: nil, + remoteChain: false, + remoteLog: false, + parentIndex: 1, + expectErr: true, + }, + + { + name: "local log + remote chain, remote add height 0", + localEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 100, + addCommitHeightRemote: 0, + }, + }, + remoteEntries: nil, + remoteChain: true, + remoteLog: false, + parentIndex: 1, + expectErr: true, + }, + { + name: "remote log found", + localEntries: nil, + remoteEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 100, + addCommitHeightRemote: 0, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + }, + remoteChain: true, + remoteLog: true, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, + }, + { + name: "local log found", + localEntries: []*PaymentDescriptor{ + // This entry will be added at log index =0. + { + HtlcIndex: 1, + addCommitHeightLocal: 0, + addCommitHeightRemote: 100, + }, + // This entry will be added at log index =1, it + // is the parent entry we are looking for. + { + HtlcIndex: 2, + addCommitHeightLocal: 100, + addCommitHeightRemote: 100, + }, + }, + remoteEntries: nil, + remoteChain: false, + remoteLog: false, + parentIndex: 1, + expectErr: false, + expectedIndex: 2, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Create a lightning channel with newly initialized + // local and remote logs. + lc := LightningChannel{ + localUpdateLog: newUpdateLog(0, 0), + remoteUpdateLog: newUpdateLog(0, 0), + } + + // Add the local and remote entries to update logs. + for _, entry := range test.localEntries { + lc.localUpdateLog.appendHtlc(entry) + } + for _, entry := range test.remoteEntries { + lc.remoteUpdateLog.appendHtlc(entry) + } + + parent, err := lc.fetchParent( + &PaymentDescriptor{ + ParentIndex: test.parentIndex, + }, + test.remoteChain, + test.remoteLog, + ) + gotErr := err != nil + if test.expectErr != gotErr { + t.Fatalf("expected error: %v, got: %v, "+ + "error:%v", test.expectErr, gotErr, err) + } + + // If our lookup failed, we do not need to check parent + // index. + if err != nil { + return + } + + if parent.HtlcIndex != test.expectedIndex { + t.Fatalf("expected parent index: %v, got: %v", + test.parentIndex, parent.HtlcIndex) + } + }) + + } +} + +// TestEvaluateView tests the creation of a htlc view and the opt in mutation of +// send and receive balances. This test does not check htlc mutation on a htlc +// level. +func TestEvaluateView(t *testing.T) { + const ( + // addHeight is a non-zero height that is used for htlc adds. + addHeight = 200 + + // nextHeight is a constant that we use for the next height in + // all unit tests. + nextHeight = 400 + + // feePerKw is the fee we start all of our unit tests with. + feePerKw = 1 + + // htlcAddAmount is the amount for htlc adds in tests. + htlcAddAmount = 15 + + // ourFeeUpdateAmt is an amount that we update fees to + // expressed in msat. + ourFeeUpdateAmt = 20000 + + // ourFeeUpdatePerSat is the fee rate *in satoshis* that we + // expect if we update to ourFeeUpdateAmt. + ourFeeUpdatePerSat = chainfee.SatPerKWeight(20) + + // theirFeeUpdateAmt iis an amount that they update fees to + // expressed in msat. + theirFeeUpdateAmt = 10000 + + // theirFeeUpdatePerSat is the fee rate *in satoshis* that we + // expect if we update to ourFeeUpdateAmt. + theirFeeUpdatePerSat = chainfee.SatPerKWeight(10) + ) + + tests := []struct { + name string + ourHtlcs []*PaymentDescriptor + theirHtlcs []*PaymentDescriptor + remoteChain bool + mutateState bool + + // ourExpectedHtlcs is the set of our htlcs that we expect in + // the htlc view once it has been evaluated. We just store + // htlc index -> bool for brevity, because we only check the + // presence of the htlc in the returned set. + ourExpectedHtlcs map[uint64]bool + + // theirExpectedHtlcs is the set of their htlcs that we expect + // in the htlc view once it has been evaluated. We just store + // htlc index -> bool for brevity, because we only check the + // presence of the htlc in the returned set. + theirExpectedHtlcs map[uint64]bool + + // expectedFee is the fee we expect to be set after evaluating + // the htlc view. + expectedFee chainfee.SatPerKWeight + + // expectReceived is the amount we expect the channel to have + // tracked as our receive total. + expectReceived lnwire.MilliSatoshi + + // expectSent is the amount we expect the channel to have + // tracked as our send total. + expectSent lnwire.MilliSatoshi + }{ + { + name: "our fee update is applied", + remoteChain: false, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{ + { + Amount: ourFeeUpdateAmt, + EntryType: FeeUpdate, + }, + }, + theirHtlcs: nil, + expectedFee: ourFeeUpdatePerSat, + ourExpectedHtlcs: nil, + theirExpectedHtlcs: nil, + expectReceived: 0, + expectSent: 0, + }, + { + name: "their fee update is applied", + remoteChain: false, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{}, + theirHtlcs: []*PaymentDescriptor{ + { + Amount: theirFeeUpdateAmt, + EntryType: FeeUpdate, + }, + }, + expectedFee: theirFeeUpdatePerSat, + ourExpectedHtlcs: nil, + theirExpectedHtlcs: nil, + expectReceived: 0, + expectSent: 0, + }, + { + // We expect unresolved htlcs to to remain in the view. + name: "htlcs adds without settles", + remoteChain: false, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + }, + theirHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Add, + }, + }, + expectedFee: feePerKw, + ourExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + theirExpectedHtlcs: map[uint64]bool{ + 0: true, + 1: true, + }, + expectReceived: 0, + expectSent: 0, + }, + { + name: "our htlc settled, state mutated", + remoteChain: false, + mutateState: true, + ourHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeightLocal: addHeight, + }, + }, + theirHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Settle, + // Map their htlc settle update to our + // htlc add (0). + ParentIndex: 0, + }, + }, + expectedFee: feePerKw, + ourExpectedHtlcs: nil, + theirExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + expectReceived: 0, + expectSent: htlcAddAmount, + }, + { + name: "our htlc settled, state not mutated", + remoteChain: false, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeightLocal: addHeight, + }, + }, + theirHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Settle, + // Map their htlc settle update to our + // htlc add (0). + ParentIndex: 0, + }, + }, + expectedFee: feePerKw, + ourExpectedHtlcs: nil, + theirExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + expectReceived: 0, + expectSent: 0, + }, + { + name: "their htlc settled, state mutated", + remoteChain: false, + mutateState: true, + ourHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Settle, + // Map our htlc settle update to their + // htlc add (1). + ParentIndex: 1, + }, + }, + theirHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeightLocal: addHeight, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeightLocal: addHeight, + }, + }, + expectedFee: feePerKw, + ourExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + theirExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + expectReceived: htlcAddAmount, + expectSent: 0, + }, + { + name: "their htlc settled, state not mutated", + remoteChain: false, + mutateState: false, + ourHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + }, + { + HtlcIndex: 1, + Amount: htlcAddAmount, + EntryType: Settle, + // Map our htlc settle update to their + // htlc add (0). + ParentIndex: 0, + }, + }, + theirHtlcs: []*PaymentDescriptor{ + { + HtlcIndex: 0, + Amount: htlcAddAmount, + EntryType: Add, + addCommitHeightLocal: addHeight, + }, + }, + expectedFee: feePerKw, + ourExpectedHtlcs: map[uint64]bool{ + 0: true, + }, + theirExpectedHtlcs: nil, + expectReceived: 0, + expectSent: 0, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + lc := LightningChannel{ + channelState: &channeldb.OpenChannel{ + TotalMSatSent: 0, + TotalMSatReceived: 0, + }, + + // Create update logs for local and remote. + localUpdateLog: newUpdateLog(0, 0), + remoteUpdateLog: newUpdateLog(0, 0), + } + + for _, htlc := range test.ourHtlcs { + if htlc.EntryType == Add { + lc.localUpdateLog.appendHtlc(htlc) + } else { + lc.localUpdateLog.appendUpdate(htlc) + } + } + + for _, htlc := range test.theirHtlcs { + if htlc.EntryType == Add { + lc.remoteUpdateLog.appendHtlc(htlc) + } else { + lc.remoteUpdateLog.appendUpdate(htlc) + } + } + + view := &htlcView{ + ourUpdates: test.ourHtlcs, + theirUpdates: test.theirHtlcs, + feePerKw: feePerKw, + } + + var ( + // Create vars to store balance changes. We do + // not check these values in this test because + // balance modification happens on the htlc + // processing level. + ourBalance lnwire.MilliSatoshi + theirBalance lnwire.MilliSatoshi + ) + + // Evaluate the htlc view, mutate as test expects. + result, err := lc.evaluateHTLCView( + view, &ourBalance, &theirBalance, nextHeight, + test.remoteChain, test.mutateState, + ) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result.feePerKw != test.expectedFee { + t.Fatalf("expected fee: %v, got: %v", + test.expectedFee, result.feePerKw) + } + + checkExpectedHtlcs( + t, result.ourUpdates, test.ourExpectedHtlcs, + ) + + checkExpectedHtlcs( + t, result.theirUpdates, test.theirExpectedHtlcs, + ) + + if lc.channelState.TotalMSatSent != test.expectSent { + t.Fatalf("expected sent: %v, got: %v", + test.expectSent, + lc.channelState.TotalMSatSent) + } + + if lc.channelState.TotalMSatReceived != + test.expectReceived { + + t.Fatalf("expected received: %v, got: %v", + test.expectReceived, + lc.channelState.TotalMSatReceived) + } + }) + } +} + +// checkExpectedHtlcs checks that a set of htlcs that we have contains all the +// htlcs we expect. +func checkExpectedHtlcs(t *testing.T, actual []*PaymentDescriptor, + expected map[uint64]bool) { + + if len(expected) != len(actual) { + t.Fatalf("expected: %v htlcs, got: %v", + len(expected), len(actual)) + } + + for _, htlc := range actual { + _, ok := expected[htlc.HtlcIndex] + if !ok { + t.Fatalf("htlc with index: %v not "+ + "expected in set", htlc.HtlcIndex) + } + } +} + +// heights represents the heights on a payment descriptor. +type heights struct { + localAdd uint64 + localRemove uint64 + remoteAdd uint64 + remoteRemove uint64 +} + +// TestProcessFeeUpdate tests the applying of fee updates and mutation of +// local and remote add and remove heights on update messages. +func TestProcessFeeUpdate(t *testing.T) { + const ( + // height is a non-zero height that can be used for htlcs + // heights. + height = 200 + + // nextHeight is a constant that we use for the next height in + // all unit tests. + nextHeight = 400 + + // feePerKw is the fee we start all of our unit tests with. + feePerKw = 1 + + // ourFeeUpdateAmt is an amount that we update fees to expressed + // in msat. + ourFeeUpdateAmt = 20000 + + // ourFeeUpdatePerSat is the fee rate *in satoshis* that we + // expect if we update to ourFeeUpdateAmt. + ourFeeUpdatePerSat = chainfee.SatPerKWeight(20) + ) + + tests := []struct { + name string + startHeights heights + expectedHeights heights + remoteChain bool + mutate bool + expectedFee chainfee.SatPerKWeight + }{ + { + // Looking at local chain, local add is non-zero so + // the update has been applied already; no fee change. + name: "non-zero local height, fee unchanged", + startHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: height, + }, + expectedHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: height, + }, + remoteChain: false, + mutate: false, + expectedFee: feePerKw, + }, + { + // Looking at local chain, local add is zero so the + // update has not been applied yet; we expect a fee + // update. + name: "zero local height, fee changed", + startHeights: heights{ + localAdd: 0, + localRemove: 0, + remoteAdd: height, + remoteRemove: 0, + }, + expectedHeights: heights{ + localAdd: 0, + localRemove: 0, + remoteAdd: height, + remoteRemove: 0, + }, + remoteChain: false, + mutate: false, + expectedFee: ourFeeUpdatePerSat, + }, + { + // Looking at remote chain, the remote add height is + // zero, so the update has not been applied so we expect + // a fee change. + name: "zero remote height, fee changed", + startHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: 0, + }, + expectedHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: 0, + }, + remoteChain: true, + mutate: false, + expectedFee: ourFeeUpdatePerSat, + }, + { + // Looking at remote chain, the remote add height is + // non-zero, so the update has been applied so we expect + // no fee change. + name: "non-zero remote height, no fee change", + startHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: height, + remoteRemove: 0, + }, + expectedHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: height, + remoteRemove: 0, + }, + remoteChain: true, + mutate: false, + expectedFee: feePerKw, + }, + { + // Local add height is non-zero, so the update has + // already been applied; we do not expect fee to + // change or any mutations to be applied. + name: "non-zero local height, mutation not applied", + startHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: height, + }, + expectedHeights: heights{ + localAdd: height, + localRemove: 0, + remoteAdd: 0, + remoteRemove: height, + }, + remoteChain: false, + mutate: true, + expectedFee: feePerKw, + }, + { + // Local add is zero and we are looking at our local + // chain, so the update has not been applied yet. We + // expect the local add and remote heights to be + // mutated. + name: "zero height, fee changed, mutation applied", + startHeights: heights{ + localAdd: 0, + localRemove: 0, + remoteAdd: 0, + remoteRemove: 0, + }, + expectedHeights: heights{ + localAdd: nextHeight, + localRemove: nextHeight, + remoteAdd: 0, + remoteRemove: 0, + }, + remoteChain: false, + mutate: true, + expectedFee: ourFeeUpdatePerSat, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + // Create a fee update with add and remove heights as + // set in the test. + heights := test.startHeights + update := &PaymentDescriptor{ + Amount: ourFeeUpdateAmt, + addCommitHeightRemote: heights.remoteAdd, + addCommitHeightLocal: heights.localAdd, + removeCommitHeightRemote: heights.remoteRemove, + removeCommitHeightLocal: heights.localRemove, + EntryType: FeeUpdate, + } + + view := &htlcView{ + feePerKw: chainfee.SatPerKWeight(feePerKw), + } + processFeeUpdate( + update, nextHeight, test.remoteChain, + test.mutate, view, + ) + + if view.feePerKw != test.expectedFee { + t.Fatalf("expected fee: %v, got: %v", + test.expectedFee, feePerKw) + } + + checkHeights(t, update, test.expectedHeights) + }) + } +} + +func checkHeights(t *testing.T, update *PaymentDescriptor, expected heights) { + updateHeights := heights{ + localAdd: update.addCommitHeightLocal, + localRemove: update.removeCommitHeightLocal, + remoteAdd: update.addCommitHeightRemote, + remoteRemove: update.removeCommitHeightRemote, + } + + if !reflect.DeepEqual(updateHeights, expected) { + t.Fatalf("expected: %v, got: %v", expected, updateHeights) + } +} + +// TestProcessAddRemoveEntry tests the updating of our and their balances when +// we process adds, settles and fails. It also tests the mutating of add and +// remove heights. +func TestProcessAddRemoveEntry(t *testing.T) { + const ( + // addHeight is a non-zero addHeight that is used for htlc + // add heights. + addHeight = 100 + + // removeHeight is a non-zero removeHeight that is used for + // htlc remove heights. + removeHeight = 200 + + // nextHeight is a constant that we use for the nextHeight in + // all unit tests. + nextHeight = 400 + + // updateAmount is the amount that the update is set to. + updateAmount = lnwire.MilliSatoshi(10) + + // startBalance is a balance we start both sides out with + // so that balances can be incremented. + startBalance = lnwire.MilliSatoshi(100) + ) + + tests := []struct { + name string + startHeights heights + remoteChain bool + isIncoming bool + mutateState bool + ourExpectedBalance lnwire.MilliSatoshi + theirExpectedBalance lnwire.MilliSatoshi + expectedHeights heights + updateType updateType + }{ + { + name: "add, remote chain, already processed", + startHeights: heights{ + localAdd: 0, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: 0, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + { + name: "add, local chain, already processed", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: false, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + { + name: "incoming add, local chain, not mutated", + startHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: false, + isIncoming: true, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance - updateAmount, + expectedHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + { + name: "incoming add, local chain, mutated", + startHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: false, + isIncoming: true, + mutateState: true, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance - updateAmount, + expectedHeights: heights{ + localAdd: nextHeight, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + + { + name: "outgoing add, remote chain, not mutated", + startHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance - updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + { + name: "outgoing add, remote chain, mutated", + startHeights: heights{ + localAdd: 0, + remoteAdd: 0, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: false, + mutateState: true, + ourExpectedBalance: startBalance - updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: 0, + remoteAdd: nextHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Add, + }, + { + name: "settle, remote chain, already processed", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: removeHeight, + }, + remoteChain: true, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: removeHeight, + }, + updateType: Settle, + }, + { + name: "settle, local chain, already processed", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: removeHeight, + remoteRemove: 0, + }, + remoteChain: false, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: removeHeight, + remoteRemove: 0, + }, + updateType: Settle, + }, + { + // Remote chain, and not processed yet. Incoming settle, + // so we expect our balance to increase. + name: "incoming settle", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: true, + mutateState: false, + ourExpectedBalance: startBalance + updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Settle, + }, + { + // Remote chain, and not processed yet. Incoming settle, + // so we expect our balance to increase. + name: "outgoing settle", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance + updateAmount, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Settle, + }, + { + // Remote chain, and not processed yet. Incoming fail, + // so we expect their balance to increase. + name: "incoming fail", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: true, + mutateState: false, + ourExpectedBalance: startBalance, + theirExpectedBalance: startBalance + updateAmount, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Fail, + }, + { + // Remote chain, and not processed yet. Outgoing fail, + // so we expect our balance to increase. + name: "outgoing fail", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: false, + mutateState: false, + ourExpectedBalance: startBalance + updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + updateType: Fail, + }, + { + // Local chain, and not processed yet. Incoming settle, + // so we expect our balance to increase. Mutate is + // true, so we expect our remove removeHeight to have + // changed. + name: "fail, our remove height mutated", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: false, + isIncoming: true, + mutateState: true, + ourExpectedBalance: startBalance + updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: nextHeight, + remoteRemove: 0, + }, + updateType: Settle, + }, + { + // Remote chain, and not processed yet. Incoming settle, + // so we expect our balance to increase. Mutate is + // true, so we expect their remove removeHeight to have + // changed. + name: "fail, their remove height mutated", + startHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: 0, + }, + remoteChain: true, + isIncoming: true, + mutateState: true, + ourExpectedBalance: startBalance + updateAmount, + theirExpectedBalance: startBalance, + expectedHeights: heights{ + localAdd: addHeight, + remoteAdd: addHeight, + localRemove: 0, + remoteRemove: nextHeight, + }, + updateType: Settle, + }, + } + + for _, test := range tests { + test := test + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + heights := test.startHeights + update := &PaymentDescriptor{ + Amount: updateAmount, + addCommitHeightLocal: heights.localAdd, + addCommitHeightRemote: heights.remoteAdd, + removeCommitHeightLocal: heights.localRemove, + removeCommitHeightRemote: heights.remoteRemove, + EntryType: test.updateType, + } + + var ( + // Start both parties off with an initial + // balance. Copy by value here so that we do + // not mutate the startBalance constant. + ourBalance, theirBalance = startBalance, + startBalance + ) + + // Choose the processing function we need based on the + // update type. Process remove is used for settles, + // fails and malformed htlcs. + process := processRemoveEntry + if test.updateType == Add { + process = processAddEntry + } + + process( + update, &ourBalance, &theirBalance, nextHeight, + test.remoteChain, test.isIncoming, + test.mutateState, + ) + + // Check that balances were updated as expected. + if ourBalance != test.ourExpectedBalance { + t.Fatalf("expected our balance: %v, got: %v", + test.ourExpectedBalance, ourBalance) + } + + if theirBalance != test.theirExpectedBalance { + t.Fatalf("expected their balance: %v, got: %v", + test.theirExpectedBalance, theirBalance) + } + + // Check that heights on the update are as expected. + checkHeights(t, update, test.expectedHeights) + }) + } +}