From 47f70dae3a9ae4913606ae245aed51a79acf1faf Mon Sep 17 00:00:00 2001
From: Olaoluwa Osuntokun <laolu32@gmail.com>
Date: Wed, 1 Mar 2023 22:15:17 -0800
Subject: [PATCH] contractcourt: update commitSweepResolver for taproot chans

---
 contractcourt/commit_sweep_resolver.go | 80 +++++++++++++++++++-------
 1 file changed, 58 insertions(+), 22 deletions(-)

diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go
index ea78bbbac..839a0025e 100644
--- a/contractcourt/commit_sweep_resolver.go
+++ b/contractcourt/commit_sweep_resolver.go
@@ -57,6 +57,9 @@ type commitSweepResolver struct {
 	// leased channel.
 	leaseExpiry uint32
 
+	// chanType denotes the type of channel the contract belongs to.
+	chanType channeldb.ChannelType
+
 	// currentReport stores the current state of the resolver for reporting
 	// over the rpc interface.
 	currentReport ContractReport
@@ -148,13 +151,15 @@ func waitForSpend(op *wire.OutPoint, pkScript []byte, heightHint uint32,
 	}
 }
 
-// getCommitTxConfHeight waits for confirmation of the commitment tx and returns
-// the confirmation height.
+// getCommitTxConfHeight waits for confirmation of the commitment tx and
+// returns the confirmation height.
 func (c *commitSweepResolver) getCommitTxConfHeight() (uint32, error) {
 	txID := c.commitResolution.SelfOutPoint.Hash
 	signDesc := c.commitResolution.SelfOutputSignDesc
 	pkScript := signDesc.Output.PkScript
+
 	const confDepth = 1
+
 	confChan, err := c.Notifier.RegisterConfirmationsNtfn(
 		&txID, pkScript, confDepth, c.broadcastHeight,
 	)
@@ -245,23 +250,50 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
 		}
 	}
 
+	var (
+		isLocalCommitTx bool
+
+		signDesc = c.commitResolution.SelfOutputSignDesc
+	)
+	switch {
+	// For taproot channels, we'll know if this is the local commit based
+	// on the witness script. For local channels, the witness script has an
+	// OP_DROP value.
+	//
+	// TODO(roasbeef): revisit this after the script changes
+	//  * otherwise need to base off the key in script or the CSV value
+	//  (script num encode)
+	case c.chanType.IsTaproot():
+		scriptLen := len(signDesc.WitnessScript)
+		isLocalCommitTx = signDesc.WitnessScript[scriptLen-1] ==
+			txscript.OP_DROP
+
 	// The output is on our local commitment if the script starts with
 	// OP_IF for the revocation clause. On the remote commitment it will
 	// either be a regular P2WKH or a simple sig spend with a CSV delay.
-	isLocalCommitTx := c.commitResolution.SelfOutputSignDesc.WitnessScript[0] == txscript.OP_IF
+	default:
+		isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF
+	}
 	isDelayedOutput := c.commitResolution.MaturityDelay != 0
 
 	c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput,
 		isLocalCommitTx)
 
-	// There're three types of commitments, those that have tweaks
-	// for the remote key (us in this case), those that don't, and a third
-	// where there is no tweak and the output is delayed. On the local
-	// commitment our output will always be delayed. We'll rely on the
-	// presence of the commitment tweak to to discern which type of
-	// commitment this is.
+	// There're three types of commitments, those that have tweaks for the
+	// remote key (us in this case), those that don't, and a third where
+	// there is no tweak and the output is delayed. On the local commitment
+	// our output will always be delayed. We'll rely on the presence of the
+	// commitment tweak to to discern which type of commitment this is.
 	var witnessType input.WitnessType
 	switch {
+	// The local delayed output for a taproot channel.
+	case isLocalCommitTx && c.chanType.IsTaproot():
+		witnessType = input.TaprootLocalCommitSpend
+
+	// The CSV 1 delayed output for a taproot channel.
+	case !isLocalCommitTx && c.chanType.IsTaproot():
+		witnessType = input.TaprootRemoteCommitSpend
+
 	// Delayed output to us on our local commitment for a channel lease in
 	// which we are the initiator.
 	case isLocalCommitTx && c.hasCLTV():
@@ -293,9 +325,9 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
 
 	c.log.Infof("Sweeping with witness type: %v", witnessType)
 
-	// We'll craft an input with all the information required for
-	// the sweeper to create a fully valid sweeping transaction to
-	// recover these coins.
+	// We'll craft an input with all the information required for the
+	// sweeper to create a fully valid sweeping transaction to recover
+	// these coins.
 	var inp *input.BaseInput
 	if c.hasCLTV() {
 		inp = input.NewCsvInputWithCltv(
@@ -312,6 +344,9 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
 		)
 	}
 
+	// TODO(roasbeef): instead of ading ctrl block to the sign desc, make
+	// new input type, have sweeper set it?
+
 	// With our input constructed, we'll now offer it to the
 	// sweeper.
 	c.log.Infof("sweeping commit output")
@@ -326,28 +361,28 @@ func (c *commitSweepResolver) Resolve() (ContractResolver, error) {
 
 	var sweepTxID chainhash.Hash
 
-	// Sweeper is going to join this input with other inputs if
-	// possible and publish the sweep tx. When the sweep tx
-	// confirms, it signals us through the result channel with the
-	// outcome. Wait for this to happen.
+	// Sweeper is going to join this input with other inputs if possible
+	// and publish the sweep tx. When the sweep tx confirms, it signals us
+	// through the result channel with the outcome. Wait for this to
+	// happen.
 	outcome := channeldb.ResolverOutcomeClaimed
 	select {
 	case sweepResult := <-resultChan:
 		switch sweepResult.Err {
+		// If the remote party was able to sweep this output it's
+		// likely what we sent was actually a revoked commitment.
+		// Report the error and continue to wrap up the contract.
 		case sweep.ErrRemoteSpend:
-			// If the remote party was able to sweep this output
-			// it's likely what we sent was actually a revoked
-			// commitment. Report the error and continue to wrap up
-			// the contract.
 			c.log.Warnf("local commitment output was swept by "+
 				"remote party via %v", sweepResult.Tx.TxHash())
 			outcome = channeldb.ResolverOutcomeUnclaimed
+
+		// No errors, therefore continue processing.
 		case nil:
-			// No errors, therefore continue processing.
 			c.log.Infof("local commitment output fully resolved by "+
 				"sweep tx: %v", sweepResult.Tx.TxHash())
+		// Unknown errors.
 		default:
-			// Unknown errors.
 			c.log.Errorf("unable to sweep input: %v",
 				sweepResult.Err)
 
@@ -404,6 +439,7 @@ func (c *commitSweepResolver) SupplementState(state *channeldb.OpenChannel) {
 		c.leaseExpiry = state.ThawHeight
 	}
 	c.channelInitiator = state.IsInitiator
+	c.chanType = state.ChanType
 }
 
 // hasCLTV denotes whether the resolver must wait for an additional CLTV to