diff --git a/sweep/aggregator.go b/sweep/aggregator.go index 7676a52ee..fdc4a27f1 100644 --- a/sweep/aggregator.go +++ b/sweep/aggregator.go @@ -551,9 +551,14 @@ func (b *BudgetAggregator) ClusterInputs(inputs InputsMap, // Sort the inputs by their economical value. sortedInputs := b.sortInputs(cluster) + // Split on locktimes if they are different. + splitClusters := splitOnLocktime(sortedInputs) + // Create input sets from the cluster. - sets := b.createInputSets(sortedInputs, height) - inputSets = append(inputSets, sets...) + for _, cluster := range splitClusters { + sets := b.createInputSets(cluster, height) + inputSets = append(inputSets, sets...) + } } // Create input sets from the exclusive inputs. @@ -742,6 +747,62 @@ func (b *BudgetAggregator) sortInputs(inputs []SweeperInput) []SweeperInput { return sortedInputs } +// splitOnLocktime splits the list of inputs based on their locktime. +// +// TODO(yy): this is a temporary hack as the blocks are not synced among the +// contractcourt and the sweeper. +func splitOnLocktime(inputs []SweeperInput) map[uint32][]SweeperInput { + result := make(map[uint32][]SweeperInput) + noLocktimeInputs := make([]SweeperInput, 0, len(inputs)) + + // mergeLocktime is the locktime that we use to merge all the + // nolocktime inputs into. + var mergeLocktime uint32 + + // Iterate all inputs and split them based on their locktimes. + for _, inp := range inputs { + locktime, required := inp.RequiredLockTime() + if !required { + log.Tracef("No locktime required for input=%v", + inp.OutPoint()) + + noLocktimeInputs = append(noLocktimeInputs, inp) + + continue + } + + log.Tracef("Split input=%v on locktime=%v", inp.OutPoint(), + locktime) + + // Get the slice - the slice will be initialized if not found. + inputList := result[locktime] + + // Add the input to the list. + inputList = append(inputList, inp) + + // Update the map. + result[locktime] = inputList + + // Update the merge locktime. + mergeLocktime = locktime + } + + // If there are locktime inputs, we will merge the no locktime inputs + // to the last locktime group found. + if len(result) > 0 { + log.Tracef("No locktime inputs has been merged to locktime=%v", + mergeLocktime) + result[mergeLocktime] = append( + result[mergeLocktime], noLocktimeInputs..., + ) + } else { + // Otherwise just return the no locktime inputs. + result[mergeLocktime] = noLocktimeInputs + } + + return result +} + // isDustOutput checks if the given output is considered as dust. func isDustOutput(output *wire.TxOut) bool { // Fetch the dust limit for this output. diff --git a/sweep/aggregator_test.go b/sweep/aggregator_test.go index e9b832938..e002a14b6 100644 --- a/sweep/aggregator_test.go +++ b/sweep/aggregator_test.go @@ -839,7 +839,7 @@ func TestBudgetInputSetClusterInputs(t *testing.T) { // 3. when assigning the input to the exclusiveInputs. // 4. when iterating the exclusiveInputs. opExclusive := wire.OutPoint{Hash: chainhash.Hash{1, 2, 3, 4, 5}} - inpExclusive.On("OutPoint").Return(opExclusive).Times(4) + inpExclusive.On("OutPoint").Return(opExclusive).Maybe() // Mock the `WitnessType` method to return the witness type. inpExclusive.On("WitnessType").Return(wt) @@ -895,11 +895,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) { // `filterInputs`. inpLow.On("OutPoint").Return(opLow).Once() - // We expect the high budget input to call this method three - // times, one in `filterInputs` and one in `createInputSet`, - // and one in `NewBudgetInputSet`. - inpHigh1.On("OutPoint").Return(opHigh1).Times(3) - inpHigh2.On("OutPoint").Return(opHigh2).Times(3) + // The number of times this method is called is dependent on + // the log level. + inpHigh1.On("OutPoint").Return(opHigh1).Maybe() + inpHigh2.On("OutPoint").Return(opHigh2).Maybe() // Mock the `WitnessType` method to return the witness type. inpLow.On("WitnessType").Return(wt) @@ -910,6 +909,10 @@ func TestBudgetInputSetClusterInputs(t *testing.T) { inpHigh1.On("RequiredTxOut").Return(nil) inpHigh2.On("RequiredTxOut").Return(nil) + // Mock the `RequiredLockTime` to return 0. + inpHigh1.On("RequiredLockTime").Return(uint32(0), false) + inpHigh2.On("RequiredLockTime").Return(uint32(0), false) + // Add the low input, which should be filtered out. inputs[opLow] = &SweeperInput{ Input: inpLow, @@ -969,3 +972,72 @@ func TestBudgetInputSetClusterInputs(t *testing.T) { require.Contains(t, deadlines, deadline1.UnwrapOrFail(t)) require.Contains(t, deadlines, deadline2.UnwrapOrFail(t)) } + +// TestSplitOnLocktime asserts `splitOnLocktime` works as expected. +func TestSplitOnLocktime(t *testing.T) { + t.Parallel() + + // Create two locktimes. + lockTime1 := uint32(1) + lockTime2 := uint32(2) + + // Create cluster one, which has a locktime of 1. + input1LockTime1 := &input.MockInput{} + input2LockTime1 := &input.MockInput{} + input1LockTime1.On("RequiredLockTime").Return(lockTime1, true) + input2LockTime1.On("RequiredLockTime").Return(lockTime1, true) + + // Create cluster two, which has a locktime of 2. + input3LockTime2 := &input.MockInput{} + input4LockTime2 := &input.MockInput{} + input3LockTime2.On("RequiredLockTime").Return(lockTime2, true) + input4LockTime2.On("RequiredLockTime").Return(lockTime2, true) + + // Create cluster three, which has no locktime. + // Create cluster three, which has no locktime. + input5NoLockTime := &input.MockInput{} + input6NoLockTime := &input.MockInput{} + input5NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + input6NoLockTime.On("RequiredLockTime").Return(uint32(0), false) + + // Mock `OutPoint` - it may or may not be called due to log settings. + input1LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 1}).Maybe() + input2LockTime1.On("OutPoint").Return(wire.OutPoint{Index: 2}).Maybe() + input3LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 3}).Maybe() + input4LockTime2.On("OutPoint").Return(wire.OutPoint{Index: 4}).Maybe() + input5NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 5}).Maybe() + input6NoLockTime.On("OutPoint").Return(wire.OutPoint{Index: 6}).Maybe() + + // With the inner Input being mocked, we can now create the pending + // inputs. + input1 := SweeperInput{Input: input1LockTime1} + input2 := SweeperInput{Input: input2LockTime1} + input3 := SweeperInput{Input: input3LockTime2} + input4 := SweeperInput{Input: input4LockTime2} + input5 := SweeperInput{Input: input5NoLockTime} + input6 := SweeperInput{Input: input6NoLockTime} + + // Call the method under test. + inputs := []SweeperInput{input1, input2, input3, input4, input5, input6} + result := splitOnLocktime(inputs) + + // We expect the no locktime inputs to be grouped with locktime2. + expectedResult := map[uint32][]SweeperInput{ + lockTime1: {input1, input2}, + lockTime2: {input3, input4, input5, input6}, + } + require.Len(t, result[lockTime1], 2) + require.Len(t, result[lockTime2], 4) + require.Equal(t, expectedResult, result) + + // Test the case where there are no locktime inputs. + inputs = []SweeperInput{input5, input6} + result = splitOnLocktime(inputs) + + // We expect the no locktime inputs to be returned as is. + expectedResult = map[uint32][]SweeperInput{ + uint32(0): {input5, input6}, + } + require.Len(t, result[uint32(0)], 2) + require.Equal(t, expectedResult, result) +}