diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 69f3403e2..16439b1bc 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -88,12 +88,68 @@ func (p Params) String() string { p.Fee, p.Force) } +// SweepState represents the current state of a pending input. +// +//nolint:revive +type SweepState uint8 + +const ( + // StateInit is the initial state of a pending input. This is set when + // a new sweeping request for a given input is made. + StateInit SweepState = iota + + // StatePendingPublish specifies an input's state where it's already + // been included in a sweeping tx but the tx is not published yet. + // Inputs in this state should not be used for grouping again. + StatePendingPublish + + // StatePublished is the state where the input's sweeping tx has + // successfully been published. Inputs in this state can only be + // updated via RBF. + StatePublished + + // StatePublishFailed is the state when an error is returned from + // publishing the sweeping tx. Inputs in this state can be re-grouped + // in to a new sweeping tx. + StatePublishFailed + + // StateSwept is the final state of a pending input. This is set when + // the input has been successfully swept. + StateSwept +) + +// String gives a human readable text for the sweep states. +func (s SweepState) String() string { + switch s { + case StateInit: + return "Init" + + case StatePendingPublish: + return "PendingPublish" + + case StatePublished: + return "Published" + + case StatePublishFailed: + return "PublishFailed" + + case StateSwept: + return "Swept" + + default: + return "Unknown" + } +} + // pendingInput is created when an input reaches the main loop for the first // time. It wraps the input and tracks all relevant state that is needed for // sweeping. type pendingInput struct { input.Input + // state tracks the current state of the input. + state SweepState + // listeners is a list of channels over which the final outcome of the // sweep needs to be broadcasted. listeners []chan Result @@ -403,6 +459,8 @@ func (s *UtxoSweeper) Stop() error { // NOTE: Extreme care needs to be taken that input isn't changed externally. // Because it is an interface and we don't know what is exactly behind it, we // cannot make a local copy in sweeper. +// +// TODO(yy): make sure the caller is using the Result chan. func (s *UtxoSweeper) SweepInput(input input.Input, params Params) (chan Result, error) { @@ -836,20 +894,13 @@ func (s *UtxoSweeper) sweep(inputs inputSet, Fee: uint64(fee), } - // Add tx before publication, so that we will always know that a spend - // by this tx is ours. Otherwise if the publish doesn't return, but did - // publish, we loose track of this tx. Even republication on startup - // doesn't prevent this, because that call returns a double spend error - // then and would also not add the hash to the store. - err = s.cfg.Store.StoreTx(tr) - if err != nil { - return fmt.Errorf("store tx: %w", err) - } - // Reschedule the inputs that we just tried to sweep. This is done in // case the following publish fails, we'd like to update the inputs' // publish attempts and rescue them in the next sweep. - s.rescheduleInputs(tx.TxIn) + err = s.markInputsPendingPublish(tr, tx) + if err != nil { + return err + } log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", tx.TxHash(), len(tx.TxIn), s.currentHeight) @@ -859,17 +910,16 @@ func (s *UtxoSweeper) sweep(inputs inputSet, tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { + // TODO(yy): find out which input is causing the failure. + s.markInputsPublishFailed(tx.TxIn) + return err } - // Mark this tx in db once successfully published. - // - // NOTE: this will behave as an overwrite, which is fine as the record - // is small. - tr.Published = true - err = s.cfg.Store.StoreTx(tr) + // Inputs have been successfully published so we update their states. + err = s.markInputsPublished(tr, tx.TxIn) if err != nil { - return fmt.Errorf("store tx: %w", err) + return err } // If there's no error, remove the output script. Otherwise keep it so @@ -880,13 +930,27 @@ func (s *UtxoSweeper) sweep(inputs inputSet, return nil } -// rescheduleInputs updates the pending inputs with the given tx inputs. It -// increments the `publishAttempts` and calculates the next broadcast height -// for each input. When the publishAttempts exceeds MaxSweepAttemps(10), this -// input will be removed. -func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) { +// markInputsPendingPublish saves the sweeping tx to db and updates the pending +// inputs with the given tx inputs. It increments the `publishAttempts` and +// calculates the next broadcast height for each input. When the +// publishAttempts exceeds MaxSweepAttemps(10), this input will be removed. +// +// TODO(yy): add unit test once done refactoring. +func (s *UtxoSweeper) markInputsPendingPublish(tr *TxRecord, + tx *wire.MsgTx) error { + + // Add tx to db before publication, so that we will always know that a + // spend by this tx is ours. Otherwise if the publish doesn't return, + // but did publish, we'd lose track of this tx. Even republication on + // startup doesn't prevent this, because that call returns a double + // spend error then and would also not add the hash to the store. + err := s.cfg.Store.StoreTx(tr) + if err != nil { + return fmt.Errorf("store tx: %w", err) + } + // Reschedule sweep. - for _, input := range inputs { + for _, input := range tx.TxIn { pi, ok := s.pendingInputs[input.PreviousOutPoint] if !ok { // It can be that the input has been removed because it @@ -897,6 +961,12 @@ func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) { continue } + // Update the input's state. + // + // TODO: also calculate the fees and fee rate of this tx to + // prepare possible RBF. + pi.state = StatePendingPublish + // Record another publish attempt. pi.publishAttempts++ @@ -927,6 +997,89 @@ func (s *UtxoSweeper) rescheduleInputs(inputs []*wire.TxIn) { }) } } + + return nil +} + +// markInputsPublished updates the sweeping tx in db and marks the list of +// inputs as published. +func (s *UtxoSweeper) markInputsPublished(tr *TxRecord, + inputs []*wire.TxIn) error { + + // Mark this tx in db once successfully published. + // + // NOTE: this will behave as an overwrite, which is fine as the record + // is small. + tr.Published = true + err := s.cfg.Store.StoreTx(tr) + if err != nil { + return fmt.Errorf("store tx: %w", err) + } + + // Reschedule sweep. + for _, input := range inputs { + pi, ok := s.pendingInputs[input.PreviousOutPoint] + if !ok { + // It can be that the input has been removed because it + // exceed the maximum number of attempts in a previous + // input set. It could also be that this input is an + // additional wallet input that was attached. In that + // case there also isn't a pending input to update. + log.Debugf("Skipped marking input as published: %v "+ + "not found in pending inputs", + input.PreviousOutPoint) + + continue + } + + // Valdiate that the input is in an expected state. + if pi.state != StatePendingPublish { + log.Errorf("Expect input %v to have %v, instead it "+ + "has %v", input.PreviousOutPoint, + StatePendingPublish, pi.state) + + continue + } + + // Update the input's state. + pi.state = StatePublished + } + + return nil +} + +// markInputsPublishFailed marks the list of inputs as failed to be published. +func (s *UtxoSweeper) markInputsPublishFailed(inputs []*wire.TxIn) { + // Reschedule sweep. + for _, input := range inputs { + pi, ok := s.pendingInputs[input.PreviousOutPoint] + if !ok { + // It can be that the input has been removed because it + // exceed the maximum number of attempts in a previous + // input set. It could also be that this input is an + // additional wallet input that was attached. In that + // case there also isn't a pending input to update. + log.Debugf("Skipped marking input as publish failed: "+ + "%v not found in pending inputs", + input.PreviousOutPoint) + + continue + } + + // Valdiate that the input is in an expected state. + if pi.state != StatePendingPublish { + log.Errorf("Expect input %v to have %v, instead it "+ + "has %v", input.PreviousOutPoint, + StatePendingPublish, pi.state) + + continue + } + + log.Warnf("Failed to publish input %v", input.PreviousOutPoint) + + // Update the input's state. + pi.state = StatePublishFailed + } } // monitorSpend registers a spend notification with the chain notifier. It @@ -956,8 +1109,8 @@ func (s *UtxoSweeper) monitorSpend(outpoint wire.OutPoint, return } - log.Debugf("Delivering spend ntfn for %v", - outpoint) + log.Debugf("Delivering spend ntfn for %v", outpoint) + select { case s.spendChan <- spend: log.Debugf("Delivered spend ntfn for %v", @@ -1183,6 +1336,7 @@ func (s *UtxoSweeper) handleNewInput(input *sweepInputMessage) { // the passed in result channel. If this input is offered for sweep // again, the result channel will be appended to this slice. pendInput = &pendingInput{ + state: StateInit, listeners: []chan Result{input.resultChan}, Input: input.input, minPublishHeight: s.currentHeight, @@ -1294,8 +1448,14 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) { ) } - // Signal sweep results for inputs in this confirmed tx. - for _, txIn := range spend.SpendingTx.TxIn { + // We now use the spending tx to update the state of the inputs. + s.markInputsSwept(spend.SpendingTx, isOurTx) +} + +// markInputsSwept marks all inputs swept by the spending transaction as swept. +// It will also notify all the subscribers of this input. +func (s *UtxoSweeper) markInputsSwept(tx *wire.MsgTx, isOurTx bool) error { + for _, txIn := range tx.TxIn { outpoint := txIn.PreviousOutPoint // Check if this input is known to us. It could probably be @@ -1307,6 +1467,16 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) { continue } + // This input may already been marked as swept by a previous + // spend notification, which is likely to happen as one sweep + // transaction usually sweeps multiple inputs. + if input.state == StateSwept { + log.Tracef("input %v already swept", outpoint) + continue + } + + input.state = StateSwept + // Return either a nil or a remote spend result. var err error if !isOurTx { @@ -1314,8 +1484,10 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) { } // Signal result channels. + // + // TODO(yy): don't remove it here. s.signalAndRemove(&outpoint, Result{ - Tx: spend.SpendingTx, + Tx: tx, Err: err, }) @@ -1324,6 +1496,8 @@ func (s *UtxoSweeper) handleInputSpent(spend *chainntnfs.SpendDetail) { s.removeExclusiveGroup(*input.params.ExclusiveGroup) } } + + return nil } // handleSweep is called when the ticker fires. It will create clusters and diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index c12b04aae..5f1fed593 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -1,6 +1,7 @@ package sweep import ( + "errors" "os" "runtime/pprof" "testing" @@ -2025,3 +2026,151 @@ func TestGetInputLists(t *testing.T) { }) } } + +// TestMarkInputsPublished checks that given a list of inputs with different +// states, only the state `StatePendingPublish` will be marked as `Published`. +func TestMarkInputsPublished(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock sweeper store. + mockStore := NewMockSweeperStore() + + // Create a test TxRecord and a dummy error. + dummyTR := &TxRecord{} + dummyErr := errors.New("dummy error") + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + }) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + } + + // First, check that when an error is returned from db, it's properly + // returned here. + mockStore.On("StoreTx", dummyTR).Return(dummyErr).Once() + err := s.markInputsPublished(dummyTR, nil) + require.ErrorIs(err, dummyErr) + + // We also expect the record has been marked as published. + require.True(dummyTR.Published) + + // Then, check that the target input has will be correctly marked as + // published. + // + // Mock the store to return nil + mockStore.On("StoreTx", dummyTR).Return(nil).Once() + + // Mark the test inputs. We expect the non-exist input and the + // inputInit to be skipped, and the final input to be marked as + // published. + err = s.markInputsPublished(dummyTR, []*wire.TxIn{ + inputNotExist, inputInit, inputPendingPublish, + }) + require.NoError(err) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 2) + + // We expect the init input's state to stay unchanged. + require.Equal(StateInit, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input's is now marked as published. + require.Equal(StatePublished, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // Assert mocked statements are executed as expected. + mockStore.AssertExpectations(t) +} + +// TestMarkInputsPublishFailed checks that given a list of inputs with +// different states, only the state `StatePendingPublish` will be marked as +// `PublishFailed`. +func TestMarkInputsPublishFailed(t *testing.T) { + t.Parallel() + + require := require.New(t) + + // Create a mock sweeper store. + mockStore := NewMockSweeperStore() + + // Create a test sweeper. + s := New(&UtxoSweeperConfig{ + Store: mockStore, + }) + + // Create three testing inputs. + // + // inputNotExist specifies an input that's not found in the sweeper's + // `pendingInputs` map. + inputNotExist := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 1}, + } + + // inputInit specifies a newly created input. When marking this as + // published, we should see an error log as this input hasn't been + // published yet. + inputInit := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 2}, + } + s.pendingInputs[inputInit.PreviousOutPoint] = &pendingInput{ + state: StateInit, + } + + // inputPendingPublish specifies an input that's about to be published. + inputPendingPublish := &wire.TxIn{ + PreviousOutPoint: wire.OutPoint{Index: 3}, + } + s.pendingInputs[inputPendingPublish.PreviousOutPoint] = &pendingInput{ + state: StatePendingPublish, + } + + // Mark the test inputs. We expect the non-exist input and the + // inputInit to be skipped, and the final input to be marked as + // published. + s.markInputsPublishFailed([]*wire.TxIn{ + inputNotExist, inputInit, inputPendingPublish, + }) + + // We expect unchanged number of pending inputs. + require.Len(s.pendingInputs, 2) + + // We expect the init input's state to stay unchanged. + require.Equal(StateInit, + s.pendingInputs[inputInit.PreviousOutPoint].state) + + // We expect the pending-publish input's is now marked as publish + // failed. + require.Equal(StatePublishFailed, + s.pendingInputs[inputPendingPublish.PreviousOutPoint].state) + + // Assert mocked statements are executed as expected. + mockStore.AssertExpectations(t) +}