sweep+contractcourt: track best height in UtxoSweeper

Thus we can use shorter method signatures. In doing so we also remove an
old TODO in one use case of `CreateSweepTx`.
This commit is contained in:
yyforyongyu 2023-10-24 07:14:55 +08:00
parent ca0813b1bf
commit 84a6fdcda3
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
5 changed files with 39 additions and 49 deletions

View File

@ -140,8 +140,8 @@ func (s *mockSweeper) SweepInput(input input.Input, params sweep.Params) (
return result, nil return result, nil
} }
func (s *mockSweeper) CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, func (s *mockSweeper) CreateSweepTx(inputs []input.Input,
currentBlockHeight uint32) (*wire.MsgTx, error) { feePref sweep.FeePreference) (*wire.MsgTx, error) {
// We will wait for the test to supply the sweep tx to return. // We will wait for the test to supply the sweep tx to return.
sweepTx := <-s.createSweepTxChan sweepTx := <-s.createSweepTxChan

View File

@ -432,17 +432,13 @@ func (h *htlcSuccessResolver) resolveRemoteCommitOutput() (
// transaction, that we'll use to move these coins back into // transaction, that we'll use to move these coins back into
// the backing wallet. // the backing wallet.
// //
// TODO: Set tx lock time to current block height instead of
// zero. Will be taken care of once sweeper implementation is
// complete.
//
// TODO: Use time-based sweeper and result chan. // TODO: Use time-based sweeper and result chan.
var err error var err error
h.sweepTx, err = h.Sweeper.CreateSweepTx( h.sweepTx, err = h.Sweeper.CreateSweepTx(
[]input.Input{inp}, []input.Input{inp},
sweep.FeePreference{ sweep.FeePreference{
ConfTarget: sweepConfTarget, ConfTarget: sweepConfTarget,
}, 0, },
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -53,8 +53,8 @@ type UtxoSweeper interface {
// CreateSweepTx accepts a list of inputs and signs and generates a txn // CreateSweepTx accepts a list of inputs and signs and generates a txn
// that spends from them. This method also makes an accurate fee // that spends from them. This method also makes an accurate fee
// estimate before generating the required witnesses. // estimate before generating the required witnesses.
CreateSweepTx(inputs []input.Input, feePref sweep.FeePreference, CreateSweepTx(inputs []input.Input,
currentBlockHeight uint32) (*wire.MsgTx, error) feePref sweep.FeePreference) (*wire.MsgTx, error)
// RelayFeePerKW returns the minimum fee rate required for transactions // RelayFeePerKW returns the minimum fee rate required for transactions
// to be relayed. // to be relayed.

View File

@ -236,6 +236,10 @@ type UtxoSweeper struct {
quit chan struct{} quit chan struct{}
wg sync.WaitGroup wg sync.WaitGroup
// currentHeight is the best known height of the main chain. This is
// updated whenever a new block epoch is received.
currentHeight int32
} }
// feeDeterminer defines an alias to the function signature of // feeDeterminer defines an alias to the function signature of
@ -596,10 +600,9 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// We registered for the block epochs with a nil request. The notifier // We registered for the block epochs with a nil request. The notifier
// should send us the current best block immediately. So we need to wait // should send us the current best block immediately. So we need to wait
// for it here because we need to know the current best height. // for it here because we need to know the current best height.
var bestHeight int32
select { select {
case bestBlock := <-blockEpochs: case bestBlock := <-blockEpochs:
bestHeight = bestBlock.Height s.currentHeight = bestBlock.Height
case <-s.quit: case <-s.quit:
return return
@ -617,7 +620,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// we are already trying to sweep this input and if not, set up // we are already trying to sweep this input and if not, set up
// a listener to spend and schedule a sweep. // a listener to spend and schedule a sweep.
case input := <-s.newInputs: case input := <-s.newInputs:
s.handleNewInput(input, bestHeight) s.handleNewInput(input)
// A spend of one of our inputs is detected. Signal sweep // A spend of one of our inputs is detected. Signal sweep
// results to the caller(s). // results to the caller(s).
@ -632,7 +635,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// A new external request has been received to bump the fee rate // A new external request has been received to bump the fee rate
// of a given input. // of a given input.
case req := <-s.updateReqs: case req := <-s.updateReqs:
resultChan, err := s.handleUpdateReq(req, bestHeight) resultChan, err := s.handleUpdateReq(req)
req.responseChan <- &updateResp{ req.responseChan <- &updateResp{
resultChan: resultChan, resultChan: resultChan,
err: err, err: err,
@ -641,7 +644,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
// The timer expires and we are going to (re)sweep. // The timer expires and we are going to (re)sweep.
case <-ticker.C: case <-ticker.C:
log.Debugf("Sweep ticker ticks, attempt sweeping...") log.Debugf("Sweep ticker ticks, attempt sweeping...")
s.handleSweep(bestHeight) s.handleSweep()
// A new block comes in, update the bestHeight. // A new block comes in, update the bestHeight.
case epoch, ok := <-blockEpochs: case epoch, ok := <-blockEpochs:
@ -649,7 +652,7 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) {
return return
} }
bestHeight = epoch.Height s.currentHeight = epoch.Height
log.Debugf("New block: height=%v, sha=%v", log.Debugf("New block: height=%v, sha=%v",
epoch.Height, epoch.Hash) epoch.Height, epoch.Hash)
@ -698,15 +701,13 @@ func (s *UtxoSweeper) removeExclusiveGroup(group uint64) {
} }
// sweepCluster tries to sweep the given input cluster. // sweepCluster tries to sweep the given input cluster.
func (s *UtxoSweeper) sweepCluster(cluster inputCluster, func (s *UtxoSweeper) sweepCluster(cluster inputCluster) error {
currentHeight int32) error {
// Execute the sweep within a coin select lock. Otherwise the coins // Execute the sweep within a coin select lock. Otherwise the coins
// that we are going to spend may be selected for other transactions // that we are going to spend may be selected for other transactions
// like funding of a channel. // like funding of a channel.
return s.cfg.Wallet.WithCoinSelectLock(func() error { return s.cfg.Wallet.WithCoinSelectLock(func() error {
// Examine pending inputs and try to construct lists of inputs. // Examine pending inputs and try to construct lists of inputs.
allSets, newSets, err := s.getInputLists(cluster, currentHeight) allSets, newSets, err := s.getInputLists(cluster)
if err != nil { if err != nil {
return fmt.Errorf("examine pending inputs: %w", err) return fmt.Errorf("examine pending inputs: %w", err)
} }
@ -719,9 +720,7 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster,
// creating an RBF for the new inputs, we'd sweep this set // creating an RBF for the new inputs, we'd sweep this set
// first. // first.
for _, inputs := range allSets { for _, inputs := range allSets {
errAllSets = s.sweep( errAllSets = s.sweep(inputs, cluster.sweepFeeRate)
inputs, cluster.sweepFeeRate, currentHeight,
)
// TODO(yy): we should also find out which set created // TODO(yy): we should also find out which set created
// this error. If there are new inputs in this set, we // this error. If there are new inputs in this set, we
// should give it a second chance by sweeping them // should give it a second chance by sweeping them
@ -754,9 +753,7 @@ func (s *UtxoSweeper) sweepCluster(cluster inputCluster,
// when sweeping a given set, we'd log the error and sweep the // when sweeping a given set, we'd log the error and sweep the
// next set. // next set.
for _, inputs := range newSets { for _, inputs := range newSets {
err := s.sweep( err := s.sweep(inputs, cluster.sweepFeeRate)
inputs, cluster.sweepFeeRate, currentHeight,
)
if err != nil { if err != nil {
log.Errorf("sweep new inputs: %w", err) log.Errorf("sweep new inputs: %w", err)
} }
@ -1079,8 +1076,8 @@ func (s *UtxoSweeper) signalAndRemove(outpoint *wire.OutPoint, result Result) {
// and will be bundled with future inputs if possible. It returns two list - // and will be bundled with future inputs if possible. It returns two list -
// one containing all inputs and the other containing only the new inputs. If // one containing all inputs and the other containing only the new inputs. If
// there's no retried inputs, the first set returned will be empty. // there's no retried inputs, the first set returned will be empty.
func (s *UtxoSweeper) getInputLists(cluster inputCluster, func (s *UtxoSweeper) getInputLists(
currentHeight int32) ([]inputSet, []inputSet, error) { cluster inputCluster) ([]inputSet, []inputSet, error) {
// Filter for inputs that need to be swept. Create two lists: all // Filter for inputs that need to be swept. Create two lists: all
// sweepable inputs and a list containing only the new, never tried // sweepable inputs and a list containing only the new, never tried
@ -1102,7 +1099,7 @@ func (s *UtxoSweeper) getInputLists(cluster inputCluster,
for _, input := range cluster.inputs { for _, input := range cluster.inputs {
// Skip inputs that have a minimum publish height that is not // Skip inputs that have a minimum publish height that is not
// yet reached. // yet reached.
if input.minPublishHeight > currentHeight { if input.minPublishHeight > s.currentHeight {
continue continue
} }
@ -1143,15 +1140,15 @@ func (s *UtxoSweeper) getInputLists(cluster inputCluster,
} }
log.Debugf("Sweep candidates at height=%v: total_num_pending=%v, "+ log.Debugf("Sweep candidates at height=%v: total_num_pending=%v, "+
"total_num_new=%v", currentHeight, len(allSets), len(newSets)) "total_num_new=%v", s.currentHeight, len(allSets), len(newSets))
return allSets, newSets, nil return allSets, newSets, nil
} }
// sweep takes a set of preselected inputs, creates a sweep tx and publishes the // sweep takes a set of preselected inputs, creates a sweep tx and publishes the
// tx. The output address is only marked as used if the publish succeeds. // tx. The output address is only marked as used if the publish succeeds.
func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight, func (s *UtxoSweeper) sweep(inputs inputSet,
currentHeight int32) error { feeRate chainfee.SatPerKWeight) error {
// Generate an output script if there isn't an unused script available. // Generate an output script if there isn't an unused script available.
if s.currentOutputScript == nil { if s.currentOutputScript == nil {
@ -1164,7 +1161,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight,
// Create sweep tx. // Create sweep tx.
tx, fee, err := createSweepTx( tx, fee, err := createSweepTx(
inputs, nil, s.currentOutputScript, uint32(currentHeight), inputs, nil, s.currentOutputScript, uint32(s.currentHeight),
feeRate, s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, feeRate, s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer,
) )
if err != nil { if err != nil {
@ -1190,10 +1187,10 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight,
// Reschedule the inputs that we just tried to sweep. This is done in // Reschedule the inputs that we just tried to sweep. This is done in
// case the following publish fails, we'd like to update the inputs' // case the following publish fails, we'd like to update the inputs'
// publish attempts and rescue them in the next sweep. // publish attempts and rescue them in the next sweep.
s.rescheduleInputs(tx.TxIn, currentHeight) s.rescheduleInputs(tx.TxIn)
log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v",
tx.TxHash(), len(tx.TxIn), currentHeight) tx.TxHash(), len(tx.TxIn), s.currentHeight)
// Publish the sweeping tx with customized label. // Publish the sweeping tx with customized label.
err = s.cfg.Wallet.PublishTransaction( err = s.cfg.Wallet.PublishTransaction(
@ -1225,9 +1222,7 @@ func (s *UtxoSweeper) sweep(inputs inputSet, feeRate chainfee.SatPerKWeight,
// increments the `publishAttempts` and calculates the next broadcast height // increments the `publishAttempts` and calculates the next broadcast height
// for each input. When the publishAttempts exceeds MaxSweepAttemps(10), this // for each input. When the publishAttempts exceeds MaxSweepAttemps(10), this
// input will be removed. // input will be removed.
func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn, func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) {
currentHeight int32) {
// Reschedule sweep. // Reschedule sweep.
for _, input := range inputs { for _, input := range inputs {
pi, ok := s.pendingInputs[input.PreviousOutPoint] pi, ok := s.pendingInputs[input.PreviousOutPoint]
@ -1251,7 +1246,7 @@ func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn,
pi.publishAttempts, pi.publishAttempts,
) )
pi.minPublishHeight = currentHeight + nextAttemptDelta pi.minPublishHeight = s.currentHeight + nextAttemptDelta
log.Debugf("Rescheduling input %v after %v attempts at "+ log.Debugf("Rescheduling input %v after %v attempts at "+
"height %v (delta %v)", input.PreviousOutPoint, "height %v (delta %v)", input.PreviousOutPoint,
@ -1412,7 +1407,7 @@ func (s *UtxoSweeper) UpdateParams(input wire.OutPoint,
// - Ensure we don't combine this input with any other unconfirmed inputs that // - Ensure we don't combine this input with any other unconfirmed inputs that
// did not exist in the original sweep transaction, resulting in an invalid // did not exist in the original sweep transaction, resulting in an invalid
// replacement transaction. // replacement transaction.
func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) ( func (s *UtxoSweeper) handleUpdateReq(req *updateReq) (
chan Result, error) { chan Result, error) {
// If the UtxoSweeper is already trying to sweep this input, then we can // If the UtxoSweeper is already trying to sweep this input, then we can
@ -1445,7 +1440,7 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) (
// NOTE: The UtxoSweeper is not yet offered time-locked inputs, so the // NOTE: The UtxoSweeper is not yet offered time-locked inputs, so the
// check for broadcast attempts is redundant at the moment. // check for broadcast attempts is redundant at the moment.
if pendingInput.publishAttempts > 0 { if pendingInput.publishAttempts > 0 {
pendingInput.minPublishHeight = bestHeight pendingInput.minPublishHeight = s.currentHeight
} }
resultChan := make(chan Result, 1) resultChan := make(chan Result, 1)
@ -1469,8 +1464,8 @@ func (s *UtxoSweeper) handleUpdateReq(req *updateReq, bestHeight int32) (
// - Make handling re-orgs easier. // - Make handling re-orgs easier.
// - Thwart future possible fee sniping attempts. // - Thwart future possible fee sniping attempts.
// - Make us blend in with the bitcoind wallet. // - Make us blend in with the bitcoind wallet.
func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference, func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input,
currentBlockHeight uint32) (*wire.MsgTx, error) { feePref FeePreference) (*wire.MsgTx, error) {
feePerKw, err := s.cfg.DetermineFeePerKw(s.cfg.FeeEstimator, feePref) feePerKw, err := s.cfg.DetermineFeePerKw(s.cfg.FeeEstimator, feePref)
if err != nil { if err != nil {
@ -1484,7 +1479,7 @@ func (s *UtxoSweeper) CreateSweepTx(inputs []input.Input, feePref FeePreference,
} }
tx, _, err := createSweepTx( tx, _, err := createSweepTx(
inputs, nil, pkScript, currentBlockHeight, feePerKw, inputs, nil, pkScript, uint32(s.currentHeight), feePerKw,
s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer, s.cfg.MaxFeeRate.FeePerKWeight(), s.cfg.Signer,
) )
@ -1506,8 +1501,7 @@ func (s *UtxoSweeper) ListSweeps() ([]chainhash.Hash, error) {
// handleNewInput processes a new input by registering spend notification and // handleNewInput processes a new input by registering spend notification and
// scheduling sweeping for it. // scheduling sweeping for it.
func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage, func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) {
bestHeight int32) {
outpoint := *input.input.OutPoint() outpoint := *input.input.OutPoint()
pendInput, pending := s.pendingInputs[outpoint] pendInput, pending := s.pendingInputs[outpoint]
@ -1525,7 +1519,7 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage,
pendInput = &pendingInput{ pendInput = &pendingInput{
listeners: []chan Result{input.resultChan}, listeners: []chan Result{input.resultChan},
Input: input.input, Input: input.input,
minPublishHeight: bestHeight, minPublishHeight: s.currentHeight,
params: input.params, params: input.params,
} }
s.pendingInputs[outpoint] = pendInput s.pendingInputs[outpoint] = pendInput
@ -1668,7 +1662,7 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) {
// handleSweep is called when the ticker fires. It will create clusters and // handleSweep is called when the ticker fires. It will create clusters and
// attempt to create and publish the sweeping transactions. // attempt to create and publish the sweeping transactions.
func (s *UtxoSweeper) handleSweep(bestHeight int32) { func (s *UtxoSweeper) handleSweep() {
// We'll attempt to cluster all of our inputs with similar fee rates. // We'll attempt to cluster all of our inputs with similar fee rates.
// Before attempting to sweep them, we'll sort them in descending fee // Before attempting to sweep them, we'll sort them in descending fee
// rate order. We do this to ensure any inputs which have had their fee // rate order. We do this to ensure any inputs which have had their fee
@ -1680,7 +1674,7 @@ func (s *UtxoSweeper) handleSweep(bestHeight int32) {
}) })
for _, cluster := range inputClusters { for _, cluster := range inputClusters {
err := s.sweepCluster(cluster, bestHeight) err := s.sweepCluster(cluster)
if err != nil { if err != nil {
log.Errorf("input cluster sweep: %v", err) log.Errorf("input cluster sweep: %v", err)
} }

View File

@ -2530,7 +2530,7 @@ func TestGetInputLists(t *testing.T) {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
allSets, newSets, err := s.getInputLists(tc.cluster, 0) allSets, newSets, err := s.getInputLists(tc.cluster)
require.NoError(t, err) require.NoError(t, err)
if tc.expectNilNewSet { if tc.expectNilNewSet {