sweep: add requestID to monitorRecord

This way we can greatly simplify the method signatures, also paving the
upcoming changes where we wanna make it clear when updating the
monitorRecord, we only touch a portion of it.
This commit is contained in:
yyforyongyu 2025-01-21 01:19:02 +08:00
parent bde5124e1b
commit 7eea7a7e9a
No known key found for this signature in database
GPG Key ID: 9BCD95C4FF296868
2 changed files with 84 additions and 63 deletions

View File

@ -410,34 +410,35 @@ func (t *TxPublisher) Broadcast(req *BumpRequest) <-chan *BumpResult {
lnutils.SpewLogClosure(req))
// Store the request.
requestID, record := t.storeInitialRecord(req)
record := t.storeInitialRecord(req)
// Create a chan to send the result to the caller.
subscriber := make(chan *BumpResult, 1)
t.subscriberChans.Store(requestID, subscriber)
t.subscriberChans.Store(record.requestID, subscriber)
// Publish the tx immediately if specified.
if req.Immediate {
t.handleInitialBroadcast(record, requestID)
t.handleInitialBroadcast(record)
}
return subscriber
}
// storeInitialRecord initializes a monitor record and saves it in the map.
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) (
uint64, *monitorRecord) {
func (t *TxPublisher) storeInitialRecord(req *BumpRequest) *monitorRecord {
// Increase the request counter.
//
// NOTE: this is the only place where we increase the counter.
requestID := t.requestCounter.Add(1)
// Register the record.
record := &monitorRecord{req: req}
record := &monitorRecord{
requestID: requestID,
req: req,
}
t.records.Store(requestID, record)
return requestID, record
return record
}
// storeRecord stores the given record in the records map.
@ -446,6 +447,7 @@ func (t *TxPublisher) storeRecord(requestID uint64, sweepCtx *sweepTxCtx,
// Register the record.
t.records.Store(requestID, &monitorRecord{
requestID: requestID,
tx: sweepCtx.tx,
req: req,
feeFunction: f,
@ -461,16 +463,25 @@ func (t *TxPublisher) Name() string {
// initializeTx initializes a fee function and creates an RBF-compliant tx. If
// succeeded, the initial tx is stored in the records map.
func (t *TxPublisher) initializeTx(requestID uint64, req *BumpRequest) error {
func (t *TxPublisher) initializeTx(r *monitorRecord) error {
// Create a fee bumping algorithm to be used for future RBF.
feeAlgo, err := t.initializeFeeFunction(req)
feeAlgo, err := t.initializeFeeFunction(r.req)
if err != nil {
return fmt.Errorf("init fee function: %w", err)
}
// Attach the newly created fee function.
//
// TODO(yy): current we'd initialize a monitorRecord before creating the
// fee function, while we could instead create the fee function first
// then save it to the record. To make this happen we need to change the
// conf target calculation below since we would be initializing the fee
// function one block before.
r.feeFunction = feeAlgo
// Create the initial tx to be broadcasted. This tx is guaranteed to
// comply with the RBF restrictions.
err = t.createRBFCompliantTx(requestID, req, feeAlgo)
err = t.createRBFCompliantTx(r)
if err != nil {
return fmt.Errorf("create RBF-compliant tx: %w", err)
}
@ -511,24 +522,24 @@ func (t *TxPublisher) initializeFeeFunction(
// so by creating a tx, validate it using `TestMempoolAccept`, and bump its fee
// and redo the process until the tx is valid, or return an error when non-RBF
// related errors occur or the budget has been used up.
func (t *TxPublisher) createRBFCompliantTx(requestID uint64, req *BumpRequest,
f FeeFunction) error {
func (t *TxPublisher) createRBFCompliantTx(r *monitorRecord) error {
f := r.feeFunction
for {
// Create a new tx with the given fee rate and check its
// mempool acceptance.
sweepCtx, err := t.createAndCheckTx(req, f)
sweepCtx, err := t.createAndCheckTx(r.req, f)
switch {
case err == nil:
// The tx is valid, store it.
t.storeRecord(requestID, sweepCtx, req, f)
t.storeRecord(r.requestID, sweepCtx, r.req, f)
log.Infof("Created initial sweep tx=%v for %v inputs: "+
"feerate=%v, fee=%v, inputs:\n%v",
sweepCtx.tx.TxHash(), len(req.Inputs),
sweepCtx.tx.TxHash(), len(r.req.Inputs),
f.FeeRate(), sweepCtx.fee,
inputTypeSummary(req.Inputs))
inputTypeSummary(r.req.Inputs))
return nil
@ -773,6 +784,9 @@ func (t *TxPublisher) handleResult(result *BumpResult) {
// monitorRecord is used to keep track of the tx being monitored by the
// publisher internally.
type monitorRecord struct {
// requestID is the ID of the request that created this record.
requestID uint64
// tx is the tx being monitored.
tx *wire.MsgTx
@ -915,35 +929,35 @@ func (t *TxPublisher) processRecords() {
t.records.ForEach(visitor)
// Handle the initial broadcast.
for requestID, r := range initialRecords {
t.handleInitialBroadcast(r, requestID)
for _, r := range initialRecords {
t.handleInitialBroadcast(r)
}
// For records that are confirmed, we'll notify the caller about this
// result.
for requestID, r := range confirmedRecords {
for _, r := range confirmedRecords {
log.Debugf("Tx=%v is confirmed", r.tx.TxHash())
t.wg.Add(1)
go t.handleTxConfirmed(r, requestID)
go t.handleTxConfirmed(r)
}
// Get the current height to be used in the following goroutines.
currentHeight := t.currentHeight.Load()
// For records that are not confirmed, we perform a fee bump if needed.
for requestID, r := range feeBumpRecords {
for _, r := range feeBumpRecords {
log.Debugf("Attempting to fee bump Tx=%v", r.tx.TxHash())
t.wg.Add(1)
go t.handleFeeBumpTx(requestID, r, currentHeight)
go t.handleFeeBumpTx(r, currentHeight)
}
// For records that are failed, we'll notify the caller about this
// result.
for requestID, r := range failedRecords {
for _, r := range failedRecords {
log.Debugf("Tx=%v has inputs been spent by a third party, "+
"failing it now", r.tx.TxHash())
t.wg.Add(1)
go t.handleThirdPartySpent(r, requestID)
go t.handleThirdPartySpent(r)
}
}
@ -951,7 +965,7 @@ func (t *TxPublisher) processRecords() {
// notify the subscriber then remove the record from the maps .
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
func (t *TxPublisher) handleTxConfirmed(r *monitorRecord) {
defer t.wg.Done()
// Create a result that will be sent to the resultChan which is
@ -959,7 +973,7 @@ func (t *TxPublisher) handleTxConfirmed(r *monitorRecord, requestID uint64) {
result := &BumpResult{
Event: TxConfirmed,
Tx: r.tx,
requestID: requestID,
requestID: r.requestID,
Fee: r.fee,
FeeRate: r.feeFunction.FeeRate(),
}
@ -1017,10 +1031,8 @@ func (t *TxPublisher) handleInitialTxError(requestID uint64, err error) {
// 1. init a fee function based on the given strategy.
// 2. create an RBF-compliant tx and monitor it for confirmation.
// 3. notify the initial broadcast result back to the caller.
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
requestID uint64) {
log.Debugf("Initial broadcast for requestID=%v", requestID)
func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord) {
log.Debugf("Initial broadcast for requestID=%v", r.requestID)
var (
result *BumpResult
@ -1031,18 +1043,18 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
// RBF rules.
//
// Create the initial tx to be broadcasted.
err = t.initializeTx(requestID, r.req)
err = t.initializeTx(r)
if err != nil {
log.Errorf("Initial broadcast failed: %v", err)
// We now handle the initialization error and exit.
t.handleInitialTxError(requestID, err)
t.handleInitialTxError(r.requestID, err)
return
}
// Successfully created the first tx, now broadcast it.
result, err = t.broadcast(requestID)
result, err = t.broadcast(r.requestID)
if err != nil {
// The broadcast failed, which can only happen if the tx record
// cannot be found or the aux sweeper returns an error. In
@ -1051,7 +1063,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
result = &BumpResult{
Event: TxFailed,
Err: err,
requestID: requestID,
requestID: r.requestID,
}
}
@ -1062,9 +1074,7 @@ func (t *TxPublisher) handleInitialBroadcast(r *monitorRecord,
// attempt to bump the fee of the tx.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,
currentHeight int32) {
func (t *TxPublisher) handleFeeBumpTx(r *monitorRecord, currentHeight int32) {
defer t.wg.Done()
oldTxid := r.tx.TxHash()
@ -1095,7 +1105,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,
// The fee function now has a new fee rate, we will use it to bump the
// fee of the tx.
resultOpt := t.createAndPublishTx(requestID, r)
resultOpt := t.createAndPublishTx(r)
// If there's a result, we will notify the caller about the result.
resultOpt.WhenSome(func(result BumpResult) {
@ -1109,9 +1119,7 @@ func (t *TxPublisher) handleFeeBumpTx(requestID uint64, r *monitorRecord,
// and send a TxFailed event to the subscriber.
//
// NOTE: Must be run as a goroutine to avoid blocking on sending the result.
func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
requestID uint64) {
func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord) {
defer t.wg.Done()
// Create a result that will be sent to the resultChan which is
@ -1123,7 +1131,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
result := &BumpResult{
Event: TxFailed,
Tx: r.tx,
requestID: requestID,
requestID: r.requestID,
Err: ErrThirdPartySpent,
}
@ -1134,7 +1142,7 @@ func (t *TxPublisher) handleThirdPartySpent(r *monitorRecord,
// createAndPublishTx creates a new tx with a higher fee rate and publishes it
// to the network. It will update the record with the new tx and fee rate if
// successfully created, and return the result when published successfully.
func (t *TxPublisher) createAndPublishTx(requestID uint64,
func (t *TxPublisher) createAndPublishTx(
r *monitorRecord) fn.Option[BumpResult] {
// Fetch the old tx.
@ -1185,16 +1193,16 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64,
Event: TxFailed,
Tx: oldTx,
Err: err,
requestID: requestID,
requestID: r.requestID,
})
}
// The tx has been created without any errors, we now register a new
// record by overwriting the same requestID.
t.storeRecord(requestID, sweepCtx, r.req, r.feeFunction)
t.storeRecord(r.requestID, sweepCtx, r.req, r.feeFunction)
// Attempt to broadcast this new tx.
result, err := t.broadcast(requestID)
result, err := t.broadcast(r.requestID)
if err != nil {
log.Infof("Failed to broadcast replacement tx %v: %v",
sweepCtx.tx.TxHash(), err)

View File

@ -664,11 +664,19 @@ func TestCreateRBFCompliantTx(t *testing.T) {
tc := tc
rid := requestCounter.Add(1)
// Create a test record.
record := &monitorRecord{
requestID: rid,
req: req,
feeFunction: m.feeFunc,
}
t.Run(tc.name, func(t *testing.T) {
tc.setupMock()
// Call the method under test.
err := tp.createRBFCompliantTx(rid, req, m.feeFunc)
err := tp.createRBFCompliantTx(record)
// Check the result is as expected.
require.ErrorIs(t, err, tc.expectedErr)
@ -1082,6 +1090,7 @@ func TestCreateAnPublishFail(t *testing.T) {
// Overwrite the budget to make it smaller than the fee.
req.Budget = 100
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
tx: &wire.MsgTx{},
@ -1097,7 +1106,7 @@ func TestCreateAnPublishFail(t *testing.T) {
mock.Anything).Return(script, nil)
// Call the createAndPublish method.
resultOpt := tp.createAndPublishTx(requestID, record)
resultOpt := tp.createAndPublishTx(record)
result := resultOpt.UnwrapOrFail(t)
// We expect the result to be TxFailed and the error is set in the
@ -1116,7 +1125,7 @@ func TestCreateAnPublishFail(t *testing.T) {
mock.Anything).Return(lnwallet.ErrMempoolFee).Once()
// Call the createAndPublish method and expect a none option.
resultOpt = tp.createAndPublishTx(requestID, record)
resultOpt = tp.createAndPublishTx(record)
require.True(t, resultOpt.IsNone())
// Mock the testmempoolaccept to return a fee related error that should
@ -1125,7 +1134,7 @@ func TestCreateAnPublishFail(t *testing.T) {
mock.Anything).Return(chain.ErrInsufficientFee).Once()
// Call the createAndPublish method and expect a none option.
resultOpt = tp.createAndPublishTx(requestID, record)
resultOpt = tp.createAndPublishTx(record)
require.True(t, resultOpt.IsNone())
}
@ -1147,6 +1156,7 @@ func TestCreateAnPublishSuccess(t *testing.T) {
// Create a testing monitor record.
req := createTestBumpRequest()
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
tx: &wire.MsgTx{},
@ -1169,7 +1179,7 @@ func TestCreateAnPublishSuccess(t *testing.T) {
mock.Anything, mock.Anything).Return(errDummy).Once()
// Call the createAndPublish method and expect a failure result.
resultOpt := tp.createAndPublishTx(requestID, record)
resultOpt := tp.createAndPublishTx(record)
result := resultOpt.UnwrapOrFail(t)
// We expect the result to be TxFailed and the error is set.
@ -1190,7 +1200,7 @@ func TestCreateAnPublishSuccess(t *testing.T) {
mock.Anything, mock.Anything).Return(nil).Once()
// Call the createAndPublish method and expect a success result.
resultOpt = tp.createAndPublishTx(requestID, record)
resultOpt = tp.createAndPublishTx(record)
result = resultOpt.UnwrapOrFail(t)
require.True(t, resultOpt.IsSome())
@ -1258,7 +1268,7 @@ func TestHandleTxConfirmed(t *testing.T) {
tp.wg.Add(1)
done := make(chan struct{})
go func() {
tp.handleTxConfirmed(record, requestID)
tp.handleTxConfirmed(record)
close(done)
}()
@ -1304,7 +1314,11 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Create a testing monitor record.
req := createTestBumpRequest()
// Create a testing record and put it in the map.
requestID := uint64(1)
record := &monitorRecord{
requestID: requestID,
req: req,
feeFunction: m.feeFunc,
tx: tx,
@ -1317,10 +1331,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
utxoIndex := map[wire.OutPoint]int{
op: 0,
}
// Create a testing record and put it in the map.
fee := btcutil.Amount(1000)
requestID := uint64(1)
// Create a sweepTxCtx.
sweepCtx := &sweepTxCtx{
@ -1345,7 +1356,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Call the method and expect no result received.
tp.wg.Add(1)
go tp.handleFeeBumpTx(requestID, record, testHeight)
go tp.handleFeeBumpTx(record, testHeight)
// Check there's no result sent back.
select {
@ -1359,7 +1370,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
// Call the method and expect no result received.
tp.wg.Add(1)
go tp.handleFeeBumpTx(requestID, record, testHeight)
go tp.handleFeeBumpTx(record, testHeight)
// Check there's no result sent back.
select {
@ -1391,7 +1402,7 @@ func TestHandleFeeBumpTx(t *testing.T) {
//
// NOTE: must be called in a goroutine in case it blocks.
tp.wg.Add(1)
go tp.handleFeeBumpTx(requestID, record, testHeight)
go tp.handleFeeBumpTx(record, testHeight)
select {
case <-time.After(time.Second):
@ -1437,6 +1448,7 @@ func TestProcessRecords(t *testing.T) {
// Create a monitor record that's confirmed.
recordConfirmed := &monitorRecord{
requestID: requestID1,
req: req1,
feeFunction: m.feeFunc,
tx: tx1,
@ -1450,6 +1462,7 @@ func TestProcessRecords(t *testing.T) {
// Create a monitor record that's not confirmed. We know it's not
// confirmed because the num of confirms is zero.
recordFeeBump := &monitorRecord{
requestID: requestID2,
req: req2,
feeFunction: m.feeFunc,
tx: tx2,
@ -1588,7 +1601,7 @@ func TestHandleInitialBroadcastSuccess(t *testing.T) {
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
tp.handleInitialBroadcast(rec)
// Check the result is sent back.
select {
@ -1659,7 +1672,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) {
// Call the method under test and expect an error returned.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
tp.handleInitialBroadcast(rec)
// Check the result is sent back.
select {
@ -1692,7 +1705,7 @@ func TestHandleInitialBroadcastFail(t *testing.T) {
// Call the method under test.
tp.wg.Add(1)
tp.handleInitialBroadcast(rec, rid)
tp.handleInitialBroadcast(rec)
// Check the result is sent back.
select {