wtclient: only fetch retribution info when needed.

Only construct the retribution info at the time that the backup task is
being bound to a session.
This commit is contained in:
Elle Mouton 2023-02-02 12:12:36 +02:00
parent 458ac32146
commit 2371bbf09a
No known key found for this signature in database
GPG Key ID: D7D916376026F177
4 changed files with 128 additions and 114 deletions

View File

@ -10,7 +10,6 @@ import (
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/txscript"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/input"
"github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
@ -40,7 +39,6 @@ import (
type backupTask struct { type backupTask struct {
id wtdb.BackupID id wtdb.BackupID
breachInfo *lnwallet.BreachRetribution breachInfo *lnwallet.BreachRetribution
chanType channeldb.ChannelType
// state-dependent variables // state-dependent variables
@ -55,11 +53,79 @@ type backupTask struct {
outputs []*wire.TxOut outputs []*wire.TxOut
} }
// newBackupTask initializes a new backupTask and populates all state-dependent // newBackupTask initializes a new backupTask.
// variables. func newBackupTask(id wtdb.BackupID, sweepPkScript []byte) *backupTask {
func newBackupTask(chanID *lnwire.ChannelID, return &backupTask{
breachInfo *lnwallet.BreachRetribution, id: id,
sweepPkScript []byte, chanType channeldb.ChannelType) *backupTask { sweepPkScript: sweepPkScript,
}
}
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69 in a later stage.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) {
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession first populates all state-dependent variables of the task. Then
// it determines if the backupTask is compatible with the passed SessionInfo's
// policy. If no error is returned, the task has been bound to the session and
// can be queued to upload to the tower. Otherwise, the bind failed and should
// be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody,
newBreachRetribution BreachRetributionBuilder) error {
breachInfo, chanType, err := newBreachRetribution(
t.id.ChanID, t.id.CommitHeight,
)
if err != nil {
return err
}
// Parse the non-dust outputs from the breach transaction, // Parse the non-dust outputs from the breach transaction,
// simultaneously computing the total amount contained in the inputs // simultaneously computing the total amount contained in the inputs
@ -123,76 +189,11 @@ func newBackupTask(chanID *lnwire.ChannelID,
totalAmt += breachInfo.LocalOutputSignDesc.Output.Value totalAmt += breachInfo.LocalOutputSignDesc.Output.Value
} }
return &backupTask{ t.breachInfo = breachInfo
id: wtdb.BackupID{ t.toLocalInput = toLocalInput
ChanID: *chanID, t.toRemoteInput = toRemoteInput
CommitHeight: breachInfo.RevokedStateNum, t.totalAmt = btcutil.Amount(totalAmt)
},
breachInfo: breachInfo,
chanType: chanType,
toLocalInput: toLocalInput,
toRemoteInput: toRemoteInput,
totalAmt: btcutil.Amount(totalAmt),
sweepPkScript: sweepPkScript,
}
}
// inputs returns all non-dust inputs that we will attempt to spend from.
//
// NOTE: Ordering of the inputs is not critical as we sort the transaction with
// BIP69.
func (t *backupTask) inputs() map[wire.OutPoint]input.Input {
inputs := make(map[wire.OutPoint]input.Input)
if t.toLocalInput != nil {
inputs[*t.toLocalInput.OutPoint()] = t.toLocalInput
}
if t.toRemoteInput != nil {
inputs[*t.toRemoteInput.OutPoint()] = t.toRemoteInput
}
return inputs
}
// addrType returns the type of an address after parsing it and matching it to
// the set of known script templates.
func addrType(pkScript []byte) txscript.ScriptClass {
// We pass in a set of dummy chain params here as they're only needed
// to make the address struct, which we're ignoring anyway (scripts are
// always the same, it's addresses that change across chains).
scriptClass, _, _, _ := txscript.ExtractPkScriptAddrs(
pkScript, &chaincfg.MainNetParams,
)
return scriptClass
}
// addScriptWeight parses the passed pkScript and adds the computed weight cost
// were the script to be added to the justice transaction.
func addScriptWeight(weightEstimate *input.TxWeightEstimator,
pkScript []byte) error {
switch addrType(pkScript) { //nolint: whitespace
case txscript.WitnessV0PubKeyHashTy:
weightEstimate.AddP2WKHOutput()
case txscript.WitnessV0ScriptHashTy:
weightEstimate.AddP2WSHOutput()
case txscript.WitnessV1TaprootTy:
weightEstimate.AddP2TROutput()
default:
return fmt.Errorf("invalid addr type: %v", addrType(pkScript))
}
return nil
}
// bindSession determines if the backupTask is compatible with the passed
// SessionInfo's policy. If no error is returned, the task has been bound to the
// session and can be queued to upload to the tower. Otherwise, the bind failed
// and should be rescheduled with a different session.
func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// First we'll begin by deriving a weight estimate for the justice // First we'll begin by deriving a weight estimate for the justice
// transaction. The final weight can be different depending on whether // transaction. The final weight can be different depending on whether
// the watchtower is taking a reward. // the watchtower is taking a reward.
@ -208,7 +209,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// original weight estimate. For anchor channels we'll go ahead // original weight estimate. For anchor channels we'll go ahead
// an use the correct penalty witness when signing our justice // an use the correct penalty witness when signing our justice
// transactions. // transactions.
if t.chanType.HasAnchors() { if chanType.HasAnchors() {
weightEstimate.AddWitnessInput( weightEstimate.AddWitnessInput(
input.ToLocalPenaltyWitnessSize, input.ToLocalPenaltyWitnessSize,
) )
@ -222,7 +223,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// Legacy channels (both tweaked and non-tweaked) spend from // Legacy channels (both tweaked and non-tweaked) spend from
// P2WKH output. Anchor channels spend a to-remote confirmed // P2WKH output. Anchor channels spend a to-remote confirmed
// P2WSH output. // P2WSH output.
if t.chanType.HasAnchors() { if chanType.HasAnchors() {
weightEstimate.AddWitnessInput( weightEstimate.AddWitnessInput(
input.ToRemoteConfirmedWitnessSize, input.ToRemoteConfirmedWitnessSize,
) )
@ -233,7 +234,7 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
// All justice transactions will either use segwit v0 (p2wkh + p2wsh) // All justice transactions will either use segwit v0 (p2wkh + p2wsh)
// or segwit v1 (p2tr). // or segwit v1 (p2tr).
err := addScriptWeight(&weightEstimate, t.sweepPkScript) err = addScriptWeight(&weightEstimate, t.sweepPkScript)
if err != nil { if err != nil {
return err return err
} }
@ -247,9 +248,9 @@ func (t *backupTask) bindSession(session *wtdb.ClientSessionBody) error {
} }
} }
if t.chanType.HasAnchors() != session.Policy.IsAnchorChannel() { if chanType.HasAnchors() != session.Policy.IsAnchorChannel() {
log.Criticalf("Invalid task (has_anchors=%t) for session "+ log.Criticalf("Invalid task (has_anchors=%t) for session "+
"(has_anchors=%t)", t.chanType.HasAnchors(), "(has_anchors=%t)", chanType.HasAnchors(),
session.Policy.IsAnchorChannel()) session.Policy.IsAnchorChannel())
} }

View File

@ -483,19 +483,19 @@ func TestBackupTask(t *testing.T) {
func testBackupTask(t *testing.T, test backupTaskTest) { func testBackupTask(t *testing.T, test backupTaskTest) {
// Create a new backupTask from the channel id and breach info. // Create a new backupTask from the channel id and breach info.
task := newBackupTask( id := wtdb.BackupID{
&test.chanID, test.breachInfo, test.expSweepScript, ChanID: test.chanID,
test.chanType, CommitHeight: test.breachInfo.RevokedStateNum,
) }
task := newBackupTask(id, test.expSweepScript)
// Assert that all parameters set during initialization are properly // getBreachInfo is a helper closure that returns the breach retribution
// populated. // info and channel type for the given channel and commit height.
require.Equal(t, test.chanID, task.id.ChanID) getBreachInfo := func(id lnwire.ChannelID, commitHeight uint64) (
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight) *lnwallet.BreachRetribution, channeldb.ChannelType, error) {
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo) return test.breachInfo, test.chanType, nil
require.Equal(t, test.expToLocalInput, task.toLocalInput) }
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// Reconstruct the expected input.Inputs that will be returned by the // Reconstruct the expected input.Inputs that will be returned by the
// task's inputs() method. // task's inputs() method.
@ -515,9 +515,18 @@ func testBackupTask(t *testing.T, test backupTaskTest) {
// Now, bind the session to the task. If successful, this locks in the // Now, bind the session to the task. If successful, this locks in the
// session's negotiated parameters and allows the backup task to derive // session's negotiated parameters and allows the backup task to derive
// the final free variables in the justice transaction. // the final free variables in the justice transaction.
err := task.bindSession(test.session) err := task.bindSession(test.session, getBreachInfo)
require.ErrorIs(t, err, test.bindErr) require.ErrorIs(t, err, test.bindErr)
// Assert that all parameters set during after binding the backup task
// are properly populated.
require.Equal(t, test.chanID, task.id.ChanID)
require.Equal(t, test.breachInfo.RevokedStateNum, task.id.CommitHeight)
require.Equal(t, test.expTotalAmt, task.totalAmt)
require.Equal(t, test.breachInfo, task.breachInfo)
require.Equal(t, test.expToLocalInput, task.toLocalInput)
require.Equal(t, test.expToRemoteInput, task.toRemoteInput)
// Exit early if the bind was supposed to fail. But first, we check that // Exit early if the bind was supposed to fail. But first, we check that
// all fields set during a bind are still unset. This ensure that a // all fields set during a bind are still unset. This ensure that a
// failed bind doesn't have side-effects if the task is retried with a // failed bind doesn't have side-effects if the task is retried with a

View File

@ -829,17 +829,12 @@ func (c *TowerClient) BackupState(chanID *lnwire.ChannelID,
c.chanCommitHeights[*chanID] = stateNum c.chanCommitHeights[*chanID] = stateNum
c.backupMu.Unlock() c.backupMu.Unlock()
// Fetch the breach retribution info and channel type. id := wtdb.BackupID{
breachInfo, chanType, err := c.cfg.BuildBreachRetribution( ChanID: *chanID,
*chanID, stateNum, CommitHeight: stateNum,
)
if err != nil {
return err
} }
task := newBackupTask( task := newBackupTask(id, summary.SweepPkScript)
chanID, breachInfo, summary.SweepPkScript, chanType,
)
return c.pipeline.QueueBackupTask(task) return c.pipeline.QueueBackupTask(task)
} }
@ -1543,16 +1538,17 @@ func (c *TowerClient) newSessionQueue(s *ClientSession,
updates []wtdb.CommittedUpdate) *sessionQueue { updates []wtdb.CommittedUpdate) *sessionQueue {
return newSessionQueue(&sessionQueueConfig{ return newSessionQueue(&sessionQueueConfig{
ClientSession: s, ClientSession: s,
ChainHash: c.cfg.ChainHash, ChainHash: c.cfg.ChainHash,
Dial: c.dial, Dial: c.dial,
ReadMessage: c.readMessage, ReadMessage: c.readMessage,
SendMessage: c.sendMessage, SendMessage: c.sendMessage,
Signer: c.cfg.Signer, Signer: c.cfg.Signer,
DB: c.cfg.DB, DB: c.cfg.DB,
MinBackoff: c.cfg.MinBackoff, MinBackoff: c.cfg.MinBackoff,
MaxBackoff: c.cfg.MaxBackoff, MaxBackoff: c.cfg.MaxBackoff,
Log: c.log, Log: c.log,
BuildBreachRetribution: c.cfg.BuildBreachRetribution,
}, updates) }, updates)
} }

View File

@ -57,6 +57,11 @@ type sessionQueueConfig struct {
// for justice transaction inputs. // for justice transaction inputs.
Signer input.Signer Signer input.Signer
// BuildBreachRetribution is a function closure that allows the client
// to fetch the breach retribution info for a certain channel at a
// certain revoked commitment height.
BuildBreachRetribution BreachRetributionBuilder
// DB provides access to the client's stable storage. // DB provides access to the client's stable storage.
DB DB DB DB
@ -220,7 +225,10 @@ func (q *sessionQueue) AcceptTask(task *backupTask) (reserveStatus, bool) {
// //
// TODO(conner): queue backups and retry with different session params. // TODO(conner): queue backups and retry with different session params.
case reserveAvailable: case reserveAvailable:
err := task.bindSession(&q.cfg.ClientSession.ClientSessionBody) err := task.bindSession(
&q.cfg.ClientSession.ClientSessionBody,
q.cfg.BuildBreachRetribution,
)
if err != nil { if err != nil {
q.queueCond.L.Unlock() q.queueCond.L.Unlock()
q.log.Debugf("SessionQueue(%s) rejected %v: %v ", q.log.Debugf("SessionQueue(%s) rejected %v: %v ",