diff --git a/contractcourt/commit_sweep_resolver_test.go b/contractcourt/commit_sweep_resolver_test.go index bbcc1e1b8..e864fb608 100644 --- a/contractcourt/commit_sweep_resolver_test.go +++ b/contractcourt/commit_sweep_resolver_test.go @@ -1,6 +1,7 @@ package contractcourt import ( + "fmt" "testing" "time" @@ -127,9 +128,15 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) ( s.sweptInputs <- input + // TODO(yy): use `mock.Mock` to avoid the conversion. + fee, ok := params.Fee.(sweep.FeeEstimateInfo) + if !ok { + return nil, fmt.Errorf("unexpected fee type: %T", params.Fee) + } + // Update the deadlines used if it's set. - if params.Fee.ConfTarget != 0 { - s.deadlines = append(s.deadlines, int(params.Fee.ConfTarget)) + if fee.ConfTarget != 0 { + s.deadlines = append(s.deadlines, int(fee.ConfTarget)) } result := make(chan sweep.Result, 1) diff --git a/lnrpc/walletrpc/walletkit_server.go b/lnrpc/walletrpc/walletkit_server.go index b6ad08cde..20ee40a9e 100644 --- a/lnrpc/walletrpc/walletkit_server.go +++ b/lnrpc/walletrpc/walletkit_server.go @@ -884,7 +884,13 @@ func (w *WalletKit) PendingSweeps(ctx context.Context, broadcastAttempts := uint32(pendingInput.BroadcastAttempts) nextBroadcastHeight := uint32(pendingInput.NextBroadcastHeight) - requestedFee := pendingInput.Params.Fee + feePref := pendingInput.Params.Fee + requestedFee, ok := feePref.(sweep.FeeEstimateInfo) + if !ok { + return nil, fmt.Errorf("unknown fee preference type: "+ + "%v", feePref) + } + requestedFeeRate := uint64(requestedFee.FeeRate.FeePerVByte()) rpcPendingSweeps = append(rpcPendingSweeps, &PendingSweep{ diff --git a/sweep/mocks.go b/sweep/mocks.go new file mode 100644 index 000000000..516e35837 --- /dev/null +++ b/sweep/mocks.go @@ -0,0 +1,29 @@ +package sweep + +import ( + "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/stretchr/testify/mock" +) + +type MockFeePreference struct { + mock.Mock +} + +// Compile-time constraint to ensure MockFeePreference implements FeePreference. +var _ FeePreference = (*MockFeePreference)(nil) + +func (m *MockFeePreference) String() string { + return "mock fee preference" +} + +func (m *MockFeePreference) Estimate(estimator chainfee.Estimator, + maxFeeRate chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) { + + args := m.Called(estimator, maxFeeRate) + + if args.Get(0) == nil { + return 0, args.Error(1) + } + + return args.Get(0).(chainfee.SatPerKWeight), args.Error(1) +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 97028facc..186b2684c 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -68,7 +68,7 @@ type Params struct { // Fee is the fee preference of the client who requested the input to be // swept. If a confirmation target is specified, then we'll map it into // a fee rate whenever we attempt to cluster inputs for a sweep. - Fee FeeEstimateInfo + Fee FeePreference // Force indicates whether the input should be swept regardless of // whether it is economical to do so. @@ -84,7 +84,7 @@ type ParamsUpdate struct { // Fee is the fee preference of the client who requested the input to be // swept. If a confirmation target is specified, then we'll map it into // a fee rate whenever we attempt to cluster inputs for a sweep. - Fee FeeEstimateInfo + Fee FeePreference // Force indicates whether the input should be swept regardless of // whether it is economical to do so. diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 854d430fd..e14e776a9 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,6 +1,7 @@ package sweep import ( + "errors" "os" "reflect" "runtime/pprof" @@ -335,7 +336,9 @@ func TestSuccess(t *testing.T) { ctx := createSweeperTestContext(t) // Sweeping an input without a fee preference should result in an error. - _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{}) + _, err := ctx.sweeper.SweepInput(spendableInputs[0], Params{ + Fee: &FeeEstimateInfo{}, + }) if err != ErrNoFeePreference { t.Fatalf("expected ErrNoFeePreference, got %v", err) } @@ -1100,7 +1103,9 @@ func TestBumpFeeRBF(t *testing.T) { ctx.estimator.blocksToFee[highFeePref.ConfTarget] = highFeeRate // We should expect to see an error if a fee preference isn't provided. - _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{}) + _, err = ctx.sweeper.UpdateParams(*input.OutPoint(), ParamsUpdate{ + Fee: &FeeEstimateInfo{}, + }) if err != ErrNoFeePreference { t.Fatalf("expected ErrNoFeePreference, got %v", err) } @@ -2141,9 +2146,12 @@ func TestSweeperShutdownHandling(t *testing.T) { func TestClusterByLockTime(t *testing.T) { t.Parallel() + // Create a mock FeePreference. + mockFeePref := &MockFeePreference{} + // Create a test param with a dummy fee preference. This is needed so // `feeRateForPreference` won't throw an error. - param := Params{Fee: FeeEstimateInfo{ConfTarget: 1}} + param := Params{Fee: mockFeePref} // We begin the test by creating three clusters of inputs, the first // cluster has a locktime of 1, the second has a locktime of 2, and the @@ -2222,15 +2230,11 @@ func TestClusterByLockTime(t *testing.T) { // minFeeRate will cause an error to be returned. s.relayFeeRate = minFeeRate - // applyFeeRate takes a testing fee rate and makes a mocker over - // DetermineFeePerKw that always return the testing fee rate. This - // mocked method is then attached to the sweeper. - applyFeeRate := func(feeRate chainfee.SatPerKWeight) { - // TODO(yy): fix the test here. - } - testCases := []struct { - name string + name string + // setupMocker takes a testing fee rate and makes a mocker over + // `Estimate` that always return the testing fee rate. + setupMocker func() testFeeRate chainfee.SatPerKWeight expectedClusters []inputCluster expectedRemainingInputs pendingInputs @@ -2240,6 +2244,14 @@ func TestClusterByLockTime(t *testing.T) { // are created and the no-locktime cluster is returned // as the remaining inputs. name: "successfully create clusters", + setupMocker: func() { + mockFeePref.On("Estimate", + s.cfg.FeeEstimator, + s.cfg.MaxFeeRate.FeePerKWeight(), + // Expect the four inputs with locktime to call + // this method. + ).Return(minFeeRate+1, nil).Times(4) + }, // Use a fee rate above the min value so we don't hit // an error when performing fee estimation. // @@ -2267,6 +2279,14 @@ func TestClusterByLockTime(t *testing.T) { // Test that when the input is skipped when the fee // estimation returns an error. name: "error from fee estimation", + setupMocker: func() { + mockFeePref.On("Estimate", + s.cfg.FeeEstimator, + s.cfg.MaxFeeRate.FeePerKWeight(), + ).Return(chainfee.SatPerKWeight(0), + errors.New("dummy")).Times(4) + }, + // Use a fee rate below the min value so we hit an // error when performing fee estimation. testFeeRate: minFeeRate - 1, @@ -2283,7 +2303,10 @@ func TestClusterByLockTime(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Apply the test fee rate so `feeRateForPreference` is // mocked to return the specified value. - applyFeeRate(tc.testFeeRate) + tc.setupMocker() + + // Assert the mocked methods are called as expeceted. + defer mockFeePref.AssertExpectations(t) // Call the method under test. clusters, remainingInputs := s.clusterByLockTime(inputs) diff --git a/sweep/walletsweep.go b/sweep/walletsweep.go index be3b69399..5328ae508 100644 --- a/sweep/walletsweep.go +++ b/sweep/walletsweep.go @@ -33,6 +33,24 @@ var ( ErrFeePreferenceConflict = errors.New("fee preference conflict") ) +// FeePreference defines an interface that allows the caller to specify how the +// fee rate should be handled. Depending on the implementation, the fee rate +// can either be specified directly, or via a conf target which relies on the +// chain backend(`bitcoind`) to give a fee estimation, or a customized fee +// function which handles fee calculation based on the specified +// urgency(deadline). +type FeePreference interface { + // String returns a human-readable string of the fee preference. + String() string + + // Estimate takes a fee estimator and a max allowed fee rate and + // returns a fee rate for the given fee preference. It ensures that the + // fee rate respects the bounds of the relay fee and the specified max + // fee rates. + Estimate(chainfee.Estimator, + chainfee.SatPerKWeight) (chainfee.SatPerKWeight, error) +} + // FeeEstimateInfo allows callers to express their time value for inclusion of // a transaction into a block via either a confirmation target, or a fee rate. type FeeEstimateInfo struct { @@ -45,12 +63,16 @@ type FeeEstimateInfo struct { FeeRate chainfee.SatPerKWeight } +// Compile-time constraint to ensure FeeEstimateInfo implements FeePreference. +var _ FeePreference = (*FeeEstimateInfo)(nil) + // String returns a human-readable string of the fee preference. -func (p FeeEstimateInfo) String() string { - if p.ConfTarget != 0 { - return fmt.Sprintf("%v blocks", p.ConfTarget) +func (f FeeEstimateInfo) String() string { + if f.ConfTarget != 0 { + return fmt.Sprintf("%v blocks", f.ConfTarget) } - return p.FeeRate.String() + + return f.FeeRate.String() } // Estimate returns a fee rate for the given fee preference. It ensures that