diff --git a/config_builder.go b/config_builder.go index 3e98075c9..d34b217df 100644 --- a/config_builder.go +++ b/config_builder.go @@ -50,6 +50,7 @@ import ( "github.com/lightningnetwork/lnd/rpcperms" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sqldb" + "github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/walletunlocker" "github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower/wtclient" @@ -149,6 +150,7 @@ type ImplementationCfg struct { // ChainControlBuilder is a type that can provide a custom wallet // implementation. ChainControlBuilder + // AuxComponents is a set of auxiliary components that can be used by // lnd for certain custom channel types. AuxComponents @@ -186,6 +188,14 @@ type AuxComponents struct { // AuxChanCloser is an optional channel closer that can be used to // modify the way a coop-close transaction is constructed. AuxChanCloser fn.Option[chancloser.AuxChanCloser] + + // AuxSweeper is an optional interface that can be used to modify the + // way sweep transaction are generated. + AuxSweeper fn.Option[sweep.AuxSweeper] + + // AuxContractResolver is an optional interface that can be used to + // modify the way contracts are resolved. + AuxContractResolver fn.Option[lnwallet.AuxContractResolver] } // DefaultWalletImpl is the default implementation of our normal, btcwallet diff --git a/contractcourt/breach_arbitrator.go b/contractcourt/breach_arbitrator.go index 96fa689c3..f86eab397 100644 --- a/contractcourt/breach_arbitrator.go +++ b/contractcourt/breach_arbitrator.go @@ -16,12 +16,14 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/labels" "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -1074,6 +1076,10 @@ type breachedOutput struct { secondLevelTapTweak [32]byte witnessFunc input.WitnessGenerator + + resolutionBlob fn.Option[tlv.Blob] + + // TODO(roasbeef): function opt and hook into brar } // makeBreachedOutput assembles a new breachedOutput that can be used by the @@ -1181,6 +1187,11 @@ func (bo *breachedOutput) UnconfParent() *input.TxInfo { return nil } +// ResolutionBlob... +func (bo *breachedOutput) ResolutionBlob() fn.Option[tlv.Blob] { + return bo.resolutionBlob +} + // Add compile-time constraint ensuring breachedOutput implements the Input // interface. var _ input.Input = (*breachedOutput)(nil) @@ -1629,13 +1640,13 @@ func taprootBriefcaseFromRetInfo(retInfo *retributionInfo) *taprootBriefcase { // commitment, we'll need to stash the control block. case input.TaprootRemoteCommitSpend: //nolint:lll - tapCase.CtrlBlocks.CommitSweepCtrlBlock = bo.signDesc.ControlBlock + tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = bo.signDesc.ControlBlock // To spend the revoked output again, we'll store the same // control block value as above, but in a different place. case input.TaprootCommitmentRevoke: //nolint:lll - tapCase.CtrlBlocks.RevokeSweepCtrlBlock = bo.signDesc.ControlBlock + tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock = bo.signDesc.ControlBlock // For spending the HTLC outputs, we'll store the first and // second level tweak values. @@ -1649,10 +1660,10 @@ func taprootBriefcaseFromRetInfo(retInfo *retributionInfo) *taprootBriefcase { secondLevelTweak := bo.secondLevelTapTweak //nolint:lll - tapCase.TapTweaks.BreachedHtlcTweaks[resID] = firstLevelTweak + tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID] = firstLevelTweak //nolint:lll - tapCase.TapTweaks.BreachedSecondLevelHltcTweaks[resID] = secondLevelTweak + tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID] = secondLevelTweak } } @@ -1672,13 +1683,13 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase, // commitment, we'll apply the control block. case input.TaprootRemoteCommitSpend: //nolint:lll - bo.signDesc.ControlBlock = tapCase.CtrlBlocks.CommitSweepCtrlBlock + bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock // To spend the revoked output again, we'll apply the same // control block value as above, but to a different place. case input.TaprootCommitmentRevoke: //nolint:lll - bo.signDesc.ControlBlock = tapCase.CtrlBlocks.RevokeSweepCtrlBlock + bo.signDesc.ControlBlock = tapCase.CtrlBlocks.Val.RevokeSweepCtrlBlock // For spending the HTLC outputs, we'll apply the first and // second level tweak values. @@ -1687,7 +1698,8 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase, case input.TaprootHtlcOfferedRevoke: resID := newResolverID(bo.OutPoint()) - tap1, ok := tapCase.TapTweaks.BreachedHtlcTweaks[resID] + //nolint:lll + tap1, ok := tapCase.TapTweaks.Val.BreachedHtlcTweaks[resID] if !ok { return fmt.Errorf("unable to find taproot "+ "tweak for: %v", bo.OutPoint()) @@ -1695,7 +1707,7 @@ func applyTaprootRetInfo(tapCase *taprootBriefcase, bo.signDesc.TapTweak = tap1[:] //nolint:lll - tap2, ok := tapCase.TapTweaks.BreachedSecondLevelHltcTweaks[resID] + tap2, ok := tapCase.TapTweaks.Val.BreachedSecondLevelHltcTweaks[resID] if !ok { return fmt.Errorf("unable to find taproot "+ "tweak for: %v", bo.OutPoint()) diff --git a/contractcourt/breach_arbitrator_test.go b/contractcourt/breach_arbitrator_test.go index babb427ea..896342d42 100644 --- a/contractcourt/breach_arbitrator_test.go +++ b/contractcourt/breach_arbitrator_test.go @@ -1592,6 +1592,7 @@ func testBreachSpends(t *testing.T, test breachTest) { retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, 1, forceCloseTx, fn.None[lnwallet.AuxLeafStore](), + fn.None[lnwallet.AuxContractResolver](), ) require.NoError(t, err, "unable to create breach retribution") @@ -1802,6 +1803,7 @@ func TestBreachDelayedJusticeConfirmation(t *testing.T) { retribution, err := lnwallet.NewBreachRetribution( alice.State(), height, uint32(blockHeight), forceCloseTx, fn.None[lnwallet.AuxLeafStore](), + fn.None[lnwallet.AuxContractResolver](), ) require.NoError(t, err, "unable to create breach retribution") diff --git a/contractcourt/briefcase.go b/contractcourt/briefcase.go index 95f6a933d..11ab83390 100644 --- a/contractcourt/briefcase.go +++ b/contractcourt/briefcase.go @@ -10,9 +10,11 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/lightningnetwork/lnd/channeldb" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/tlv" ) // ContractResolutions is a wrapper struct around the two forms of resolutions @@ -1553,7 +1555,13 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { commitResolution := c.CommitResolution commitSignDesc := commitResolution.SelfOutputSignDesc //nolint:lll - tapCase.CtrlBlocks.CommitSweepCtrlBlock = commitSignDesc.ControlBlock + tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock = commitSignDesc.ControlBlock + + c.CommitResolution.ResolutionBlob.WhenSome(func(b []byte) { + tapCase.CommitBlob = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType2](b), + ) + }) } for _, htlc := range c.HtlcResolutions.IncomingHTLCs { @@ -1571,7 +1579,7 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { htlc.SignedSuccessTx.TxIn[0].PreviousOutPoint, ) //nolint:lll - tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock + tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock // For HTLCs we need to go to the second level for, we // also need to store the control block needed to @@ -1580,12 +1588,12 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { //nolint:lll bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock //nolint:lll - tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock + tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = bridgeCtrlBlock } } else { resID := newResolverID(htlc.ClaimOutpoint) //nolint:lll - tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] = ctrlBlock + tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] = ctrlBlock } } for _, htlc := range c.HtlcResolutions.OutgoingHTLCs { @@ -1603,7 +1611,7 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { htlc.SignedTimeoutTx.TxIn[0].PreviousOutPoint, ) //nolint:lll - tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] = ctrlBlock + tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] = ctrlBlock // For HTLCs we need to go to the second level for, we // also need to store the control block needed to @@ -1614,18 +1622,18 @@ func encodeTaprootAuxData(w io.Writer, c *ContractResolutions) error { //nolint:lll bridgeCtrlBlock := htlc.SignDetails.SignDesc.ControlBlock //nolint:lll - tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock + tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = bridgeCtrlBlock } } else { resID := newResolverID(htlc.ClaimOutpoint) //nolint:lll - tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock + tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] = ctrlBlock } } if c.AnchorResolution != nil { anchorSignDesc := c.AnchorResolution.AnchorSignDescriptor - tapCase.TapTweaks.AnchorTweak = anchorSignDesc.TapTweak + tapCase.TapTweaks.Val.AnchorTweak = anchorSignDesc.TapTweak } return tapCase.Encode(w) @@ -1639,7 +1647,11 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error { if c.CommitResolution != nil { c.CommitResolution.SelfOutputSignDesc.ControlBlock = - tapCase.CtrlBlocks.CommitSweepCtrlBlock + tapCase.CtrlBlocks.Val.CommitSweepCtrlBlock + + tapCase.CommitBlob.WhenSomeV(func(b []byte) { + c.CommitResolution.ResolutionBlob = fn.Some(b) + }) } for i := range c.HtlcResolutions.IncomingHTLCs { @@ -1652,19 +1664,19 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error { ) //nolint:lll - ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] + ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] htlc.SweepSignDesc.ControlBlock = ctrlBlock //nolint:lll if htlc.SignDetails != nil { - bridgeCtrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] + bridgeCtrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock } } else { resID = newResolverID(htlc.ClaimOutpoint) //nolint:lll - ctrlBlock := tapCase.CtrlBlocks.IncomingHtlcCtrlBlocks[resID] + ctrlBlock := tapCase.CtrlBlocks.Val.IncomingHtlcCtrlBlocks[resID] htlc.SweepSignDesc.ControlBlock = ctrlBlock } @@ -1680,19 +1692,19 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error { ) //nolint:lll - ctrlBlock := tapCase.CtrlBlocks.SecondLevelCtrlBlocks[resID] + ctrlBlock := tapCase.CtrlBlocks.Val.SecondLevelCtrlBlocks[resID] htlc.SweepSignDesc.ControlBlock = ctrlBlock //nolint:lll if htlc.SignDetails != nil { - bridgeCtrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] + bridgeCtrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] htlc.SignDetails.SignDesc.ControlBlock = bridgeCtrlBlock } } else { resID = newResolverID(htlc.ClaimOutpoint) //nolint:lll - ctrlBlock := tapCase.CtrlBlocks.OutgoingHtlcCtrlBlocks[resID] + ctrlBlock := tapCase.CtrlBlocks.Val.OutgoingHtlcCtrlBlocks[resID] htlc.SweepSignDesc.ControlBlock = ctrlBlock } @@ -1701,7 +1713,7 @@ func decodeTapRootAuxData(r io.Reader, c *ContractResolutions) error { if c.AnchorResolution != nil { c.AnchorResolution.AnchorSignDescriptor.TapTweak = - tapCase.TapTweaks.AnchorTweak + tapCase.TapTweaks.Val.AnchorTweak } return nil diff --git a/contractcourt/chain_arbitrator.go b/contractcourt/chain_arbitrator.go index d61e47901..c29178b43 100644 --- a/contractcourt/chain_arbitrator.go +++ b/contractcourt/chain_arbitrator.go @@ -225,6 +225,10 @@ type ChainArbitratorConfig struct { // AuxSigner is an optional signer that can be used to sign auxiliary // leaves for certain custom channel types. AuxSigner fn.Option[lnwallet.AuxSigner] + + // AuxResolver is an optional interface that can be used to modify the + // way contracts are resolved. + AuxResolver fn.Option[lnwallet.AuxContractResolver] } // ChainArbitrator is a sub-system that oversees the on-chain resolution of all @@ -314,6 +318,9 @@ func (a *arbChannel) NewAnchorResolutions() (*lnwallet.AnchorResolutions, a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) { chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s)) }) + a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) { + chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) + }) chanMachine, err := lnwallet.NewLightningChannel( a.c.cfg.Signer, channel, nil, chanOpts..., @@ -367,6 +374,9 @@ func (a *arbChannel) ForceCloseChan() (*lnwallet.LocalForceCloseSummary, error) a.c.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) { chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s)) }) + a.c.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) { + chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) + }) // Finally, we'll force close the channel completing // the force close workflow. @@ -581,6 +591,8 @@ func (c *ChainArbitrator) Start() error { isOurAddr: c.cfg.IsOurAddress, contractBreach: breachClosure, extractStateNumHint: lnwallet.GetStateNumHint, + auxLeafStore: c.cfg.AuxLeafStore, + auxResolver: c.cfg.AuxResolver, }, ) if err != nil { @@ -1210,6 +1222,8 @@ func (c *ChainArbitrator) WatchNewChannel(newChan *channeldb.OpenChannel) error ) }, extractStateNumHint: lnwallet.GetStateNumHint, + auxLeafStore: c.cfg.AuxLeafStore, + auxResolver: c.cfg.AuxResolver, }, ) if err != nil { diff --git a/contractcourt/chain_watcher.go b/contractcourt/chain_watcher.go index dd8a62e6d..19bfd8b70 100644 --- a/contractcourt/chain_watcher.go +++ b/contractcourt/chain_watcher.go @@ -192,6 +192,9 @@ type chainWatcherConfig struct { // auxLeafStore can be used to fetch information for custom channels. auxLeafStore fn.Option[lnwallet.AuxLeafStore] + + // auxResolver is used to supplement contract resolution. + auxResolver fn.Option[lnwallet.AuxContractResolver] } // chainWatcher is a system that's assigned to every active channel. The duty @@ -889,7 +892,7 @@ func (c *chainWatcher) handlePossibleBreach(commitSpend *chainntnfs.SpendDetail, spendHeight := uint32(commitSpend.SpendingHeight) retribution, err := lnwallet.NewBreachRetribution( c.cfg.chanState, broadcastStateNum, spendHeight, - commitSpend.SpendingTx, c.cfg.auxLeafStore, + commitSpend.SpendingTx, c.cfg.auxLeafStore, c.cfg.auxResolver, ) switch { @@ -1101,7 +1104,7 @@ func (c *chainWatcher) dispatchLocalForceClose( forceClose, err := lnwallet.NewLocalForceCloseSummary( c.cfg.chanState, c.cfg.signer, commitSpend.SpendingTx, stateNum, - c.cfg.auxLeafStore, + c.cfg.auxLeafStore, c.cfg.auxResolver, ) if err != nil { return err @@ -1193,8 +1196,8 @@ func (c *chainWatcher) dispatchRemoteForceClose( // materials required to let each subscriber sweep the funds in the // channel on-chain. uniClose, err := lnwallet.NewUnilateralCloseSummary( - c.cfg.chanState, c.cfg.signer, commitSpend, - remoteCommit, commitPoint, c.cfg.auxLeafStore, + c.cfg.chanState, c.cfg.signer, commitSpend, remoteCommit, + commitPoint, c.cfg.auxLeafStore, c.cfg.auxResolver, ) if err != nil { return err diff --git a/contractcourt/commit_sweep_resolver.go b/contractcourt/commit_sweep_resolver.go index 296ea38e5..025d1312d 100644 --- a/contractcourt/commit_sweep_resolver.go +++ b/contractcourt/commit_sweep_resolver.go @@ -261,9 +261,7 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { // * otherwise need to base off the key in script or the CSV value // (script num encode) case c.chanType.IsTaproot(): - scriptLen := len(signDesc.WitnessScript) - isLocalCommitTx = signDesc.WitnessScript[scriptLen-1] == - txscript.OP_DROP + isLocalCommitTx = c.commitResolution.MaturityDelay != 1 // The output is on our local commitment if the script starts with // OP_IF for the revocation clause. On the remote commitment it will @@ -271,10 +269,8 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { default: isLocalCommitTx = signDesc.WitnessScript[0] == txscript.OP_IF } - isDelayedOutput := c.commitResolution.MaturityDelay != 0 - c.log.Debugf("isDelayedOutput=%v, isLocalCommitTx=%v", isDelayedOutput, - isLocalCommitTx) + isDelayedOutput := c.commitResolution.MaturityDelay != 0 // There're three types of commitments, those that have tweaks for the // remote key (us in this case), those that don't, and a third where @@ -332,12 +328,18 @@ func (c *commitSweepResolver) Resolve(_ bool) (ContractResolver, error) { &c.commitResolution.SelfOutputSignDesc, c.broadcastHeight, c.commitResolution.MaturityDelay, c.leaseExpiry, + input.WithResolutionBlob( + c.commitResolution.ResolutionBlob, + ), ) } else { inp = input.NewCsvInput( &c.commitResolution.SelfOutPoint, witnessType, &c.commitResolution.SelfOutputSignDesc, c.broadcastHeight, c.commitResolution.MaturityDelay, + input.WithResolutionBlob( + c.commitResolution.ResolutionBlob, + ), ) } diff --git a/contractcourt/htlc_incoming_contest_resolver.go b/contractcourt/htlc_incoming_contest_resolver.go index 6bda4e398..c46ec61af 100644 --- a/contractcourt/htlc_incoming_contest_resolver.go +++ b/contractcourt/htlc_incoming_contest_resolver.go @@ -99,6 +99,14 @@ func (h *htlcIncomingContestResolver) Resolve( return nil, nil } + // If the HTLC has custom records, then for now we'll pause resolution. + if len(h.htlc.CustomRecords) != 0 { + select { //nolint:gosimple + case <-h.quit: + return nil, errResolverShuttingDown + } + } + // First try to parse the payload. If that fails, we can stop resolution // now. payload, nextHopOnionBlob, err := h.decodePayload() diff --git a/contractcourt/htlc_outgoing_contest_resolver.go b/contractcourt/htlc_outgoing_contest_resolver.go index 2466544c9..a075b243f 100644 --- a/contractcourt/htlc_outgoing_contest_resolver.go +++ b/contractcourt/htlc_outgoing_contest_resolver.go @@ -58,6 +58,14 @@ func (h *htlcOutgoingContestResolver) Resolve( return nil, nil } + // If the HTLC has custom records, then for now we'll pause resolution. + if len(h.htlc.CustomRecords) != 0 { + select { //nolint:gosimple + case <-h.quit: + return nil, errResolverShuttingDown + } + } + // Otherwise, we'll watch for two external signals to decide if we'll // morph into another resolver, or fully resolve the contract. // diff --git a/contractcourt/htlc_success_resolver.go b/contractcourt/htlc_success_resolver.go index 6eee939ea..21557ef87 100644 --- a/contractcourt/htlc_success_resolver.go +++ b/contractcourt/htlc_success_resolver.go @@ -123,6 +123,14 @@ func (h *htlcSuccessResolver) Resolve( return nil, nil } + // If the HTLC has custom records, then for now we'll pause resolution. + if len(h.htlc.CustomRecords) != 0 { + select { //nolint:gosimple + case <-h.quit: + return nil, errResolverShuttingDown + } + } + // If we don't have a success transaction, then this means that this is // an output on the remote party's commitment transaction. if h.htlcResolution.SignedSuccessTx == nil { diff --git a/contractcourt/htlc_timeout_resolver.go b/contractcourt/htlc_timeout_resolver.go index 62ff83207..82353caa4 100644 --- a/contractcourt/htlc_timeout_resolver.go +++ b/contractcourt/htlc_timeout_resolver.go @@ -426,6 +426,14 @@ func (h *htlcTimeoutResolver) Resolve( return nil, nil } + // If the HTLC has custom records, then for now we'll pause resolution. + if len(h.htlc.CustomRecords) != 0 { + select { //nolint:gosimple + case <-h.quit: + return nil, errResolverShuttingDown + } + } + // Start by spending the HTLC output, either by broadcasting the // second-level timeout transaction, or directly if this is the remote // commitment. diff --git a/contractcourt/taproot_briefcase.go b/contractcourt/taproot_briefcase.go index 5931a4556..016dee5e0 100644 --- a/contractcourt/taproot_briefcase.go +++ b/contractcourt/taproot_briefcase.go @@ -8,9 +8,6 @@ import ( ) const ( - taprootCtrlBlockType tlv.Type = 0 - taprootTapTweakType tlv.Type = 1 - commitCtrlBlockType tlv.Type = 0 revokeCtrlBlockType tlv.Type = 1 outgoingHtlcCtrlBlockType tlv.Type = 2 @@ -26,36 +23,47 @@ const ( // information we need to sweep taproot outputs. type taprootBriefcase struct { // CtrlBlock is the set of control block for the taproot outputs. - CtrlBlocks *ctrlBlocks + CtrlBlocks tlv.RecordT[tlv.TlvType0, ctrlBlocks] // TapTweaks is the set of taproot tweaks for the taproot outputs that // are to be spent via a keyspend path. This includes anchors, and any // revocation paths. - TapTweaks *tapTweaks + TapTweaks tlv.RecordT[tlv.TlvType1, tapTweaks] + + // CommitBlob is an optional record that contains an opaque blob that + // may be used to properly sweep commitment outputs on a force close + // transaction. + CommitBlob tlv.OptionalRecordT[tlv.TlvType2, tlv.Blob] } +// TODO(roasbeef): morph into new tlv record + // newTaprootBriefcase returns a new instance of the taproot specific briefcase // variant. func newTaprootBriefcase() *taprootBriefcase { return &taprootBriefcase{ - CtrlBlocks: newCtrlBlocks(), - TapTweaks: newTapTweaks(), + CtrlBlocks: tlv.NewRecordT[tlv.TlvType0](newCtrlBlocks()), + TapTweaks: tlv.NewRecordT[tlv.TlvType1](newTapTweaks()), } } // EncodeRecords returns a slice of TLV records that should be encoded. func (t *taprootBriefcase) EncodeRecords() []tlv.Record { - return []tlv.Record{ - newCtrlBlocksRecord(&t.CtrlBlocks), - newTapTweaksRecord(&t.TapTweaks), + records := []tlv.Record{ + t.CtrlBlocks.Record(), t.TapTweaks.Record(), } + + t.CommitBlob.WhenSome(func(r tlv.RecordT[tlv.TlvType2, tlv.Blob]) { + records = append(records, r.Record()) + }) + + return records } // DecodeRecords returns a slice of TLV records that should be decoded. func (t *taprootBriefcase) DecodeRecords() []tlv.Record { return []tlv.Record{ - newCtrlBlocksRecord(&t.CtrlBlocks), - newTapTweaksRecord(&t.TapTweaks), + t.CtrlBlocks.Record(), t.TapTweaks.Record(), } } @@ -71,12 +79,23 @@ func (t *taprootBriefcase) Encode(w io.Writer) error { // Decode decodes the given reader into the target struct. func (t *taprootBriefcase) Decode(r io.Reader) error { - stream, err := tlv.NewStream(t.DecodeRecords()...) + commitBlob := t.CommitBlob.Zero() + records := append(t.DecodeRecords(), commitBlob.Record()) + stream, err := tlv.NewStream(records...) if err != nil { return err } - return stream.Decode(r) + typeMap, err := stream.DecodeWithParsedTypes(r) + if err != nil { + return err + } + + if val, ok := typeMap[t.CommitBlob.TlvType()]; ok && val == nil { + t.CommitBlob = tlv.SomeRecordT(commitBlob) + } + + return nil } // resolverCtrlBlocks is a map of resolver IDs to their corresponding control @@ -216,8 +235,8 @@ type ctrlBlocks struct { } // newCtrlBlocks returns a new instance of the ctrlBlocks struct. -func newCtrlBlocks() *ctrlBlocks { - return &ctrlBlocks{ +func newCtrlBlocks() ctrlBlocks { + return ctrlBlocks{ OutgoingHtlcCtrlBlocks: newResolverCtrlBlocks(), IncomingHtlcCtrlBlocks: newResolverCtrlBlocks(), SecondLevelCtrlBlocks: newResolverCtrlBlocks(), @@ -260,7 +279,7 @@ func varBytesDecoder(r io.Reader, val any, buf *[8]byte, l uint64) error { // ctrlBlockEncoder is a custom TLV encoder for the ctrlBlocks struct. func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error { - if t, ok := val.(**ctrlBlocks); ok { + if t, ok := val.(*ctrlBlocks); ok { return (*t).Encode(w) } @@ -269,7 +288,7 @@ func ctrlBlockEncoder(w io.Writer, val any, _ *[8]byte) error { // ctrlBlockDecoder is a custom TLV decoder for the ctrlBlocks struct. func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { - if typ, ok := val.(**ctrlBlocks); ok { + if typ, ok := val.(*ctrlBlocks); ok { ctrlReader := io.LimitReader(r, int64(l)) var ctrlBlocks ctrlBlocks @@ -278,7 +297,7 @@ func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { return err } - *typ = &ctrlBlocks + *typ = ctrlBlocks return nil } @@ -286,28 +305,6 @@ func ctrlBlockDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "ctrlBlocks", l, l) } -// newCtrlBlocksRecord returns a new TLV record that can be used to -// encode/decode the set of cotrol blocks for the taproot outputs for a -// channel. -func newCtrlBlocksRecord(blks **ctrlBlocks) tlv.Record { - recordSize := func() uint64 { - var ( - b bytes.Buffer - buf [8]byte - ) - if err := ctrlBlockEncoder(&b, blks, &buf); err != nil { - panic(err) - } - - return uint64(len(b.Bytes())) - } - - return tlv.MakeDynamicRecord( - taprootCtrlBlockType, blks, recordSize, ctrlBlockEncoder, - ctrlBlockDecoder, - ) -} - // EncodeRecords returns the set of TLV records that encode the control block // for the commitment transaction. func (c *ctrlBlocks) EncodeRecords() []tlv.Record { @@ -382,7 +379,21 @@ func (c *ctrlBlocks) DecodeRecords() []tlv.Record { // Record returns a TLV record that can be used to encode/decode the control // blocks. type from a given TLV stream. func (c *ctrlBlocks) Record() tlv.Record { - return tlv.MakePrimitiveRecord(commitCtrlBlockType, c) + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := ctrlBlockEncoder(&b, c, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, c, recordSize, ctrlBlockEncoder, ctrlBlockDecoder, + ) } // Encode encodes the set of control blocks. @@ -530,8 +541,8 @@ type tapTweaks struct { } // newTapTweaks returns a new tapTweaks struct. -func newTapTweaks() *tapTweaks { - return &tapTweaks{ +func newTapTweaks() tapTweaks { + return tapTweaks{ BreachedHtlcTweaks: make(htlcTapTweaks), BreachedSecondLevelHltcTweaks: make(htlcTapTweaks), } @@ -539,7 +550,7 @@ func newTapTweaks() *tapTweaks { // tapTweaksEncoder is a custom TLV encoder for the tapTweaks struct. func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error { - if t, ok := val.(**tapTweaks); ok { + if t, ok := val.(*tapTweaks); ok { return (*t).Encode(w) } @@ -548,7 +559,7 @@ func tapTweaksEncoder(w io.Writer, val any, _ *[8]byte) error { // tapTweaksDecoder is a custom TLV decoder for the tapTweaks struct. func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { - if typ, ok := val.(**tapTweaks); ok { + if typ, ok := val.(*tapTweaks); ok { tweakReader := io.LimitReader(r, int64(l)) var tapTweaks tapTweaks @@ -557,7 +568,7 @@ func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { return err } - *typ = &tapTweaks + *typ = tapTweaks return nil } @@ -565,27 +576,6 @@ func tapTweaksDecoder(r io.Reader, val any, _ *[8]byte, l uint64) error { return tlv.NewTypeForDecodingErr(val, "tapTweaks", l, l) } -// newTapTweaksRecord returns a new TLV record that can be used to -// encode/decode the tap tweak structs. -func newTapTweaksRecord(tweaks **tapTweaks) tlv.Record { - recordSize := func() uint64 { - var ( - b bytes.Buffer - buf [8]byte - ) - if err := tapTweaksEncoder(&b, tweaks, &buf); err != nil { - panic(err) - } - - return uint64(len(b.Bytes())) - } - - return tlv.MakeDynamicRecord( - taprootTapTweakType, tweaks, recordSize, tapTweaksEncoder, - tapTweaksDecoder, - ) -} - // EncodeRecords returns the set of TLV records that encode the tweaks. func (t *tapTweaks) EncodeRecords() []tlv.Record { var records []tlv.Record @@ -637,7 +627,21 @@ func (t *tapTweaks) DecodeRecords() []tlv.Record { // Record returns a TLV record that can be used to encode/decode the tap // tweaks. func (t *tapTweaks) Record() tlv.Record { - return tlv.MakePrimitiveRecord(taprootTapTweakType, t) + recordSize := func() uint64 { + var ( + b bytes.Buffer + buf [8]byte + ) + if err := tapTweaksEncoder(&b, t, &buf); err != nil { + panic(err) + } + + return uint64(len(b.Bytes())) + } + + return tlv.MakeDynamicRecord( + 0, t, recordSize, tapTweaksEncoder, tapTweaksDecoder, + ) } // Encode encodes the set of tap tweaks. diff --git a/contractcourt/taproot_briefcase_test.go b/contractcourt/taproot_briefcase_test.go index 38471ed74..66adc3c07 100644 --- a/contractcourt/taproot_briefcase_test.go +++ b/contractcourt/taproot_briefcase_test.go @@ -5,6 +5,7 @@ import ( "math/rand" "testing" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -69,19 +70,26 @@ func TestTaprootBriefcase(t *testing.T) { _, err = rand.Read(anchorTweak[:]) require.NoError(t, err) + var commitBlob [100]byte + _, err = rand.Read(commitBlob[:]) + require.NoError(t, err) + testCase := &taprootBriefcase{ - CtrlBlocks: &ctrlBlocks{ + CtrlBlocks: tlv.NewRecordT[tlv.TlvType0](ctrlBlocks{ CommitSweepCtrlBlock: sweepCtrlBlock[:], RevokeSweepCtrlBlock: revokeCtrlBlock[:], OutgoingHtlcCtrlBlocks: randResolverCtrlBlocks(t), IncomingHtlcCtrlBlocks: randResolverCtrlBlocks(t), SecondLevelCtrlBlocks: randResolverCtrlBlocks(t), - }, - TapTweaks: &tapTweaks{ + }), + TapTweaks: tlv.NewRecordT[tlv.TlvType1](tapTweaks{ AnchorTweak: anchorTweak[:], BreachedHtlcTweaks: randHtlcTweaks(t), BreachedSecondLevelHltcTweaks: randHtlcTweaks(t), - }, + }), + CommitBlob: tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType2](commitBlob[:]), + ), } var b bytes.Buffer diff --git a/funding/manager.go b/funding/manager.go index 77fba3072..d679d40b0 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -558,6 +558,10 @@ type Config struct { // AuxSigner is an optional signer that can be used to sign auxiliary // leaves for certain custom channel types. AuxSigner fn.Option[lnwallet.AuxSigner] + + // AuxResolver is an optional interface that can be used to modify the + // way contracts are resolved. + AuxResolver fn.Option[lnwallet.AuxContractResolver] } // Manager acts as an orchestrator/bridge between the wallet's @@ -1090,6 +1094,9 @@ func (f *Manager) advanceFundingState(channel *channeldb.OpenChannel, f.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) { chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s)) }) + f.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) { + chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) + }) // We create the state-machine object which wraps the database state. lnChannel, err := lnwallet.NewLightningChannel( diff --git a/go.mod b/go.mod index 786307df1..62c5b14a6 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( github.com/lightningnetwork/lightning-onion v1.2.1-0.20230823005744-06182b1d7d2f github.com/lightningnetwork/lnd/cert v1.2.2 github.com/lightningnetwork/lnd/clock v1.1.1 - github.com/lightningnetwork/lnd/fn v1.0.8 + github.com/lightningnetwork/lnd/fn v1.1.0 github.com/lightningnetwork/lnd/healthcheck v1.2.4 github.com/lightningnetwork/lnd/kvdb v1.4.8 github.com/lightningnetwork/lnd/queue v1.1.1 diff --git a/go.sum b/go.sum index 0067ffd7f..d792bcc91 100644 --- a/go.sum +++ b/go.sum @@ -448,8 +448,8 @@ github.com/lightningnetwork/lnd/cert v1.2.2 h1:71YK6hogeJtxSxw2teq3eGeuy4rHGKcFf github.com/lightningnetwork/lnd/cert v1.2.2/go.mod h1:jQmFn/Ez4zhDgq2hnYSw8r35bqGVxViXhX6Cd7HXM6U= github.com/lightningnetwork/lnd/clock v1.1.1 h1:OfR3/zcJd2RhH0RU+zX/77c0ZiOnIMsDIBjgjWdZgA0= github.com/lightningnetwork/lnd/clock v1.1.1/go.mod h1:mGnAhPyjYZQJmebS7aevElXKTFDuO+uNFFfMXK1W8xQ= -github.com/lightningnetwork/lnd/fn v1.0.8 h1:gwzzcUyeDXVIm5S6KgJ9iCQ9wLQGf367k7O3bn/BEvs= -github.com/lightningnetwork/lnd/fn v1.0.8/go.mod h1:P027+0CyELd92H9gnReUkGGAqbFA1HwjHWdfaDFD51U= +github.com/lightningnetwork/lnd/fn v1.1.0 h1:W1p/bUXMgAh5YlmawdQYaNgmLaLMT77BilepzWOSZ2A= +github.com/lightningnetwork/lnd/fn v1.1.0/go.mod h1:P027+0CyELd92H9gnReUkGGAqbFA1HwjHWdfaDFD51U= github.com/lightningnetwork/lnd/healthcheck v1.2.4 h1:lLPLac+p/TllByxGSlkCwkJlkddqMP5UCoawCj3mgFQ= github.com/lightningnetwork/lnd/healthcheck v1.2.4/go.mod h1:G7Tst2tVvWo7cx6mSBEToQC5L1XOGxzZTPB29g9Rv2I= github.com/lightningnetwork/lnd/kvdb v1.4.8 h1:xH0a5Vi1yrcZ5BEeF2ba3vlKBRxrL9uYXlWTjOjbNTY= diff --git a/input/input.go b/input/input.go index aef524a4c..00312900f 100644 --- a/input/input.go +++ b/input/input.go @@ -6,7 +6,9 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/tlv" ) // EmptyOutPoint is a zeroed outpoint. @@ -63,6 +65,10 @@ type Input interface { // UnconfParent returns information about a possibly unconfirmed parent // tx. UnconfParent() *TxInfo + + // ResolutionBlob returns a special opaque blob to be used to + // sweep/resolve this input. + ResolutionBlob() fn.Option[tlv.Blob] } // TxInfo describes properties of a parent tx that are relevant for CPFP. @@ -106,6 +112,8 @@ type inputKit struct { // unconfParent contains information about a potential unconfirmed // parent transaction. unconfParent *TxInfo + + resolutionBlob fn.Option[tlv.Blob] } // OutPoint returns the breached output's identifier that is to be included as @@ -156,6 +164,38 @@ func (i *inputKit) UnconfParent() *TxInfo { return i.unconfParent } +// ResolutionBlob returns a special opaque blob to be used to sweep/resolve +// this input. +func (i *inputKit) ResolutionBlob() fn.Option[tlv.Blob] { + return i.resolutionBlob +} + +// inputOpts contains options for constructing a new input. +type inputOpts struct { + // resolutionBlob is an optional blob that can be used to resolve an + // input. + resolutionBlob fn.Option[tlv.Blob] +} + +// defaultInputOpts returns a new inputOpts with default values. +func defaultInputOpts() *inputOpts { + return &inputOpts{} +} + +// InputOpt is a functional option that can be used to modify the default input +// options. +// +// TODO(roasbeef): make rest of args to input kit func opt? +type InputOpt func(*inputOpts) //nolint:revive + +// WithResolutionBlob is an option that can be used to set a resolution blob on +// for an input. +func WithResolutionBlob(b fn.Option[tlv.Blob]) InputOpt { + return func(o *inputOpts) { + o.resolutionBlob = b + } +} + // BaseInput contains all the information needed to sweep a basic // output (CSV/CLTV/no time lock). type BaseInput struct { @@ -166,15 +206,21 @@ type BaseInput struct { // sweep transaction. func MakeBaseInput(outpoint *wire.OutPoint, witnessType WitnessType, signDescriptor *SignDescriptor, heightHint uint32, - unconfParent *TxInfo) BaseInput { + unconfParent *TxInfo, opts ...InputOpt) BaseInput { + + opt := defaultInputOpts() + for _, optF := range opts { + optF(opt) + } return BaseInput{ inputKit{ - outpoint: *outpoint, - witnessType: witnessType, - signDesc: *signDescriptor, - heightHint: heightHint, - unconfParent: unconfParent, + outpoint: *outpoint, + witnessType: witnessType, + signDesc: *signDescriptor, + heightHint: heightHint, + unconfParent: unconfParent, + resolutionBlob: opt.resolutionBlob, }, } } @@ -182,10 +228,11 @@ func MakeBaseInput(outpoint *wire.OutPoint, witnessType WitnessType, // NewBaseInput allocates and assembles a new *BaseInput that can be used to // construct a sweep transaction. func NewBaseInput(outpoint *wire.OutPoint, witnessType WitnessType, - signDescriptor *SignDescriptor, heightHint uint32) *BaseInput { + signDescriptor *SignDescriptor, heightHint uint32, + opts ...InputOpt) *BaseInput { input := MakeBaseInput( - outpoint, witnessType, signDescriptor, heightHint, nil, + outpoint, witnessType, signDescriptor, heightHint, nil, opts..., ) return &input @@ -195,36 +242,31 @@ func NewBaseInput(outpoint *wire.OutPoint, witnessType WitnessType, // construct a sweep transaction. func NewCsvInput(outpoint *wire.OutPoint, witnessType WitnessType, signDescriptor *SignDescriptor, heightHint uint32, - blockToMaturity uint32) *BaseInput { + blockToMaturity uint32, opts ...InputOpt) *BaseInput { - return &BaseInput{ - inputKit{ - outpoint: *outpoint, - witnessType: witnessType, - signDesc: *signDescriptor, - heightHint: heightHint, - blockToMaturity: blockToMaturity, - }, - } + input := MakeBaseInput( + outpoint, witnessType, signDescriptor, heightHint, nil, opts..., + ) + + input.blockToMaturity = blockToMaturity + + return &input } // NewCsvInputWithCltv assembles a new csv and cltv locked input that can be // used to construct a sweep transaction. func NewCsvInputWithCltv(outpoint *wire.OutPoint, witnessType WitnessType, signDescriptor *SignDescriptor, heightHint uint32, - csvDelay uint32, cltvExpiry uint32) *BaseInput { + csvDelay uint32, cltvExpiry uint32, opts ...InputOpt) *BaseInput { - return &BaseInput{ - inputKit{ - outpoint: *outpoint, - witnessType: witnessType, - signDesc: *signDescriptor, - heightHint: heightHint, - blockToMaturity: csvDelay, - cltvExpiry: cltvExpiry, - unconfParent: nil, - }, - } + input := MakeBaseInput( + outpoint, witnessType, signDescriptor, heightHint, nil, opts..., + ) + + input.blockToMaturity = csvDelay + input.cltvExpiry = cltvExpiry + + return &input } // CraftInputScript returns a valid set of input scripts allowing this output @@ -256,16 +298,16 @@ type HtlcSucceedInput struct { // construct a sweep transaction. func MakeHtlcSucceedInput(outpoint *wire.OutPoint, signDescriptor *SignDescriptor, preimage []byte, heightHint, - blocksToMaturity uint32) HtlcSucceedInput { + blocksToMaturity uint32, opts ...InputOpt) HtlcSucceedInput { + + input := MakeBaseInput( + outpoint, HtlcAcceptedRemoteSuccess, signDescriptor, + heightHint, nil, opts..., + ) + input.blockToMaturity = blocksToMaturity return HtlcSucceedInput{ - inputKit: inputKit{ - outpoint: *outpoint, - witnessType: HtlcAcceptedRemoteSuccess, - signDesc: *signDescriptor, - heightHint: heightHint, - blockToMaturity: blocksToMaturity, - }, + inputKit: input.inputKit, preimage: preimage, } } @@ -274,16 +316,17 @@ func MakeHtlcSucceedInput(outpoint *wire.OutPoint, // to spend an HTLC output for a taproot channel on the remote party's // commitment transaction. func MakeTaprootHtlcSucceedInput(op *wire.OutPoint, signDesc *SignDescriptor, - preimage []byte, heightHint, blocksToMaturity uint32) HtlcSucceedInput { + preimage []byte, heightHint, blocksToMaturity uint32, + opts ...InputOpt) HtlcSucceedInput { + + input := MakeBaseInput( + op, TaprootHtlcAcceptedRemoteSuccess, signDesc, + heightHint, nil, opts..., + ) + input.blockToMaturity = blocksToMaturity return HtlcSucceedInput{ - inputKit: inputKit{ - outpoint: *op, - witnessType: TaprootHtlcAcceptedRemoteSuccess, - signDesc: *signDesc, - heightHint: heightHint, - blockToMaturity: blocksToMaturity, - }, + inputKit: input.inputKit, preimage: preimage, } } @@ -388,7 +431,8 @@ func (i *HtlcSecondLevelAnchorInput) CraftInputScript(signer Signer, // to spend the HTLC output on our commit using the second level timeout // transaction. func MakeHtlcSecondLevelTimeoutAnchorInput(signedTx *wire.MsgTx, - signDetails *SignDetails, heightHint uint32) HtlcSecondLevelAnchorInput { + signDetails *SignDetails, heightHint uint32, + opts ...InputOpt) HtlcSecondLevelAnchorInput { // Spend an HTLC output on our local commitment tx using the // 2nd timeout transaction. @@ -408,16 +452,15 @@ func MakeHtlcSecondLevelTimeoutAnchorInput(signedTx *wire.MsgTx, ) } - return HtlcSecondLevelAnchorInput{ - inputKit: inputKit{ - outpoint: signedTx.TxIn[0].PreviousOutPoint, - witnessType: HtlcOfferedTimeoutSecondLevelInputConfirmed, - signDesc: signDetails.SignDesc, - heightHint: heightHint, + input := MakeBaseInput( + &signedTx.TxIn[0].PreviousOutPoint, + HtlcOfferedTimeoutSecondLevelInputConfirmed, + &signDetails.SignDesc, heightHint, nil, opts..., + ) + input.blockToMaturity = 1 - // CSV delay is always 1 for these inputs. - blockToMaturity: 1, - }, + return HtlcSecondLevelAnchorInput{ + inputKit: input.inputKit, SignedTx: signedTx, createWitness: createWitness, } @@ -429,7 +472,7 @@ func MakeHtlcSecondLevelTimeoutAnchorInput(signedTx *wire.MsgTx, // sweep the second level HTLC aggregated with other transactions. func MakeHtlcSecondLevelTimeoutTaprootInput(signedTx *wire.MsgTx, signDetails *SignDetails, - heightHint uint32) HtlcSecondLevelAnchorInput { + heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput { createWitness := func(signer Signer, txn *wire.MsgTx, hashCache *txscript.TxSigHashes, @@ -453,16 +496,15 @@ func MakeHtlcSecondLevelTimeoutTaprootInput(signedTx *wire.MsgTx, ) } - return HtlcSecondLevelAnchorInput{ - inputKit: inputKit{ - outpoint: signedTx.TxIn[0].PreviousOutPoint, - witnessType: TaprootHtlcLocalOfferedTimeout, - signDesc: signDetails.SignDesc, - heightHint: heightHint, + input := MakeBaseInput( + &signedTx.TxIn[0].PreviousOutPoint, + TaprootHtlcLocalOfferedTimeout, + &signDetails.SignDesc, heightHint, nil, opts..., + ) + input.blockToMaturity = 1 - // CSV delay is always 1 for these inputs. - blockToMaturity: 1, - }, + return HtlcSecondLevelAnchorInput{ + inputKit: input.inputKit, SignedTx: signedTx, createWitness: createWitness, } @@ -473,7 +515,7 @@ func MakeHtlcSecondLevelTimeoutTaprootInput(signedTx *wire.MsgTx, // transaction. func MakeHtlcSecondLevelSuccessAnchorInput(signedTx *wire.MsgTx, signDetails *SignDetails, preimage lntypes.Preimage, - heightHint uint32) HtlcSecondLevelAnchorInput { + heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput { // Spend an HTLC output on our local commitment tx using the 2nd // success transaction. @@ -492,18 +534,16 @@ func MakeHtlcSecondLevelSuccessAnchorInput(signedTx *wire.MsgTx, preimage[:], signer, &desc, txn, ) } + input := MakeBaseInput( + &signedTx.TxIn[0].PreviousOutPoint, + HtlcAcceptedSuccessSecondLevelInputConfirmed, + &signDetails.SignDesc, heightHint, nil, opts..., + ) + input.blockToMaturity = 1 return HtlcSecondLevelAnchorInput{ - inputKit: inputKit{ - outpoint: signedTx.TxIn[0].PreviousOutPoint, - witnessType: HtlcAcceptedSuccessSecondLevelInputConfirmed, - signDesc: signDetails.SignDesc, - heightHint: heightHint, - - // CSV delay is always 1 for these inputs. - blockToMaturity: 1, - }, SignedTx: signedTx, + inputKit: input.inputKit, createWitness: createWitness, } } @@ -513,7 +553,7 @@ func MakeHtlcSecondLevelSuccessAnchorInput(signedTx *wire.MsgTx, // commitment transaction. func MakeHtlcSecondLevelSuccessTaprootInput(signedTx *wire.MsgTx, signDetails *SignDetails, preimage lntypes.Preimage, - heightHint uint32) HtlcSecondLevelAnchorInput { + heightHint uint32, opts ...InputOpt) HtlcSecondLevelAnchorInput { createWitness := func(signer Signer, txn *wire.MsgTx, hashCache *txscript.TxSigHashes, @@ -537,16 +577,15 @@ func MakeHtlcSecondLevelSuccessTaprootInput(signedTx *wire.MsgTx, ) } - return HtlcSecondLevelAnchorInput{ - inputKit: inputKit{ - outpoint: signedTx.TxIn[0].PreviousOutPoint, - witnessType: TaprootHtlcAcceptedLocalSuccess, - signDesc: signDetails.SignDesc, - heightHint: heightHint, + input := MakeBaseInput( + &signedTx.TxIn[0].PreviousOutPoint, + TaprootHtlcAcceptedLocalSuccess, + &signDetails.SignDesc, heightHint, nil, opts..., + ) + input.blockToMaturity = 1 - // CSV delay is always 1 for these inputs. - blockToMaturity: 1, - }, + return HtlcSecondLevelAnchorInput{ + inputKit: input.inputKit, SignedTx: signedTx, createWitness: createWitness, } diff --git a/input/mocks.go b/input/mocks.go index 2f38400d8..695525955 100644 --- a/input/mocks.go +++ b/input/mocks.go @@ -8,8 +8,10 @@ import ( "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lntypes" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/mock" ) @@ -127,6 +129,17 @@ func (m *MockInput) UnconfParent() *TxInfo { return info.(*TxInfo) } +func (m *MockInput) ResolutionBlob() fn.Option[tlv.Blob] { + args := m.Called() + + info := args.Get(0) + if info == nil { + return fn.None[tlv.Blob]() + } + + return info.(fn.Option[tlv.Blob]) +} + // MockWitnessType implements the `WitnessType` interface and is used by other // packages for mock testing. type MockWitnessType struct { diff --git a/lnwallet/aux_resolutions.go b/lnwallet/aux_resolutions.go new file mode 100644 index 000000000..eaf32f0a4 --- /dev/null +++ b/lnwallet/aux_resolutions.go @@ -0,0 +1,83 @@ +package lnwallet + +import ( + "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +// CloseType is an enum that represents the type of close that we are trying to +// resolve. +type CloseType uint8 + +const ( + // LocalForceClose represents a local force close. + LocalForceClose CloseType = iota + + // RemoteForceClose represents a remote force close. + RemoteForceClose + + // BreachClose represents a breach by the remote party. + Breach +) + +// ResolutionReq is used to ask an outside sub-system for additional +// information needed to resolve a contract. +type ResolutionReq struct { + // ChanPoint is the channel point of the channel that we are trying to + // resolve. + ChanPoint wire.OutPoint + + // ShortChanID is the short channel ID of the channel that we are + // trying to resolve. + ShortChanID lnwire.ShortChannelID + + // Initiator is a bool if we're the initiator of the channel. + Initiator bool + + // CommitBlob is an optional commit blob for the channel. + CommitBlob fn.Option[tlv.Blob] + + // FundingBlob is an optional funding blob for the channel. + FundingBlob fn.Option[tlv.Blob] + + // Type is the type of the witness that we are trying to resolve. + Type input.WitnessType + + // CloseType is the type of close that we are trying to resolve. + CloseType CloseType + + // CommitTx is the force close commitment transaction. + CommitTx *wire.MsgTx + + // CommitFee is the fee that was paid for the commitment transaction. + CommitFee btcutil.Amount + + // ContractPoint is the outpoint of the contract we're trying to + // resolve. + ContractPoint wire.OutPoint + + // SignDesc is the sign descriptor for the contract. + SignDesc input.SignDescriptor + + // KeyRing is the key ring for the channel. + KeyRing *CommitmentKeyRing + + // CsvDelay is the CSV delay for the local output for this commitment. + CsvDelay uint32 + + // CltvDelay is the CLTV delay for the outpoint. + CltvDelay fn.Option[uint32] +} + +// AuxContractResolver is an interface that is used to resolve contracts that +// may need additional outside information to resolve correctly. +type AuxContractResolver interface { + // ResolveContract is called to resolve a contract that needs + // additional information to resolve properly. If no extra information + // is required, a nil Result error is returned. + ResolveContract(ResolutionReq) fn.Result[tlv.Blob] +} diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 34ea53990..5f066eb23 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -1353,6 +1353,9 @@ type LightningChannel struct { // custom channel variants. auxSigner fn.Option[AuxSigner] + // auxResolver... + auxResolver fn.Option[AuxContractResolver] + // Capacity is the total capacity of this channel. Capacity btcutil.Amount @@ -1416,8 +1419,9 @@ type channelOpts struct { localNonce *musig2.Nonces remoteNonce *musig2.Nonces - leafStore fn.Option[AuxLeafStore] - auxSigner fn.Option[AuxSigner] + leafStore fn.Option[AuxLeafStore] + auxSigner fn.Option[AuxSigner] + auxResolver fn.Option[AuxContractResolver] skipNonceInit bool } @@ -1463,6 +1467,14 @@ func WithAuxSigner(signer AuxSigner) ChannelOpt { } } +// WithAuxResolver is used to specify a custom aux contract resolver for the +// channel. +func WithAuxResolver(resolver AuxContractResolver) ChannelOpt { + return func(o *channelOpts) { + o.auxResolver = fn.Some[AuxContractResolver](resolver) + } +} + // defaultChannelOpts returns the set of default options for a new channel. func defaultChannelOpts() *channelOpts { return &channelOpts{} @@ -1507,6 +1519,7 @@ func NewLightningChannel(signer input.Signer, Signer: signer, leafStore: opts.leafStore, auxSigner: opts.auxSigner, + auxResolver: opts.auxResolver, sigPool: sigPool, currentHeight: localCommit.CommitHeight, remoteCommitChain: newCommitmentChain(), @@ -2522,6 +2535,11 @@ type BreachRetribution struct { // breaching commitment transaction. This allows downstream clients to // have access to the public keys used in the scripts. KeyRing *CommitmentKeyRing + + // ResolutionBlob is a blob used for aux channels that permits a + // spender of the output to properly resolve it in the case of a force + // close. + ResolutionBlob fn.Option[tlv.Blob] } // NewBreachRetribution creates a new fully populated BreachRetribution for the @@ -2533,7 +2551,8 @@ type BreachRetribution struct { // the required fields then ErrRevLogDataMissing will be returned. func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, breachHeight uint32, spendTx *wire.MsgTx, - leafStore fn.Option[AuxLeafStore]) (*BreachRetribution, error) { + leafStore fn.Option[AuxLeafStore], + auxResolver fn.Option[AuxContractResolver]) (*BreachRetribution, error) { //nolint:lll // Query the on-disk revocation log for the snapshot which was recorded // at this particular state num. Based on whether a legacy revocation @@ -2691,6 +2710,33 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, err } } + + // At this point, we'll check to see if we need any extra + // resolution data for this output. + resolveBlob := fn.MapOptionZ( + auxResolver, + func(a AuxContractResolver) fn.Result[tlv.Blob] { + return a.ResolveContract(ResolutionReq{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + Initiator: chanState.IsInitiator, + CommitBlob: chanState.RemoteCommitment.CustomBlob, //nolint:lll + FundingBlob: chanState.CustomBlob, + Type: input.TaprootRemoteCommitSpend, //nolint:lll + CloseType: Breach, + CommitTx: spendTx, + SignDesc: *br.LocalOutputSignDesc, + KeyRing: keyRing, + CsvDelay: theirDelay, + CommitFee: chanState.RemoteCommitment.CommitFee, //nolint:lll + }) + }, + ) + if err := resolveBlob.Err(); err != nil { + return nil, fmt.Errorf("unable to aux resolve: %w", err) + } + + br.ResolutionBlob = resolveBlob.Option() } // Similarly, if their balance exceeds the remote party's dust limit, @@ -2738,6 +2784,33 @@ func NewBreachRetribution(chanState *channeldb.OpenChannel, stateNum uint64, return nil, err } } + + // At this point, we'll check to see if we need any extra + // resolution data for this output. + resolveBlob := fn.MapOptionZ( + auxResolver, + func(a AuxContractResolver) fn.Result[tlv.Blob] { + return a.ResolveContract(ResolutionReq{ + ChanPoint: chanState.FundingOutpoint, + ShortChanID: chanState.ShortChanID(), + Initiator: chanState.IsInitiator, + CommitBlob: chanState.RemoteCommitment.CustomBlob, //nolint:lll + FundingBlob: chanState.CustomBlob, + Type: input.TaprootCommitmentRevoke, //nolint:lll + CloseType: Breach, + CommitTx: spendTx, + SignDesc: *br.RemoteOutputSignDesc, + KeyRing: keyRing, + CsvDelay: theirDelay, + CommitFee: chanState.RemoteCommitment.CommitFee, //nolint:lll + }) + }, + ) + if err := resolveBlob.Err(); err != nil { + return nil, fmt.Errorf("unable to aux resolve: %w", err) + } + + br.ResolutionBlob = resolveBlob.Option() } // Finally, with all the necessary data constructed, we can pad the @@ -7003,6 +7076,11 @@ type CommitOutputResolution struct { // that pay to the local party within the broadcast commitment // transaction. MaturityDelay uint32 + + // ResolutionBlob is a blob used for aux channels that permits a + // spender of the output to properly resolve it in the case of a force + // close. + ResolutionBlob fn.Option[tlv.Blob] } // UnilateralCloseSummary describes the details of a detected unilateral @@ -7060,7 +7138,8 @@ type UnilateralCloseSummary struct { func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, commitSpend *chainntnfs.SpendDetail, remoteCommit channeldb.ChannelCommitment, commitPoint *btcec.PublicKey, - leafStore fn.Option[AuxLeafStore]) (*UnilateralCloseSummary, error) { + leafStore fn.Option[AuxLeafStore], + auxResolver fn.Option[AuxContractResolver]) (*UnilateralCloseSummary, error) { //nolint:lll // First, we'll generate the commitment point and the revocation point // so we can re-construct the HTLC state and also our payment key. @@ -7181,6 +7260,34 @@ func NewUnilateralCloseSummary(chanState *channeldb.OpenChannel, return nil, err } } + + // At this point, we'll check to see if we need any extra + // resolution data for this output. + resolveBlob := fn.MapOptionZ( + auxResolver, + func(a AuxContractResolver) fn.Result[tlv.Blob] { + return a.ResolveContract(ResolutionReq{ + ChanPoint: chanState.FundingOutpoint, //nolint:lll + ShortChanID: chanState.ShortChanID(), + Initiator: chanState.IsInitiator, + CommitBlob: chanState.RemoteCommitment.CustomBlob, //nolint:lll + FundingBlob: chanState.CustomBlob, + Type: input.TaprootRemoteCommitSpend, //nolint:lll + CloseType: RemoteForceClose, + CommitTx: commitTxBroadcast, + ContractPoint: *selfPoint, + SignDesc: commitResolution.SelfOutputSignDesc, //nolint:lll + KeyRing: keyRing, + CsvDelay: maturityDelay, + CommitFee: chanState.RemoteCommitment.CommitFee, //nolint:lll + }) + }, + ) + if err := resolveBlob.Err(); err != nil { + return nil, fmt.Errorf("unable to aux resolve: %w", err) + } + + commitResolution.ResolutionBlob = resolveBlob.Option() } closeSummary := channeldb.ChannelCloseSummary{ @@ -8035,7 +8142,7 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { localCommitment := lc.channelState.LocalCommitment summary, err := NewLocalForceCloseSummary( lc.channelState, lc.Signer, commitTx, - localCommitment.CommitHeight, lc.leafStore, + localCommitment.CommitHeight, lc.leafStore, lc.auxResolver, ) if err != nil { return nil, fmt.Errorf("unable to gen force close "+ @@ -8053,7 +8160,8 @@ func (lc *LightningChannel) ForceClose() (*LocalForceCloseSummary, error) { // transaction corresponding to localCommit. func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, signer input.Signer, commitTx *wire.MsgTx, stateNum uint64, - leafStore fn.Option[AuxLeafStore]) (*LocalForceCloseSummary, error) { + leafStore fn.Option[AuxLeafStore], + auxResolver fn.Option[AuxContractResolver]) (*LocalForceCloseSummary, error) { //nolint:lll // Re-derive the original pkScript for to-self output within the // commitment transaction. We'll need this to find the corresponding @@ -8074,16 +8182,29 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, &chanState.LocalChanCfg, &chanState.RemoteChanCfg, ) - // TODO(roasbeef): fetch aux leave + localCommit := chanState.LocalCommitment + + // If we have a custom blob, then we'll attempt to fetch the aux leaves + // for this state. + auxLeaves, err := AuxLeavesFromCommit( + chanState, localCommit, leafStore, *keyRing, + ) + if err != nil { + return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) + } var leaseExpiry uint32 if chanState.ChanType.HasLeaseExpiration() { leaseExpiry = chanState.ThawHeight } + + localAuxLeaf := fn.MapOption(func(l CommitAuxLeaves) input.AuxTapLeaf { + return l.LocalAuxLeaf + })(auxLeaves) toLocalScript, err := CommitScriptToSelf( chanState.ChanType, chanState.IsInitiator, keyRing.ToLocalKey, keyRing.RevocationKey, csvTimeout, leaseExpiry, - input.NoneTapLeaf(), + fn.FlattenOption(localAuxLeaf), ) if err != nil { return nil, err @@ -8162,6 +8283,34 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, return nil, err } } + + // At this point, we'll check to see if we need any extra + // resolution data for this output. + resolveBlob := fn.MapOptionZ( + auxResolver, + func(a AuxContractResolver) fn.Result[tlv.Blob] { + return a.ResolveContract(ResolutionReq{ + ChanPoint: chanState.FundingOutpoint, //nolint:lll + ShortChanID: chanState.ShortChanID(), + Initiator: chanState.IsInitiator, + CommitBlob: chanState.LocalCommitment.CustomBlob, //nolint:lll + FundingBlob: chanState.CustomBlob, + Type: input.TaprootLocalCommitSpend, //nolint:lll + CloseType: LocalForceClose, + CommitTx: commitTx, + ContractPoint: commitResolution.SelfOutPoint, //nolint:lll + SignDesc: commitResolution.SelfOutputSignDesc, //nolint:lll + KeyRing: keyRing, + CsvDelay: csvTimeout, + CommitFee: chanState.LocalCommitment.CommitFee, //nolint:lll + }) + }, + ) + if err := resolveBlob.Err(); err != nil { + return nil, fmt.Errorf("unable to aux resolve: %w", err) + } + + commitResolution.ResolutionBlob = resolveBlob.Option() } // Once the delay output has been found (if it exists), then we'll also @@ -8169,15 +8318,6 @@ func NewLocalForceCloseSummary(chanState *channeldb.OpenChannel, // outgoing HTLC's that we'll need to claim as well. If this is after // recovery there is not much we can do with HTLCs, so we'll always // use what we have in our latest state when extracting resolutions. - localCommit := chanState.LocalCommitment - - auxLeaves, err := AuxLeavesFromCommit( - chanState, localCommit, leafStore, *keyRing, - ) - if err != nil { - return nil, fmt.Errorf("unable to fetch aux leaves: %w", err) - } - htlcResolutions, err := extractHtlcResolutions( chainfee.SatPerKWeight(localCommit.FeePerKw), true, signer, localCommit.Htlcs, keyRing, &chanState.LocalChanCfg, diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index fd353a0fa..9b93ef6ac 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -5696,7 +5696,7 @@ func TestChannelUnilateralCloseHtlcResolution(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5846,7 +5846,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceChannel.channelState.RemoteCommitment, aliceChannel.channelState.RemoteCurrentRevocation, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err, "unable to create alice close summary") @@ -5864,7 +5864,7 @@ func TestChannelUnilateralClosePendingCommit(t *testing.T) { spendDetail, aliceRemoteChainTip.Commitment, aliceChannel.channelState.RemoteNextRevocation, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err, "unable to create alice close summary") @@ -6745,7 +6745,7 @@ func TestNewBreachRetributionSkipsDustHtlcs(t *testing.T) { breachTx := aliceChannel.channelState.RemoteCommitment.CommitTx breachRet, err := NewBreachRetribution( aliceChannel.channelState, revokedStateNum, 100, breachTx, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err, "unable to create breach retribution") @@ -10291,7 +10291,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error as there are no past delta state saved as revocation logs yet. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10299,7 +10299,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.ErrorIs(t, err, channeldb.ErrNoPastDeltas) @@ -10345,7 +10345,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // successfully. br, err := NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, breachTx, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err) @@ -10357,7 +10357,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // since the necessary info should now be found in the revocation log. br, err = NewBreachRetribution( aliceChannel.channelState, stateNum, breachHeight, nil, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.NoError(t, err) assertRetribution(br, 1, 0) @@ -10366,7 +10366,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // error. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, breachTx, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) @@ -10374,7 +10374,7 @@ func testNewBreachRetribution(t *testing.T, chanType channeldb.ChannelType) { // provided. _, err = NewBreachRetribution( aliceChannel.channelState, stateNum+1, breachHeight, nil, - fn.None[AuxLeafStore](), + fn.None[AuxLeafStore](), fn.None[AuxContractResolver](), ) require.ErrorIs(t, err, channeldb.ErrLogEntryNotFound) } diff --git a/lnwallet/interface.go b/lnwallet/interface.go index a48e92560..8553b7650 100644 --- a/lnwallet/interface.go +++ b/lnwallet/interface.go @@ -11,6 +11,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil/hdkeychain" "github.com/btcsuite/btcd/btcutil/psbt" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" @@ -18,8 +19,10 @@ import ( base "github.com/btcsuite/btcwallet/wallet" "github.com/btcsuite/btcwallet/wallet/txauthor" "github.com/btcsuite/btcwallet/wtxmgr" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet/chainfee" + "github.com/lightningnetwork/lnd/lnwire" ) const ( @@ -592,6 +595,67 @@ type MessageSigner interface { doubleHash bool) (*ecdsa.Signature, error) } +// AddrWithKey wraps a normal addr, but also includes the internal key for the +// delivery addr if known. +type AddrWithKey struct { + lnwire.DeliveryAddress + + InternalKey fn.Option[keychain.KeyDescriptor] + + // TODO(roasbeef): consolidate w/ instance in chan closer +} + +// InternalKeyForAddr returns the internal key associated with a taproot +// address. +func InternalKeyForAddr(wallet WalletController, netParams *chaincfg.Params, + deliveryScript []byte) (fn.Option[keychain.KeyDescriptor], error) { + + none := fn.None[keychain.KeyDescriptor]() + + pkScript, err := txscript.ParsePkScript(deliveryScript) + if err != nil { + return none, err + } + addr, err := pkScript.Address(netParams) + if err != nil { + return none, err + } + + walletAddr, err := wallet.AddressInfo(addr) + if err != nil { + return none, err + } + + // No wallet addr. No error, but we'll return an nil error value here, + // as callers can use the .Option() method to get an option value. + if walletAddr == nil { + return none, nil + } + + // If it's not a taproot address, we don't require to know the internal + // key in the first place. So we don't return an error here, but also no + // internal key. + if walletAddr.AddrType() != waddrmgr.TaprootPubKey { + return none, nil + } + + pubKeyAddr, ok := walletAddr.(waddrmgr.ManagedPubKeyAddress) + if !ok { + return none, fmt.Errorf("expected pubkey addr, got %T", + pubKeyAddr) + } + + _, derivationPath, _ := pubKeyAddr.DerivationInfo() + + return fn.Some[keychain.KeyDescriptor](keychain.KeyDescriptor{ + KeyLocator: keychain.KeyLocator{ + Family: keychain.KeyFamily(derivationPath.Account), + Index: derivationPath.Index, + }, + PubKey: pubKeyAddr.PubKey(), + }), nil +} + // WalletDriver represents a "driver" for a particular concrete // WalletController implementation. A driver is identified by a globally unique // string identifier along with a 'New()' method which is responsible for diff --git a/peer/brontide.go b/peer/brontide.go index d243ac2dc..311ff0cf3 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -18,7 +18,6 @@ import ( "github.com/btcsuite/btcd/txscript" "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog" - "github.com/btcsuite/btcwallet/waddrmgr" "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/buffer" "github.com/lightningnetwork/lnd/build" @@ -36,6 +35,7 @@ import ( "github.com/lightningnetwork/lnd/htlcswitch/hop" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/invoices" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnpeer" "github.com/lightningnetwork/lnd/lnutils" "github.com/lightningnetwork/lnd/lnwallet" @@ -395,6 +395,10 @@ type Config struct { // leaves for certain custom channel types. AuxSigner fn.Option[lnwallet.AuxSigner] + // AuxResolver is an optional interface that can be used to modify the + // way contracts are resolved. + AuxResolver fn.Option[lnwallet.AuxContractResolver] + // PongBuf is a slice we'll reuse instead of allocating memory on the // heap. Since only reads will occur and no writes, there is no need // for any synchronization primitives. As a result, it's safe to share @@ -883,70 +887,32 @@ func (p *Brontide) QuitSignal() <-chan struct{} { return p.quit } -// internalKeyForAddr returns the internal key associated with a taproot -// address. -func internalKeyForAddr(wallet *lnwallet.LightningWallet, - deliveryScript []byte) (fn.Option[btcec.PublicKey], error) { - - none := fn.None[btcec.PublicKey]() - - pkScript, err := txscript.ParsePkScript(deliveryScript) - if err != nil { - return none, err - } - addr, err := pkScript.Address(&wallet.Cfg.NetParams) - if err != nil { - return none, err - } - - walletAddr, err := wallet.AddressInfo(addr) - if err != nil { - return none, err - } - - // If the address isn't known to the wallet, we can't determine the - // internal key. - if walletAddr == nil { - return none, nil - } - - // If it's not a taproot address, we don't require to know the internal - // key in the first place. So we don't return an error here, but also no - // internal key. - if walletAddr.AddrType() != waddrmgr.TaprootPubKey { - return none, nil - } - - pubKeyAddr, ok := walletAddr.(waddrmgr.ManagedPubKeyAddress) - if !ok { - return none, fmt.Errorf("expected pubkey addr, got %T", - pubKeyAddr) - } - - return fn.Some(*pubKeyAddr.PubKey()), nil -} - // addrWithInternalKey takes a delivery script, then attempts to supplement it // with information related to the internal key for the addr, but only if it's // a taproot addr. func (p *Brontide) addrWithInternalKey( - deliveryScript []byte) fn.Result[chancloser.DeliveryAddrWithKey] { + deliveryScript []byte) (*chancloser.DeliveryAddrWithKey, error) { - // TODO(roasbeef): not compatible with external shutdown addr? // Currently, custom channels cannot be created with external upfront // shutdown addresses, so this shouldn't be an issue. We only require // the internal key for taproot addresses to be able to provide a non // inclusion proof of any scripts. - - internalKey, err := internalKeyForAddr(p.cfg.Wallet, deliveryScript) + internalKeyDesc, err := lnwallet.InternalKeyForAddr( + p.cfg.Wallet, &p.cfg.Wallet.Cfg.NetParams, + deliveryScript, + ) if err != nil { - return fn.Err[chancloser.DeliveryAddrWithKey](err) + return nil, fmt.Errorf("unable to fetch internal key: %w", err) } - return fn.Ok(chancloser.DeliveryAddrWithKey{ + return &chancloser.DeliveryAddrWithKey{ DeliveryAddress: deliveryScript, - InternalKey: internalKey, - }) + InternalKey: fn.MapOption( + func(desc keychain.KeyDescriptor) btcec.PublicKey { + return *desc.PubKey + }, + )(internalKeyDesc), + }, nil } // loadActiveChannels creates indexes within the peer for tracking all active @@ -1027,6 +993,10 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( p.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) { chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s)) }) + p.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) { //nolint:lll + chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) + }) + lnChan, err := lnwallet.NewLightningChannel( p.cfg.Signer, dbChan, p.cfg.SigPool, chanOpts..., ) @@ -1191,7 +1161,7 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( addr, err := p.addrWithInternalKey( info.DeliveryScript.Val, - ).Unpack() + ) if err != nil { shutdownInfoErr = fmt.Errorf("unable to make "+ "delivery addr: %w", err) @@ -2885,7 +2855,7 @@ func (p *Brontide) fetchActiveChanCloser(chanID lnwire.ChannelID) ( return nil, fmt.Errorf("unable to estimate fee") } - addr, err := p.addrWithInternalKey(deliveryScript).Unpack() + addr, err := p.addrWithInternalKey(deliveryScript) if err != nil { return nil, fmt.Errorf("unable to parse addr: %w", err) } @@ -3131,7 +3101,7 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( channeldb.ChanStatusLocalCloseInitiator, ) - addr, err := p.addrWithInternalKey(deliveryScript).Unpack() + addr, err := p.addrWithInternalKey(deliveryScript) if err != nil { return nil, fmt.Errorf("unable to parse addr: %w", err) } @@ -3163,7 +3133,7 @@ func (p *Brontide) restartCoopClose(lnChan *lnwallet.LightningChannel) ( // createChanCloser constructs a ChanCloser from the passed parameters and is // used to de-duplicate code. func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, - deliveryScript chancloser.DeliveryAddrWithKey, + deliveryScript *chancloser.DeliveryAddrWithKey, fee chainfee.SatPerKWeight, req *htlcswitch.ChanClose, locallyInitiated bool) (*chancloser.ChanCloser, error) { @@ -3198,7 +3168,7 @@ func (p *Brontide) createChanCloser(channel *lnwallet.LightningChannel, ChainParams: &p.cfg.Wallet.Cfg.NetParams, Quit: p.quit, }, - deliveryScript, + *deliveryScript, fee, uint32(startingHeight), req, @@ -3257,7 +3227,7 @@ func (p *Brontide) handleLocalCloseReq(req *htlcswitch.ChanClose) { return } } - addr, err := p.addrWithInternalKey(deliveryScript).Unpack() + addr, err := p.addrWithInternalKey(deliveryScript) if err != nil { err = fmt.Errorf("unable to parse addr for channel "+ "%v: %w", req.ChanPoint, err) @@ -4199,6 +4169,9 @@ func (p *Brontide) addActiveChannel(c *lnpeer.NewChannel) error { p.cfg.AuxSigner.WhenSome(func(s lnwallet.AuxSigner) { chanOpts = append(chanOpts, lnwallet.WithAuxSigner(s)) }) + p.cfg.AuxResolver.WhenSome(func(s lnwallet.AuxContractResolver) { + chanOpts = append(chanOpts, lnwallet.WithAuxResolver(s)) + }) // If not already active, we'll add this channel to the set of active // channels, so we can look it up later easily according to its channel diff --git a/server.go b/server.go index 24d98525d..2dbc85d63 100644 --- a/server.go +++ b/server.go @@ -18,6 +18,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2/ecdsa" "github.com/btcsuite/btcd/btcutil" + "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/connmgr" "github.com/btcsuite/btcd/txscript" @@ -514,12 +515,14 @@ func newServer(cfg *Config, listenAddrs []net.Addr, var serializedPubKey [33]byte copy(serializedPubKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) + netParams := cfg.ActiveNetParams.Params + // Initialize the sphinx router. replayLog := htlcswitch.NewDecayedLog( dbs.DecayedLogDB, cc.ChainNotifier, ) sphinxRouter := sphinx.NewRouter( - nodeKeyECDH, cfg.ActiveNetParams.Params, replayLog, + nodeKeyECDH, netParams, replayLog, ) writeBufferPool := pool.NewWriteBuffer( @@ -1091,15 +1094,18 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ) s.txPublisher = sweep.NewTxPublisher(sweep.TxPublisherConfig{ - Signer: cc.Wallet.Cfg.Signer, - Wallet: cc.Wallet, - Estimator: cc.FeeEstimator, - Notifier: cc.ChainNotifier, + Signer: cc.Wallet.Cfg.Signer, + Wallet: cc.Wallet, + Estimator: cc.FeeEstimator, + Notifier: cc.ChainNotifier, + AuxSweeper: s.implCfg.AuxSweeper, }) s.sweeper = sweep.New(&sweep.UtxoSweeperConfig{ - FeeEstimator: cc.FeeEstimator, - GenSweepScript: newSweepPkScriptGen(cc.Wallet), + FeeEstimator: cc.FeeEstimator, + GenSweepScript: newSweepPkScriptGen( + cc.Wallet, s.cfg.ActiveNetParams.Params, + ), Signer: cc.Wallet.Cfg.Signer, Wallet: newSweeperWallet(cc.Wallet), Mempool: cc.MempoolNotifier, @@ -1110,6 +1116,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, Aggregator: aggregator, Publisher: s.txPublisher, NoDeadlineConfTarget: cfg.Sweeper.NoDeadlineConfTarget, + AuxSweeper: s.implCfg.AuxSweeper, }) s.utxoNursery = contractcourt.NewUtxoNursery(&contractcourt.NurseryConfig{ @@ -1142,10 +1149,19 @@ func newServer(cfg *Config, listenAddrs []net.Addr, s.breachArbitrator = contractcourt.NewBreachArbitrator( &contractcourt.BreachConfig{ - CloseLink: closeLink, - DB: s.chanStateDB, - Estimator: s.cc.FeeEstimator, - GenSweepScript: newSweepPkScriptGen(cc.Wallet), + CloseLink: closeLink, + DB: s.chanStateDB, + Estimator: s.cc.FeeEstimator, + GenSweepScript: func() ([]byte, error) { + addr, err := newSweepPkScriptGen( + cc.Wallet, netParams, + )().Unpack() + if err != nil { + return nil, err + } + + return addr.DeliveryAddress, nil + }, Notifier: cc.ChainNotifier, PublishTransaction: cc.Wallet.PublishTransaction, ContractBreaches: contractBreaches, @@ -1161,8 +1177,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, ChainHash: *s.cfg.ActiveNetParams.GenesisHash, IncomingBroadcastDelta: lncfg.DefaultIncomingBroadcastDelta, OutgoingBroadcastDelta: lncfg.DefaultOutgoingBroadcastDelta, - NewSweepAddr: newSweepPkScriptGen(cc.Wallet), - PublishTx: cc.Wallet.PublishTransaction, + NewSweepAddr: func() ([]byte, error) { + addr, err := newSweepPkScriptGen( + cc.Wallet, netParams, + )().Unpack() + if err != nil { + return nil, err + } + + return addr.DeliveryAddress, nil + }, + PublishTx: cc.Wallet.PublishTransaction, DeliverResolutionMsg: func(msgs ...contractcourt.ResolutionMsg) error { for _, msg := range msgs { err := s.htlcSwitch.ProcessContractResolution(msg) @@ -1269,6 +1294,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, }, AuxLeafStore: implCfg.AuxLeafStore, AuxSigner: implCfg.AuxSigner, + AuxResolver: implCfg.AuxContractResolver, }, dbs.ChanStateDB) // Select the configuration and funding parameters for Bitcoin. @@ -1517,6 +1543,8 @@ func newServer(cfg *Config, listenAddrs []net.Addr, AliasManager: s.aliasMgr, IsSweeperOutpoint: s.sweeper.IsSweeperOutpoint, AuxFundingController: implCfg.AuxFundingController, + AuxSigner: implCfg.AuxSigner, + AuxResolver: implCfg.AuxContractResolver, }) if err != nil { return nil, err @@ -1605,6 +1633,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, br, err := lnwallet.NewBreachRetribution( channel, commitHeight, 0, nil, implCfg.AuxLeafStore, + implCfg.AuxContractResolver, ) if err != nil { return nil, 0, err @@ -1638,8 +1667,17 @@ func newServer(cfg *Config, listenAddrs []net.Addr, return s.channelNotifier. SubscribeChannelEvents() }, - Signer: cc.Wallet.Cfg.Signer, - NewAddress: newSweepPkScriptGen(cc.Wallet), + Signer: cc.Wallet.Cfg.Signer, + NewAddress: func() ([]byte, error) { + addr, err := newSweepPkScriptGen( + cc.Wallet, netParams, + )().Unpack() + if err != nil { + return nil, err + } + + return addr.DeliveryAddress, nil + }, SecretKeyRing: s.cc.KeyRing, Dial: cfg.net.Dial, AuthDial: authDial, @@ -3943,6 +3981,7 @@ func (s *server) peerConnected(conn net.Conn, connReq *connmgr.ConnReq, AuxSigner: s.implCfg.AuxSigner, MsgRouter: s.implCfg.MsgRouter, AuxChanCloser: s.implCfg.AuxChanCloser, + AuxResolver: s.implCfg.AuxContractResolver, } copy(pCfg.PubKeyBytes[:], peerAddr.IdentityKey.SerializeCompressed()) @@ -4763,18 +4802,34 @@ func (s *server) SendCustomMessage(peerPub [33]byte, msgType lnwire.MessageType, // Specifically, the script generated is a version 0, pay-to-witness-pubkey-hash // (p2wkh) output. func newSweepPkScriptGen( - wallet lnwallet.WalletController) func() ([]byte, error) { + wallet lnwallet.WalletController, + netParams *chaincfg.Params) func() fn.Result[lnwallet.AddrWithKey] { - return func() ([]byte, error) { + return func() fn.Result[lnwallet.AddrWithKey] { sweepAddr, err := wallet.NewAddress( lnwallet.TaprootPubkey, false, lnwallet.DefaultAccountName, ) if err != nil { - return nil, err + return fn.Err[lnwallet.AddrWithKey](err) } - return txscript.PayToAddrScript(sweepAddr) + addr, err := txscript.PayToAddrScript(sweepAddr) + if err != nil { + return fn.Err[lnwallet.AddrWithKey](err) + } + + internalKeyDesc, err := lnwallet.InternalKeyForAddr( + wallet, netParams, addr, + ) + if err != nil { + return fn.Err[lnwallet.AddrWithKey](err) + } + + return fn.Ok(lnwallet.AddrWithKey{ + DeliveryAddress: addr, + InternalKey: internalKeyDesc, + }) } } diff --git a/sweep/fee_bumper.go b/sweep/fee_bumper.go index b4d429894..c5c875ed2 100644 --- a/sweep/fee_bumper.go +++ b/sweep/fee_bumper.go @@ -111,7 +111,7 @@ type BumpRequest struct { DeadlineHeight int32 // DeliveryAddress is the script to send the change output to. - DeliveryAddress []byte + DeliveryAddress lnwallet.AddrWithKey // MaxFeeRate is the maximum fee rate that can be used for fee bumping. MaxFeeRate chainfee.SatPerKWeight @@ -119,6 +119,10 @@ type BumpRequest struct { // StartingFeeRate is an optional parameter that can be used to specify // the initial fee rate to use for the fee function. StartingFeeRate fn.Option[chainfee.SatPerKWeight] + + // ExtraTxOut tracks if this bump request has an optional set of extra + // outputs to add to the transaction. + ExtraTxOut fn.Option[SweepOutput] } // MaxFeeRateAllowed returns the maximum fee rate allowed for the given @@ -128,7 +132,11 @@ type BumpRequest struct { func (r *BumpRequest) MaxFeeRateAllowed() (chainfee.SatPerKWeight, error) { // Get the size of the sweep tx, which will be used to calculate the // budget fee rate. - size, err := calcSweepTxWeight(r.Inputs, r.DeliveryAddress) + // + // TODO(roasbeef): also wants the extra change output? + size, err := calcSweepTxWeight( + r.Inputs, r.DeliveryAddress.DeliveryAddress, + ) if err != nil { return 0, err } @@ -170,7 +178,7 @@ func calcSweepTxWeight(inputs []input.Input, // TODO(yy): we should refactor the weight estimator to not require a // fee rate and max fee rate and make it a pure tx weight calculator. _, estimator, err := getWeightEstimate( - inputs, nil, feeRate, 0, outputPkScript, + inputs, nil, feeRate, 0, [][]byte{outputPkScript}, ) if err != nil { return 0, err @@ -249,6 +257,10 @@ type TxPublisherConfig struct { // Notifier is used to monitor the confirmation status of the tx. Notifier chainntnfs.ChainNotifier + + // AuxSweeper is an optional interface that can be used to modify the + // way sweep transaction are generated. + AuxSweeper fn.Option[AuxSweeper] } // TxPublisher is an implementation of the Bumper interface. It utilizes the @@ -401,16 +413,18 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, for { // Create a new tx with the given fee rate and check its // mempool acceptance. - tx, fee, err := t.createAndCheckTx(req, f) + sweepCtx, err := t.createAndCheckTx(req, f) switch { case err == nil: // The tx is valid, return the request ID. - requestID := t.storeRecord(tx, req, f, fee) + requestID := t.storeRecord( + sweepCtx.tx, req, f, sweepCtx.fee, + ) log.Infof("Created tx %v for %v inputs: feerate=%v, "+ - "fee=%v, inputs=%v", tx.TxHash(), - len(req.Inputs), f.FeeRate(), fee, + "fee=%v, inputs=%v", sweepCtx.tx.TxHash(), + len(req.Inputs), f.FeeRate(), sweepCtx.fee, inputTypeSummary(req.Inputs)) return requestID, nil @@ -421,8 +435,8 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // We should at least start with a feerate above the // mempool min feerate, so if we get this error, it // means something is wrong earlier in the pipeline. - log.Errorf("Current fee=%v, feerate=%v, %v", fee, - f.FeeRate(), err) + log.Errorf("Current fee=%v, feerate=%v, %v", + sweepCtx.fee, f.FeeRate(), err) fallthrough @@ -434,8 +448,8 @@ func (t *TxPublisher) createRBFCompliantTx(req *BumpRequest, // increased or maxed out. for !increased { log.Debugf("Increasing fee for next round, "+ - "current fee=%v, feerate=%v", fee, - f.FeeRate()) + "current fee=%v, feerate=%v", + sweepCtx.fee, f.FeeRate()) // If the fee function tells us that we have // used up the budget, we will return an error @@ -484,30 +498,34 @@ func (t *TxPublisher) storeRecord(tx *wire.MsgTx, req *BumpRequest, // script, and the fee rate. In addition, it validates the tx's mempool // acceptance before returning a tx that can be published directly, along with // its fee. -func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) ( - *wire.MsgTx, btcutil.Amount, error) { +func (t *TxPublisher) createAndCheckTx(req *BumpRequest, + f FeeFunction) (*sweepTxCtx, error) { // Create the sweep tx with max fee rate of 0 as the fee function // guarantees the fee rate used here won't exceed the max fee rate. - tx, fee, err := t.createSweepTx( + sweepCtx, err := t.createSweepTx( req.Inputs, req.DeliveryAddress, f.FeeRate(), ) if err != nil { - return nil, fee, fmt.Errorf("create sweep tx: %w", err) + return sweepCtx, fmt.Errorf("create sweep tx: %w", err) } // Sanity check the budget still covers the fee. - if fee > req.Budget { - return nil, fee, fmt.Errorf("%w: budget=%v, fee=%v", - ErrNotEnoughBudget, req.Budget, fee) + if sweepCtx.fee > req.Budget { + return sweepCtx, fmt.Errorf("%w: budget=%v, fee=%v", + ErrNotEnoughBudget, req.Budget, sweepCtx.fee) } + // If we had an extra txOut, then we'll update the result to include + // it. + req.ExtraTxOut = sweepCtx.extraTxOut + // Validate the tx's mempool acceptance. - err = t.cfg.Wallet.CheckMempoolAcceptance(tx) + err = t.cfg.Wallet.CheckMempoolAcceptance(sweepCtx.tx) // Exit early if the tx is valid. if err == nil { - return tx, fee, nil + return sweepCtx, nil } // Print an error log if the chain backend doesn't support the mempool @@ -515,18 +533,18 @@ func (t *TxPublisher) createAndCheckTx(req *BumpRequest, f FeeFunction) ( if errors.Is(err, rpcclient.ErrBackendVersion) { log.Errorf("TestMempoolAccept not supported by backend, " + "consider upgrading it to a newer version") - return tx, fee, nil + return sweepCtx, nil } // We are running on a backend that doesn't implement the RPC // testmempoolaccept, eg, neutrino, so we'll skip the check. if errors.Is(err, chain.ErrUnimplemented) { log.Debug("Skipped testmempoolaccept due to not implemented") - return tx, fee, nil + return sweepCtx, nil } - return nil, fee, fmt.Errorf("tx=%v failed mempool check: %w", - tx.TxHash(), err) + return sweepCtx, fmt.Errorf("tx=%v failed mempool check: %w", + sweepCtx.tx.TxHash(), err) } // broadcast takes a monitored tx and publishes it to the network. Prior to the @@ -547,6 +565,17 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { log.Debugf("Publishing sweep tx %v, num_inputs=%v, height=%v", txid, len(tx.TxIn), t.currentHeight.Load()) + // Before we go to broadcast, we'll notify the aux sweeper, if it's + // present of this new broadcast attempt. + err := fn.MapOptionZ(t.cfg.AuxSweeper, func(aux AuxSweeper) error { + return aux.NotifyBroadcast( + record.req, tx, record.fee, + ) + }) + if err != nil { + return nil, fmt.Errorf("unable to notify aux sweeper: %w", err) + } + // Set the event, and change it to TxFailed if the wallet fails to // publish it. event := TxPublished @@ -554,7 +583,7 @@ func (t *TxPublisher) broadcast(requestID uint64) (*BumpResult, error) { // Publish the sweeping tx with customized label. If the publish fails, // this error will be saved in the `BumpResult` and it will be removed // from being monitored. - err := t.cfg.Wallet.PublishTransaction( + err = t.cfg.Wallet.PublishTransaction( tx, labels.MakeLabel(labels.LabelTypeSweepTransaction, nil), ) if err != nil { @@ -922,7 +951,7 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // NOTE: The fee function is expected to have increased its returned // fee rate after calling the SkipFeeBump method. So we can use it // directly here. - tx, fee, err := t.createAndCheckTx(r.req, r.feeFunction) + sweepCtx, err := t.createAndCheckTx(r.req, r.feeFunction) // If the error is fee related, we will return no error and let the fee // bumper retry it at next block. @@ -969,17 +998,17 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, // The tx has been created without any errors, we now register a new // record by overwriting the same requestID. t.records.Store(requestID, &monitorRecord{ - tx: tx, + tx: sweepCtx.tx, req: r.req, feeFunction: r.feeFunction, - fee: fee, + fee: sweepCtx.fee, }) // Attempt to broadcast this new tx. result, err := t.broadcast(requestID) if err != nil { log.Infof("Failed to broadcast replacement tx %v: %v", - tx.TxHash(), err) + sweepCtx.tx.TxHash(), err) return fn.None[BumpResult]() } @@ -1005,7 +1034,8 @@ func (t *TxPublisher) createAndPublishTx(requestID uint64, return fn.Some(*result) } - log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), tx.TxHash()) + log.Infof("Replaced tx=%v with new tx=%v", oldTx.TxHash(), + sweepCtx.tx.TxHash()) // Otherwise, it's a successful RBF, set the event and return. result.Event = TxReplaced @@ -1118,17 +1148,28 @@ func calcCurrentConfTarget(currentHeight, deadline int32) uint32 { return confTarget } +// sweepTxCtx houses a sweep transaction with additional context. +type sweepTxCtx struct { + tx *wire.MsgTx + + fee btcutil.Amount + + extraTxOut fn.Option[SweepOutput] +} + // createSweepTx creates a sweeping tx based on the given inputs, change // address and fee rate. -func (t *TxPublisher) createSweepTx(inputs []input.Input, changePkScript []byte, - feeRate chainfee.SatPerKWeight) (*wire.MsgTx, btcutil.Amount, error) { +func (t *TxPublisher) createSweepTx(inputs []input.Input, + changePkScript lnwallet.AddrWithKey, + feeRate chainfee.SatPerKWeight) (*sweepTxCtx, error) { // Validate and calculate the fee and change amount. - txFee, changeAmtOpt, locktimeOpt, err := prepareSweepTx( + txFee, changeOutputsOpt, locktimeOpt, err := prepareSweepTx( inputs, changePkScript, feeRate, t.currentHeight.Load(), + t.cfg.AuxSweeper, ) if err != nil { - return nil, 0, err + return nil, err } var ( @@ -1171,12 +1212,12 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, changePkScript []byte, }) } - // If there's a change amount, add it to the transaction. - changeAmtOpt.WhenSome(func(changeAmt btcutil.Amount) { - sweepTx.AddTxOut(&wire.TxOut{ - PkScript: changePkScript, - Value: int64(changeAmt), - }) + // If we have change outputs to add, then add it the sweep transaction + // here. + changeOutputsOpt.WhenSome(func(changeOuts []SweepOutput) { + for i := range changeOuts { + sweepTx.AddTxOut(&changeOuts[i].TxOut) + } }) // We'll default to using the current block height as locktime, if none @@ -1185,7 +1226,7 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, changePkScript []byte, prevInputFetcher, err := input.MultiPrevOutFetcher(inputs) if err != nil { - return nil, 0, fmt.Errorf("error creating prev input fetcher "+ + return nil, fmt.Errorf("error creating prev input fetcher "+ "for hash cache: %v", err) } hashCache := txscript.NewTxSigHashes(sweepTx, prevInputFetcher) @@ -1213,35 +1254,71 @@ func (t *TxPublisher) createSweepTx(inputs []input.Input, changePkScript []byte, for idx, inp := range idxs { if err := addInputScript(idx, inp); err != nil { - return nil, 0, err + return nil, err } } log.Debugf("Created sweep tx %v for inputs:\n%v", sweepTx.TxHash(), inputTypeSummary(inputs)) - return sweepTx, txFee, nil + // Try to locate the extra change output, though there might be None. + extraTxOut := fn.MapOption(func(sweepOuts []SweepOutput) fn.Option[SweepOutput] { //nolint:lll + for _, sweepOut := range sweepOuts { + if sweepOut.IsExtra { + log.Infof("Sweep produced extra_sweep_out=%v", + spew.Sdump(sweepOut)) + + return fn.Some(sweepOut) + } + } + + return fn.None[SweepOutput]() + })(changeOutputsOpt) + + return &sweepTxCtx{ + tx: sweepTx, + fee: txFee, + extraTxOut: fn.FlattenOption(extraTxOut), + }, nil } -// prepareSweepTx returns the tx fee, an optional change amount and an optional -// locktime after a series of validations: +// prepareSweepTx returns the tx fee, a set of optional change outputs and an +// optional locktime after a series of validations: // 1. check the locktime has been reached. // 2. check the locktimes are the same. // 3. check the inputs cover the outputs. // // NOTE: if the change amount is below dust, it will be added to the tx fee. -func prepareSweepTx(inputs []input.Input, changePkScript []byte, - feeRate chainfee.SatPerKWeight, currentHeight int32) ( - btcutil.Amount, fn.Option[btcutil.Amount], fn.Option[int32], error) { +func prepareSweepTx(inputs []input.Input, changePkScript lnwallet.AddrWithKey, + feeRate chainfee.SatPerKWeight, currentHeight int32, + auxSweeper fn.Option[AuxSweeper]) ( + btcutil.Amount, fn.Option[[]SweepOutput], fn.Option[int32], error) { - noChange := fn.None[btcutil.Amount]() + noChange := fn.None[[]SweepOutput]() noLocktime := fn.None[int32]() + // Given the set of inputs we have, if we have an aux sweeper, then + // we'll attempt to see if we have any other change outputs we'll need + // to add to the sweep transaction. + changePkScripts := [][]byte{changePkScript.DeliveryAddress} + extraChangeOut := fn.MapOptionZ( + auxSweeper, + func(aux AuxSweeper) fn.Result[SweepOutput] { + return aux.DeriveSweepAddr(inputs, changePkScript) + }, + ) + if err := extraChangeOut.Err(); err != nil { + return 0, noChange, noLocktime, err + } + extraChangeOut.WhenResult(func(o SweepOutput) { + changePkScripts = append(changePkScripts, o.PkScript) + }) + // Creating a weight estimator with nil outputs and zero max fee rate. // We don't allow adding customized outputs in the sweeping tx, and the // fee rate is already being managed before we get here. inputs, estimator, err := getWeightEstimate( - inputs, nil, feeRate, 0, changePkScript, + inputs, nil, feeRate, 0, changePkScripts, ) if err != nil { return 0, noChange, noLocktime, err @@ -1259,6 +1336,12 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, requiredOutput btcutil.Amount ) + // If we have an extra change output, then we'll add it as a required + // output amt. + extraChangeOut.WhenResult(func(o SweepOutput) { + requiredOutput += btcutil.Amount(o.Value) + }) + // Go through each input and check if the required lock times have // reached and are the same. for _, o := range inputs { @@ -1305,14 +1388,23 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, // The value remaining after the required output and fees is the // change output. changeAmt := totalInput - requiredOutput - txFee - changeAmtOpt := fn.Some(changeAmt) + + changeOuts := make([]SweepOutput, 0, 2) + + extraChangeOut.WhenResult(func(o SweepOutput) { + changeOuts = append(changeOuts, o) + }) // We'll calculate the dust limit for the given changePkScript since it // is variable. - changeFloor := lnwallet.DustLimitForSize(len(changePkScript)) + changeFloor := lnwallet.DustLimitForSize( + len(changePkScript.DeliveryAddress), + ) - // If the change amount is dust, we'll move it into the fees. - if changeAmt < changeFloor { + switch { + // If the change amount is dust, we'll move it into the fees, and + // ignore it. + case changeAmt < changeFloor: log.Infof("Change amt %v below dustlimit %v, not adding "+ "change output", changeAmt, changeFloor) @@ -1327,8 +1419,16 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, // The dust amount is added to the fee. txFee += changeAmt - // Set the change amount to none. - changeAmtOpt = fn.None[btcutil.Amount]() + // Otherwise, we'll actually recognize it as a change output. + default: + changeOuts = append(changeOuts, SweepOutput{ + TxOut: wire.TxOut{ + Value: int64(changeAmt), + PkScript: changePkScript.DeliveryAddress, + }, + IsExtra: false, + InternalKey: changePkScript.InternalKey, + }) } // Optionally set the locktime. @@ -1337,6 +1437,11 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, locktimeOpt = noLocktime } + var changeOutsOpt fn.Option[[]SweepOutput] + if len(changeOuts) > 0 { + changeOutsOpt = fn.Some(changeOuts) + } + log.Debugf("Creating sweep tx for %v inputs (%s) using %v, "+ "tx_weight=%v, tx_fee=%v, locktime=%v, parents_count=%v, "+ "parents_fee=%v, parents_weight=%v, current_height=%v", @@ -1344,5 +1449,5 @@ func prepareSweepTx(inputs []input.Input, changePkScript []byte, estimator.weight(), txFee, locktimeOpt, len(estimator.parents), estimator.parentsFee, estimator.parentsWeight, currentHeight) - return txFee, changeAmtOpt, locktimeOpt, nil + return txFee, changeOutsOpt, locktimeOpt, nil } diff --git a/sweep/fee_bumper_test.go b/sweep/fee_bumper_test.go index 63a828654..a7b278702 100644 --- a/sweep/fee_bumper_test.go +++ b/sweep/fee_bumper_test.go @@ -21,12 +21,14 @@ import ( var ( // Create a taproot change script. - changePkScript = []byte{ - 0x51, 0x20, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + changePkScript = lnwallet.AddrWithKey{ + DeliveryAddress: []byte{ + 0x51, 0x20, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + }, } testInputCount atomic.Uint64 @@ -117,7 +119,9 @@ func TestCalcSweepTxWeight(t *testing.T) { require.Zero(t, weight) // Use a correct change script to test the success case. - weight, err = calcSweepTxWeight([]input.Input{&inp}, changePkScript) + weight, err = calcSweepTxWeight( + []input.Input{&inp}, changePkScript.DeliveryAddress, + ) require.NoError(t, err) // BaseTxSize 8 bytes @@ -137,7 +141,9 @@ func TestBumpRequestMaxFeeRateAllowed(t *testing.T) { inp := createTestInput(100, input.WitnessKeyHash) // The weight is 487. - weight, err := calcSweepTxWeight([]input.Input{&inp}, changePkScript) + weight, err := calcSweepTxWeight( + []input.Input{&inp}, changePkScript.DeliveryAddress, + ) require.NoError(t, err) // Define a test budget and calculates its fee rate. @@ -154,7 +160,9 @@ func TestBumpRequestMaxFeeRateAllowed(t *testing.T) { // Use a wrong change script to test the error case. name: "error calc weight", req: &BumpRequest{ - DeliveryAddress: []byte{1}, + DeliveryAddress: lnwallet.AddrWithKey{ + DeliveryAddress: []byte{1}, + }, }, expectedMaxFeeRate: 0, expectedErr: true, @@ -451,7 +459,7 @@ func TestCreateAndCheckTx(t *testing.T) { t.Run(tc.name, func(t *testing.T) { // Call the method under test. - _, _, err := tp.createAndCheckTx(tc.req, m.feeFunc) + _, err := tp.createAndCheckTx(tc.req, m.feeFunc) // Check the result is as expected. require.ErrorIs(t, err, tc.expectedErr) diff --git a/sweep/interface.go b/sweep/interface.go index 4b02f143c..41120613b 100644 --- a/sweep/interface.go +++ b/sweep/interface.go @@ -1,8 +1,12 @@ package sweep import ( + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/fn" + "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/lnwallet" ) @@ -57,3 +61,31 @@ type Wallet interface { // service. BackEnd() string } + +// SweepOutput is an output used to sweep funds from a channel output. +type SweepOutput struct { //nolint:revive + wire.TxOut + + // IsExtra indicates whether this output is an extra output that was + // added by a party other than the sweeper. + IsExtra bool + + // InternalKey is the taproot internal key of the extra output. This is + // None, if this isn't a taproot output. + InternalKey fn.Option[keychain.KeyDescriptor] +} + +// AuxSweeper is used to enable a 3rd party to further shape the sweeping +// transaction by adding a set of extra outputs to the sweeping transaction. +type AuxSweeper interface { + // DeriveSweepAddr takes a set of inputs, and the change address we'd + // use to sweep them, and maybe results an extra sweep output that we + // should add to the sweeping transaction. + DeriveSweepAddr(inputs []input.Input, + change lnwallet.AddrWithKey) fn.Result[SweepOutput] + + // NotifyBroadcast is used to notify external callers of the broadcast + // of a sweep transaction, generated by the passed BumpRequest. + NotifyBroadcast(req *BumpRequest, tx *wire.MsgTx, + totalFees btcutil.Amount) error +} diff --git a/sweep/sweeper.go b/sweep/sweeper.go index 39a03228d..4fdb9d359 100644 --- a/sweep/sweeper.go +++ b/sweep/sweeper.go @@ -297,7 +297,7 @@ type UtxoSweeper struct { // to sweep. inputs InputsMap - currentOutputScript []byte + currentOutputScript fn.Option[lnwallet.AddrWithKey] relayFeeRate chainfee.SatPerKWeight @@ -317,7 +317,7 @@ type UtxoSweeper struct { type UtxoSweeperConfig struct { // GenSweepScript generates a P2WKH script belonging to the wallet where // funds can be swept. - GenSweepScript func() ([]byte, error) + GenSweepScript func() fn.Result[lnwallet.AddrWithKey] // FeeEstimator is used when crafting sweep transactions to estimate // the necessary fee relative to the expected size of the sweep @@ -361,6 +361,10 @@ type UtxoSweeperConfig struct { // NoDeadlineConfTarget is the conf target to use when sweeping // non-time-sensitive outputs. NoDeadlineConfTarget uint32 + + // AuxSweeper is an optional interface that can be used to modify the + // way sweep transaction are generated. + AuxSweeper fn.Option[AuxSweeper] } // Result is the struct that is pushed through the result channel. Callers can @@ -795,12 +799,19 @@ func (s *UtxoSweeper) signalResult(pi *SweeperInput, result Result) { // the tx. The output address is only marked as used if the publish succeeds. func (s *UtxoSweeper) sweep(set InputSet) error { // Generate an output script if there isn't an unused script available. - if s.currentOutputScript == nil { - pkScript, err := s.cfg.GenSweepScript() + if s.currentOutputScript.IsNone() { + addr, err := s.cfg.GenSweepScript().Unpack() if err != nil { return fmt.Errorf("gen sweep script: %w", err) } - s.currentOutputScript = pkScript + s.currentOutputScript = fn.Some(addr) + } + + sweepAddr, err := s.currentOutputScript.UnwrapOrErr( + fmt.Errorf("none sweep script"), + ) + if err != nil { + return err } // Create a fee bump request and ask the publisher to broadcast it. The @@ -810,7 +821,7 @@ func (s *UtxoSweeper) sweep(set InputSet) error { Inputs: set.Inputs(), Budget: set.Budget(), DeadlineHeight: set.DeadlineHeight(), - DeliveryAddress: s.currentOutputScript, + DeliveryAddress: sweepAddr, MaxFeeRate: s.cfg.MaxFeeRate.FeePerKWeight(), StartingFeeRate: set.StartingFeeRate(), // TODO(yy): pass the strategy here. @@ -1704,10 +1715,10 @@ func (s *UtxoSweeper) handleBumpEventTxPublished(r *BumpResult) error { log.Debugf("Published sweep tx %v, num_inputs=%v, height=%v", tx.TxHash(), len(tx.TxIn), s.currentHeight) - // If there's no error, remove the output script. Otherwise - // keep it so that it can be reused for the next transaction - // and causes no address inflation. - s.currentOutputScript = nil + // If there's no error, remove the output script. Otherwise keep it so + // that it can be reused for the next transaction and causes no address + // inflation. + s.currentOutputScript = fn.None[lnwallet.AddrWithKey]() return nil } diff --git a/sweep/sweeper_test.go b/sweep/sweeper_test.go index c8d9fc510..eca481e02 100644 --- a/sweep/sweeper_test.go +++ b/sweep/sweeper_test.go @@ -12,6 +12,7 @@ import ( "github.com/lightningnetwork/lnd/chainntnfs" "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" + "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwallet/chainfee" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -667,8 +668,10 @@ func TestSweepPendingInputs(t *testing.T) { Wallet: wallet, Aggregator: aggregator, Publisher: publisher, - GenSweepScript: func() ([]byte, error) { - return testPubKey.SerializeCompressed(), nil + GenSweepScript: func() fn.Result[lnwallet.AddrWithKey] { + return fn.Ok(lnwallet.AddrWithKey{ + DeliveryAddress: testPubKey.SerializeCompressed(), //nolint:lll + }) }, NoDeadlineConfTarget: uint32(DefaultDeadlineDelta), }) diff --git a/sweep/tx_input_set.go b/sweep/tx_input_set.go index 31f20b7db..71db3a70e 100644 --- a/sweep/tx_input_set.go +++ b/sweep/tx_input_set.go @@ -220,6 +220,26 @@ func (b *BudgetInputSet) NeedWalletInput() bool { budgetBorrowable btcutil.Amount ) + // If any of the outputs in the set have a resolution blob, then this + // means we'll end up needing an extra change output. We'll tack this + // on now as an extra portion of the budget. + extraRequiredTxOut := fn.Any(func(i *SweeperInput) bool { + // If there's a required txout, then we don't count this as + // it'll be a second level HTLC. + if i.RequiredTxOut() != nil { + return false + } + + // Otherwise, we need one if we have a resolution blob. + return i.ResolutionBlob().IsSome() + }, b.inputs) + + if extraRequiredTxOut { + // TODO(roasbeef): aux sweeper ext to ask for extra output + // params and value? + budgetNeeded += 1_000 + } + for _, inp := range b.inputs { // If this input has a required output, we can assume it's a // second-level htlc txns input. Although this input must have diff --git a/sweep/tx_input_set_test.go b/sweep/tx_input_set_test.go index b6a87b378..ea3bda6b0 100644 --- a/sweep/tx_input_set_test.go +++ b/sweep/tx_input_set_test.go @@ -11,6 +11,7 @@ import ( "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/input" "github.com/lightningnetwork/lnd/lnwallet" + "github.com/lightningnetwork/lnd/tlv" "github.com/stretchr/testify/require" ) @@ -125,6 +126,7 @@ func TestNeedWalletInput(t *testing.T) { // Create a mock input that doesn't have required outputs. mockInput := &input.MockInput{} mockInput.On("RequiredTxOut").Return(nil) + mockInput.On("ResolutionBlob").Return(fn.None[tlv.Blob]()) defer mockInput.AssertExpectations(t) // Create a mock input that has required outputs. diff --git a/sweep/txgenerator.go b/sweep/txgenerator.go index 30e11023e..43fc802ba 100644 --- a/sweep/txgenerator.go +++ b/sweep/txgenerator.go @@ -38,7 +38,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, signer input.Signer) (*wire.MsgTx, btcutil.Amount, error) { inputs, estimator, err := getWeightEstimate( - inputs, outputs, feeRate, maxFeeRate, changePkScript, + inputs, outputs, feeRate, maxFeeRate, [][]byte{changePkScript}, ) if err != nil { return nil, 0, err @@ -221,7 +221,7 @@ func createSweepTx(inputs []input.Input, outputs []*wire.TxOut, // Additionally, it returns counts for the number of csv and cltv inputs. func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, feeRate, maxFeeRate chainfee.SatPerKWeight, - outputPkScript []byte) ([]input.Input, *weightEstimator, error) { + outputPkScripts [][]byte) ([]input.Input, *weightEstimator, error) { // We initialize a weight estimator so we can accurately asses the // amount of fees we need to pay for this sweep transaction. @@ -237,31 +237,33 @@ func getWeightEstimate(inputs []input.Input, outputs []*wire.TxOut, // If there is any leftover change after paying to the given outputs // and required outputs, it will go to a single segwit p2wkh or p2tr - // address. This will be our change address, so ensure it contributes to - // our weight estimate. Note that if we have other outputs, we might end - // up creating a sweep tx without a change output. It is okay to add the - // change output to the weight estimate regardless, since the estimated - // fee will just be subtracted from this already dust output, and - // trimmed. - switch { - case txscript.IsPayToTaproot(outputPkScript): - weightEstimate.addP2TROutput() + // address. This will be our change address, so ensure it contributes + // to our weight estimate. Note that if we have other outputs, we might + // end up creating a sweep tx without a change output. It is okay to + // add the change output to the weight estimate regardless, since the + // estimated fee will just be subtracted from this already dust output, + // and trimmed. + for _, outputPkScript := range outputPkScripts { + switch { + case txscript.IsPayToTaproot(outputPkScript): + weightEstimate.addP2TROutput() - case txscript.IsPayToWitnessScriptHash(outputPkScript): - weightEstimate.addP2WSHOutput() + case txscript.IsPayToWitnessScriptHash(outputPkScript): + weightEstimate.addP2WSHOutput() - case txscript.IsPayToWitnessPubKeyHash(outputPkScript): - weightEstimate.addP2WKHOutput() + case txscript.IsPayToWitnessPubKeyHash(outputPkScript): + weightEstimate.addP2WKHOutput() - case txscript.IsPayToPubKeyHash(outputPkScript): - weightEstimate.estimator.AddP2PKHOutput() + case txscript.IsPayToPubKeyHash(outputPkScript): + weightEstimate.estimator.AddP2PKHOutput() - case txscript.IsPayToScriptHash(outputPkScript): - weightEstimate.estimator.AddP2SHOutput() + case txscript.IsPayToScriptHash(outputPkScript): + weightEstimate.estimator.AddP2SHOutput() - default: - // Unknown script type. - return nil, nil, errors.New("unknown script type") + default: + // Unknown script type. + return nil, nil, errors.New("unknown script type") + } } // For each output, use its witness type to determine the estimate diff --git a/sweep/txgenerator_test.go b/sweep/txgenerator_test.go index 48dcacd49..71477bd6e 100644 --- a/sweep/txgenerator_test.go +++ b/sweep/txgenerator_test.go @@ -51,7 +51,7 @@ func TestWeightEstimate(t *testing.T) { } _, estimator, err := getWeightEstimate( - inputs, nil, 0, 0, changePkScript, + inputs, nil, 0, 0, [][]byte{changePkScript}, ) require.NoError(t, err) @@ -153,7 +153,7 @@ func testUnknownScriptInner(t *testing.T, pkscript []byte, expectFail bool) { )) } - _, _, err := getWeightEstimate(inputs, nil, 0, 0, pkscript) + _, _, err := getWeightEstimate(inputs, nil, 0, 0, [][]byte{pkscript}) if expectFail { require.Error(t, err) } else {