diff --git a/sweep/sweeper.go b/sweep/sweeper.go index c3ce504ec..549095868 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -631,12 +631,6 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { return } - // Create a ticker based on the config duration. - ticker := time.NewTicker(s.cfg.TickerDuration) - defer ticker.Stop() - - log.Debugf("Sweep ticker started") - for { // Clean inputs, which will remove inputs that are swept, // failed, or excluded from the sweeper and return inputs that @@ -651,6 +645,13 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { case input := <-s.newInputs: s.handleNewInput(input) + // If this input is forced, we perform an sweep + // immediately. + if input.params.Force { + inputs = s.updateSweeperInputs() + s.sweepPendingInputs(inputs) + } + // A spend of one of our inputs is detected. Signal sweep // results to the caller(s). case spend := <-s.spendChan: @@ -670,14 +671,6 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { err: err, } - // The timer expires and we are going to (re)sweep. - case <-ticker.C: - log.Debugf("Sweep ticker ticks, attempt sweeping %d "+ - "inputs", len(inputs)) - - // Sweep the remaining pending inputs. - s.sweepPendingInputs(inputs) - // A new block comes in, update the bestHeight. // // TODO(yy): this is where we check our published transactions @@ -685,13 +678,22 @@ func (s *UtxoSweeper) collector(blockEpochs <-chan *chainntnfs.BlockEpoch) { // bumper to get an updated fee rate. case epoch, ok := <-blockEpochs: if !ok { + // We should stop the sweeper before stopping + // the chain service. Otherwise it indicates an + // error. + log.Error("Block epoch channel closed") + return } + // Update the sweeper to the best height. s.currentHeight = epoch.Height - log.Debugf("New block: height=%v, sha=%v", - epoch.Height, epoch.Hash) + log.Debugf("Received new block: height=%v, attempt "+ + "sweeping %d inputs", epoch.Height, len(inputs)) + + // Attempt to sweep any pending inputs. + s.sweepPendingInputs(inputs) case <-s.quit: return diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index 0168d9f08..0a914c4c6 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -43,7 +43,8 @@ type sweeperTestContext struct { backend *mockBackend store SweeperStore - publishChan chan wire.MsgTx + publishChan chan wire.MsgTx + currentHeight int32 } var ( @@ -125,12 +126,13 @@ func createSweeperTestContext(t *testing.T) *sweeperTestContext { ) ctx := &sweeperTestContext{ - notifier: notifier, - publishChan: backend.publishChan, - t: t, - estimator: estimator, - backend: backend, - store: store, + notifier: notifier, + publishChan: backend.publishChan, + t: t, + estimator: estimator, + backend: backend, + store: store, + currentHeight: mockChainHeight, } ctx.sweeper = New(&UtxoSweeperConfig{ @@ -214,6 +216,11 @@ func (ctx *sweeperTestContext) assertNoTx() { func (ctx *sweeperTestContext) receiveTx() wire.MsgTx { ctx.t.Helper() + + // Every time we want to receive a tx, we send a new block epoch to the + // sweeper to trigger a sweeping action. + ctx.notifier.NotifyEpochNonBlocking(ctx.currentHeight + 1) + var tx wire.MsgTx select { case tx = <-ctx.publishChan: @@ -1775,6 +1782,10 @@ func TestRequiredTxOuts(t *testing.T) { inputs[*op] = inp } + // Send a new block epoch to trigger the sweeper to + // sweep the inputs. + ctx.notifier.NotifyEpoch(ctx.sweeper.currentHeight + 1) + // Check the sweeps transactions, ensuring all inputs // are there, and all the locktimes are satisfied. var sweeps []*wire.MsgTx diff --git a/sweep/test_utils.go b/sweep/test_utils.go index e36b56a6b..bd4b91bee 100644 --- a/sweep/test_utils.go +++ b/sweep/test_utils.go @@ -40,6 +40,27 @@ func NewMockNotifier(t *testing.T) *MockNotifier { } } +// NotifyEpochNonBlocking simulates a new epoch arriving without blocking when +// the epochChan is not read. +func (m *MockNotifier) NotifyEpochNonBlocking(height int32) { + m.t.Helper() + + for epochChan, chanHeight := range m.epochChan { + // Only send notifications if the height is greater than the + // height the caller passed into the register call. + if chanHeight >= height { + continue + } + + log.Debugf("Notifying height %v to listener", height) + + select { + case epochChan <- &chainntnfs.BlockEpoch{Height: height}: + default: + } + } +} + // NotifyEpoch simulates a new epoch arriving. func (m *MockNotifier) NotifyEpoch(height int32) { m.t.Helper()