diff --git a/watchtower/wtclient/backup_task.go b/watchtower/wtclient/backup_task.go index dd1d4de81..17c4c5a7b 100644 --- a/watchtower/wtclient/backup_task.go +++ b/watchtower/wtclient/backup_task.go @@ -43,6 +43,7 @@ type backupTask struct { toLocalInput input.Input toRemoteInput input.Input totalAmt btcutil.Amount + sweepPkScript []byte // session-dependent variables @@ -53,7 +54,8 @@ type backupTask struct { // newBackupTask initializes a new backupTask and populates all state-dependent // variables. func newBackupTask(chanID *lnwire.ChannelID, - breachInfo *lnwallet.BreachRetribution) *backupTask { + breachInfo *lnwallet.BreachRetribution, + sweepPkScript []byte) *backupTask { // Parse the non-dust outputs from the breach transaction, // simultaneously computing the total amount contained in the inputs @@ -100,6 +102,7 @@ func newBackupTask(chanID *lnwire.ChannelID, toLocalInput: toLocalInput, toRemoteInput: toRemoteInput, totalAmt: btcutil.Amount(totalAmt), + sweepPkScript: sweepPkScript, } } @@ -122,17 +125,13 @@ func (t *backupTask) inputs() map[wire.OutPoint]input.Input { // SessionInfo's policy. If no error is returned, the task has been bound to the // session and can be queued to upload to the tower. Otherwise, the bind failed // and should be rescheduled with a different session. -func (t *backupTask) bindSession(session *wtdb.SessionInfo, - sweepPkScript []byte) error { +func (t *backupTask) bindSession(session *wtdb.SessionInfo) error { // First we'll begin by deriving a weight estimate for the justice // transaction. The final weight can be different depending on whether // the watchtower is taking a reward. var weightEstimate input.TxWeightEstimator - // All justice transactions have a p2wkh output paying to the victim. - weightEstimate.AddP2WKHOutput() - // Next, add the contribution from the inputs that are present on this // breach transaction. if t.toLocalInput != nil { @@ -142,18 +141,27 @@ func (t *backupTask) bindSession(session *wtdb.SessionInfo, weightEstimate.AddWitnessInput(input.P2WKHWitnessSize) } + // All justice transactions have a p2wkh output paying to the victim. + weightEstimate.AddP2WKHOutput() + + // If the justice transaction has a reward output, add the output's + // contribution to the weight estimate. + if session.Policy.BlobType.Has(blob.FlagReward) { + weightEstimate.AddP2WKHOutput() + } + // Now, compute the output values depending on whether FlagReward is set // in the current session's policy. outputs, err := session.Policy.ComputeJusticeTxOuts( t.totalAmt, int64(weightEstimate.Weight()), - sweepPkScript, session.RewardAddress, + t.sweepPkScript, session.RewardAddress, ) if err != nil { return err } - t.outputs = outputs t.blobType = session.Policy.BlobType + t.outputs = outputs return nil } @@ -164,7 +172,7 @@ func (t *backupTask) bindSession(session *wtdb.SessionInfo, // session-dependent variables, and signs the resulting transaction. The // required pieces from signatures, witness scripts, etc are then packaged into // a JusticeKit and encrypted using the breach transaction's key. -func (t *backupTask) craftSessionPayload(sweepPkScript []byte, +func (t *backupTask) craftSessionPayload( signer input.Signer) (wtdb.BreachHint, []byte, error) { var hint wtdb.BreachHint @@ -173,7 +181,7 @@ func (t *backupTask) craftSessionPayload(sweepPkScript []byte, // to-local script, and the remote CSV delay. keyRing := t.breachInfo.KeyRing justiceKit := &blob.JusticeKit{ - SweepAddress: sweepPkScript, + SweepAddress: t.sweepPkScript, RevocationPubKey: toBlobPubKey(keyRing.RevocationKey), LocalDelayPubKey: toBlobPubKey(keyRing.DelayKey), CSVDelay: t.breachInfo.RemoteDelay, diff --git a/watchtower/wtclient/backup_task_internal_test.go b/watchtower/wtclient/backup_task_internal_test.go index 2d39c6005..c38e5c974 100644 --- a/watchtower/wtclient/backup_task_internal_test.go +++ b/watchtower/wtclient/backup_task_internal_test.go @@ -312,7 +312,7 @@ var backupTaskTests = []backupTaskTest{ blobTypeCommitReward, // blobType 1000, // sweepFeeRate addrScript, // rewardScript - 296241, // expSweepAmt + 296117, // expSweepAmt 3000, // expRewardAmt nil, // bindErr ), @@ -324,7 +324,7 @@ var backupTaskTests = []backupTaskTest{ blobTypeCommitReward, // blobType 1000, // sweepFeeRate addrScript, // rewardScript - 197514, // expSweepAmt + 197390, // expSweepAmt 2000, // expRewardAmt nil, // bindErr ), @@ -336,7 +336,7 @@ var backupTaskTests = []backupTaskTest{ blobTypeCommitReward, // blobType 1000, // sweepFeeRate addrScript, // rewardScript - 98561, // expSweepAmt + 98437, // expSweepAmt 1000, // expRewardAmt nil, // bindErr ), @@ -346,7 +346,7 @@ var backupTaskTests = []backupTaskTest{ 0, // toLocalAmt 100000, // toRemoteAmt blobTypeCommitReward, // blobType - 225000, // sweepFeeRate + 175000, // sweepFeeRate addrScript, // rewardScript 0, // expSweepAmt 0, // expRewardAmt @@ -397,7 +397,7 @@ func TestBackupTask(t *testing.T) { func testBackupTask(t *testing.T, test backupTaskTest) { // Create a new backupTask from the channel id and breach info. - task := newBackupTask(&test.chanID, test.breachInfo) + task := newBackupTask(&test.chanID, test.breachInfo, test.expSweepScript) // Assert that all parameters set during initialization are properly // populated. @@ -452,7 +452,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Now, bind the session to the task. If successful, this locks in the // session's negotiated parameters and allows the backup task to derive // the final free variables in the justice transaction. - err := task.bindSession(test.session, test.expSweepScript) + err := task.bindSession(test.session) if err != test.bindErr { t.Fatalf("expected: %v when binding session, got: %v", test.bindErr, err) @@ -509,9 +509,7 @@ func testBackupTask(t *testing.T, test backupTaskTest) { // Now, we'll construct, sign, and encrypt the blob containing the parts // needed to reconstruct the justice transaction. - hint, encBlob, err := task.craftSessionPayload( - test.expSweepScript, test.signer, - ) + hint, encBlob, err := task.craftSessionPayload(test.signer) if err != nil { t.Fatalf("unable to craft session payload: %v", err) }