diff --git a/channeldb/channel.go b/channeldb/channel.go index 75eb4cde9..1a07ed91d 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -1551,7 +1551,7 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { // If this is a taproot channel, then we'll need to generate our next // verification nonce to send to the remote party. They'll use this to // sign the next update to our commitment transaction. - var nextTaprootNonce *lnwire.Musig2Nonce + var nextTaprootNonce lnwire.OptMusig2NonceTLV if c.ChanType.IsTaproot() { taprootRevProducer, err := DeriveMusig2Shachain( c.RevocationProducer, @@ -1569,7 +1569,7 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) { "nonce: %w", err) } - nextTaprootNonce = (*lnwire.Musig2Nonce)(&nextNonce.PubNonce) + nextTaprootNonce = lnwire.SomeMusig2Nonce(nextNonce.PubNonce) } return &lnwire.ChannelReestablish{ diff --git a/funding/manager.go b/funding/manager.go index de537b1f2..78cf19559 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -48,6 +48,14 @@ var ( // // NOTE: for itest, this value is changed to 10ms. checkPeerChannelReadyInterval = 1 * time.Second + + // errNoLocalNonce is returned when a local nonce is not found in the + // expected TLV. + errNoLocalNonce = fmt.Errorf("local nonce not found") + + // errNoPartialSig is returned when a partial sig is not found in the + // expected TLV. + errNoPartialSig = fmt.Errorf("partial sig not found") ) // WriteOutpoint writes an outpoint to an io.Writer. This is not the same as @@ -1801,17 +1809,17 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer, } if resCtx.reservation.IsTaproot() { - if msg.LocalNonce == nil { - err := fmt.Errorf("local nonce not set for taproot " + - "chan") - log.Error(err) - f.failFundingFlow( - resCtx.peer, cid, err, - ) + localNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoLocalNonce) + if err != nil { + log.Error(errNoLocalNonce) + + f.failFundingFlow(resCtx.peer, cid, errNoLocalNonce) + + return } remoteContribution.LocalNonce = &musig2.Nonces{ - PubNonce: *msg.LocalNonce, + PubNonce: localNonce, } } @@ -1827,13 +1835,6 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer, log.Debugf("Remote party accepted commitment constraints: %v", spew.Sdump(remoteContribution.ChannelConfig.ChannelConstraints)) - var localNonce *lnwire.Musig2Nonce - if commitType.IsTaproot() { - localNonce = (*lnwire.Musig2Nonce)( - &ourContribution.LocalNonce.PubNonce, - ) - } - // With the initiator's contribution recorded, respond with our // contribution in the next message of the workflow. fundingAccept := lnwire.AcceptChannel{ @@ -1854,7 +1855,12 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer, UpfrontShutdownScript: ourContribution.UpfrontShutdown, ChannelType: chanType, LeaseExpiry: msg.LeaseExpiry, - LocalNonce: localNonce, + } + + if commitType.IsTaproot() { + fundingAccept.LocalNonce = lnwire.SomeMusig2Nonce( + ourContribution.LocalNonce.PubNonce, + ) } if err := peer.SendMessage(true, &fundingAccept); err != nil { @@ -2044,15 +2050,17 @@ func (f *Manager) funderProcessAcceptChannel(peer lnpeer.Peer, } if resCtx.reservation.IsTaproot() { - if msg.LocalNonce == nil { - err := fmt.Errorf("local nonce not set for taproot " + - "chan") - log.Error(err) - f.failFundingFlow(resCtx.peer, cid, err) + localNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoLocalNonce) + if err != nil { + log.Error(errNoLocalNonce) + + f.failFundingFlow(resCtx.peer, cid, errNoLocalNonce) + + return } remoteContribution.LocalNonce = &musig2.Nonces{ - PubNonce: *msg.LocalNonce, + PubNonce: localNonce, } } @@ -2263,7 +2271,9 @@ func (f *Manager) continueFundingAccept(resCtx *reservationWithCtx, return } - fundingCreated.PartialSig = partialSig.ToWireSig() + fundingCreated.PartialSig = lnwire.MaybePartialSigWithNonce( + partialSig.ToWireSig(), + ) } else { fundingCreated.CommitSig, err = lnwire.NewSigFromSignature(sig) if err != nil { @@ -2317,14 +2327,15 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer, // our internal input.Signature type. var commitSig input.Signature if resCtx.reservation.IsTaproot() { - if msg.PartialSig == nil { - log.Errorf("partial sig not included: %v", err) + partialSig, err := msg.PartialSig.UnwrapOrErrV(errNoPartialSig) + if err != nil { f.failFundingFlow(peer, cid, err) + return } commitSig = new(lnwallet.MusigPartialSig).FromWireSig( - msg.PartialSig, + &partialSig, ) } else { commitSig, err = msg.CommitSig.ToSignature() @@ -2408,7 +2419,9 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer, return } - fundingSigned.PartialSig = partialSig.ToWireSig() + fundingSigned.PartialSig = lnwire.MaybePartialSigWithNonce( + partialSig.ToWireSig(), + ) } else { fundingSigned.CommitSig, err = lnwire.NewSigFromSignature(sig) if err != nil { @@ -2565,14 +2578,15 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer, // our internal input.Signature type. var commitSig input.Signature if resCtx.reservation.IsTaproot() { - if msg.PartialSig == nil { - log.Errorf("partial sig not included: %v", err) + partialSig, err := msg.PartialSig.UnwrapOrErrV(errNoPartialSig) + if err != nil { f.failFundingFlow(peer, cid, err) + return } commitSig = new(lnwallet.MusigPartialSig).FromWireSig( - msg.PartialSig, + &partialSig, ) } else { commitSig, err = msg.CommitSig.ToSignature() @@ -3153,8 +3167,8 @@ func (f *Manager) sendChannelReady(completeChan *channeldb.OpenChannel, } f.nonceMtx.Unlock() - channelReadyMsg.NextLocalNonce = (*lnwire.Musig2Nonce)( - &localNonce.PubNonce, + channelReadyMsg.NextLocalNonce = lnwire.SomeMusig2Nonce( + localNonce.PubNonce, ) } @@ -3824,11 +3838,9 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen channelReadyMsg.AliasScid = &alias if firstVerNonce != nil { - wireNonce := (*lnwire.Musig2Nonce)( - &firstVerNonce.PubNonce, + channelReadyMsg.NextLocalNonce = lnwire.SomeMusig2Nonce( //nolint:lll + firstVerNonce.PubNonce, ) - - channelReadyMsg.NextLocalNonce = wireNonce } err = peer.SendMessage(true, channelReadyMsg) @@ -3873,8 +3885,13 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen log.Infof("ChanID(%v): applying local+remote musig2 nonces", chanID) - if msg.NextLocalNonce == nil { - log.Errorf("remote nonces are nil") + remoteNonce, err := msg.NextLocalNonce.UnwrapOrErrV( + errNoLocalNonce, + ) + if err != nil { + cid := newChanIdentifier(msg.ChanID) + f.failFundingFlow(peer, cid, err) + return } @@ -3882,7 +3899,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen chanOpts, lnwallet.WithLocalMusigNonces(localNonce), lnwallet.WithRemoteMusigNonces(&musig2.Nonces{ - PubNonce: *msg.NextLocalNonce, + PubNonce: remoteNonce, }), ) } @@ -4714,13 +4731,6 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) { log.Infof("Starting funding workflow with %v for pending_id(%x), "+ "committype=%v", msg.Peer.Address(), chanID, commitType) - var localNonce *lnwire.Musig2Nonce - if commitType.IsTaproot() { - localNonce = (*lnwire.Musig2Nonce)( - &ourContribution.LocalNonce.PubNonce, - ) - } - fundingOpen := lnwire.OpenChannel{ ChainHash: *f.cfg.Wallet.Cfg.NetParams.GenesisHash, PendingChannelID: chanID, @@ -4743,8 +4753,14 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) { UpfrontShutdownScript: shutdown, ChannelType: chanType, LeaseExpiry: leaseExpiry, - LocalNonce: localNonce, } + + if commitType.IsTaproot() { + fundingOpen.LocalNonce = lnwire.SomeMusig2Nonce( + ourContribution.LocalNonce.PubNonce, + ) + } + if err := msg.Peer.SendMessage(true, &fundingOpen); err != nil { e := fmt.Errorf("unable to send funding request message: %v", err) diff --git a/lnwallet/chancloser/chancloser.go b/lnwallet/chancloser/chancloser.go index 98efa7095..b77e175e6 100644 --- a/lnwallet/chancloser/chancloser.go +++ b/lnwallet/chancloser/chancloser.go @@ -49,6 +49,10 @@ var ( // ErrInvalidShutdownScript is returned when we receive an address from // a peer that isn't either a p2wsh or p2tr address. ErrInvalidShutdownScript = fmt.Errorf("invalid shutdown script") + + // errNoShutdownNonce is returned when a shutdown message is received + // w/o a nonce for a taproot channel. + errNoShutdownNonce = fmt.Errorf("shutdown nonce not populated") ) // closeState represents all the possible states the channel closer state @@ -337,8 +341,8 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) { return nil, err } - shutdown.ShutdownNonce = (*lnwire.ShutdownNonce)( - &firstClosingNonce.PubNonce, + shutdown.ShutdownNonce = lnwire.SomeShutdownNonce( + firstClosingNonce.PubNonce, ) chancloserLog.Infof("Initiating shutdown w/ nonce: %v", @@ -548,13 +552,15 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) ( // remote nonces so we can properly create a new musig // session for signing. if c.cfg.Channel.ChanType().IsTaproot() { - if msg.ShutdownNonce == nil { - return noShutdown, fmt.Errorf("shutdown " + - "nonce not populated") + shutdownNonce, err := msg.ShutdownNonce.UnwrapOrErrV( + errNoShutdownNonce, + ) + if err != nil { + return noShutdown, err } c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{ - PubNonce: *msg.ShutdownNonce, + PubNonce: shutdownNonce, }) } @@ -594,13 +600,15 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) ( // local+remote nonces so we can properly create a new musig // session for signing. if c.cfg.Channel.ChanType().IsTaproot() { - if msg.ShutdownNonce == nil { - return noShutdown, fmt.Errorf("shutdown " + - "nonce not populated") + shutdownNonce, err := msg.ShutdownNonce.UnwrapOrErrV( + errNoShutdownNonce, + ) + if err != nil { + return noShutdown, err } c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{ - PubNonce: *msg.ShutdownNonce, + PubNonce: shutdownNonce, }) } @@ -683,10 +691,10 @@ func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned], } // ReceiveClosingSigned is a method that should be called whenever we receive a -// ClosingSigned message from the wire. It may or may not return a ClosingSigned -// of our own to send back to the remote. -func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) ( - fn.Option[lnwire.ClosingSigned], error) { +// ClosingSigned message from the wire. It may or may not return a +// ClosingSigned of our own to send back to the remote. +func (c *ChanCloser) ReceiveClosingSigned( //nolint:funlen + msg lnwire.ClosingSigned) (fn.Option[lnwire.ClosingSigned], error) { noClosing := fn.None[lnwire.ClosingSigned]() @@ -702,7 +710,7 @@ func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) ( // If this is a taproot channel, then it MUST have a partial // signature set at this point. isTaproot := c.cfg.Channel.ChanType().IsTaproot() - if isTaproot && msg.PartialSig == nil { + if isTaproot && msg.PartialSig.IsNone() { return noClosing, fmt.Errorf("partial sig not set " + "for taproot chan") @@ -807,12 +815,23 @@ func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) ( ) matchingSig := c.priorFeeOffers[remoteProposedFee] if c.cfg.Channel.ChanType().IsTaproot() { + localWireSig, err := matchingSig.PartialSig.UnwrapOrErrV( //nolint:lll + fmt.Errorf("none local sig"), + ) + if err != nil { + return noClosing, err + } + remoteWireSig, err := msg.PartialSig.UnwrapOrErrV( + fmt.Errorf("none remote sig"), + ) + if err != nil { + return noClosing, err + } + muSession := c.cfg.MusigSession - localSig, remoteSig, closeOpts, err = - muSession.CombineClosingOpts( - *matchingSig.PartialSig, - *msg.PartialSig, - ) + localSig, remoteSig, closeOpts, err = muSession.CombineClosingOpts( //nolint:lll + localWireSig, remoteWireSig, + ) if err != nil { return noClosing, err } @@ -952,7 +971,9 @@ func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) ( // over a partial signature which'll be combined once our offer is // accepted. if partialSig != nil { - closeSignedMsg.PartialSig = &partialSig.PartialSig + closeSignedMsg.PartialSig = lnwire.SomePartialSig( + partialSig.PartialSig, + ) } // We'll also save this close signed, in the case that the remote party diff --git a/lnwallet/chancloser/chancloser_test.go b/lnwallet/chancloser/chancloser_test.go index 53c0fb6ba..8111ef657 100644 --- a/lnwallet/chancloser/chancloser_test.go +++ b/lnwallet/chancloser/chancloser_test.go @@ -541,11 +541,9 @@ func TestTaprootFastClose(t *testing.T) { require.True(t, oShutdown.IsSome()) require.True(t, oClosingSigned.IsNone()) - bobShutdown := oShutdown.UnsafeFromSome() - // Alice should process the shutdown message, and create a closing // signed of her own. - oShutdown, err = aliceCloser.ReceiveShutdown(bobShutdown) + oShutdown, err = aliceCloser.ReceiveShutdown(oShutdown.UnwrapOrFail(t)) require.NoError(t, err) oClosingSigned, err = aliceCloser.BeginNegotiation() require.NoError(t, err) @@ -554,7 +552,7 @@ func TestTaprootFastClose(t *testing.T) { require.True(t, oShutdown.IsNone()) require.True(t, oClosingSigned.IsSome()) - aliceClosingSigned := oClosingSigned.UnsafeFromSome() + aliceClosingSigned := oClosingSigned.UnwrapOrFail(t) // Next, Bob will process the closing signed message, and send back a // new one that should match exactly the offer Alice sent. @@ -564,7 +562,7 @@ func TestTaprootFastClose(t *testing.T) { require.NotNil(t, tx) require.True(t, oClosingSigned.IsSome()) - bobClosingSigned := oClosingSigned.UnsafeFromSome() + bobClosingSigned := oClosingSigned.UnwrapOrFail(t) // At this point, Bob has accepted the offer, so he can broadcast the // closing transaction, and considers the channel closed. @@ -597,7 +595,7 @@ func TestTaprootFastClose(t *testing.T) { require.NotNil(t, tx) require.True(t, oClosingSigned.IsSome()) - aliceClosingSigned = oClosingSigned.UnsafeFromSome() + aliceClosingSigned = oClosingSigned.UnwrapOrFail(t) // Alice should now also broadcast her closing transaction. _, err = lnutils.RecvOrTimeout(broadcastSignal, time.Second*1) diff --git a/lnwallet/channel.go b/lnwallet/channel.go index 3871ec780..5e6af76c8 100644 --- a/lnwallet/channel.go +++ b/lnwallet/channel.go @@ -127,6 +127,13 @@ var ( // we've likely lost data ourselves. ErrForceCloseLocalDataLoss = errors.New("cannot force close " + "channel with local data loss") + + // errNoNonce is returned when a nonce is required, but none is found. + errNoNonce = errors.New("no nonce found") + + // errNoPartialSig is returned when a partial signature is required, + // but none is found. + errNoPartialSig = errors.New("no partial signature found") ) // ErrCommitSyncLocalDataLoss is returned in the case that we receive a valid @@ -4128,7 +4135,7 @@ type CommitSigs struct { // PartialSig is the musig2 partial signature for taproot commitment // transactions. - PartialSig *lnwire.PartialSigWithNonce + PartialSig lnwire.OptPartialSigWithNonceTLV } // NewCommitState wraps the various signatures needed to properly @@ -4341,7 +4348,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) { CommitSigs: &CommitSigs{ CommitSig: sig, HtlcSigs: htlcSigs, - PartialSig: partialSig, + PartialSig: lnwire.MaybePartialSigWithNonce(partialSig), }, PendingHTLCs: commitDiff.Commitment.Htlcs, }, nil @@ -4416,19 +4423,24 @@ func (lc *LightningChannel) ProcessChanSyncMsg( // bail out, otherwise we'll init our local session then continue as // normal. switch { - case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce == nil: + case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsNone(): return nil, nil, nil, fmt.Errorf("remote verification nonce " + "not sent") - case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce != nil: + case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsSome(): if lc.opts.skipNonceInit { // Don't call InitRemoteMusigNonces if we have already // done so. break } - err := lc.InitRemoteMusigNonces(&musig2.Nonces{ - PubNonce: *msg.LocalNonce, + nextNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoNonce) + if err != nil { + return nil, nil, nil, err + } + + err = lc.InitRemoteMusigNonces(&musig2.Nonces{ + PubNonce: nextNonce, }) if err != nil { return nil, nil, nil, fmt.Errorf("unable to init "+ @@ -5227,6 +5239,13 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { if lc.channelState.ChanType.IsTaproot() { localSession := lc.musigSessions.LocalSession + partialSig, err := commitSigs.PartialSig.UnwrapOrErrV( + errNoPartialSig, + ) + if err != nil { + return err + } + // As we want to ensure we never write nonces to disk, we'll // use the shachain state to generate a nonce for our next // local state. Similar to generateRevocation, we do height + 2 @@ -5236,7 +5255,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { nextHeight+1, lc.taprootNonceProducer, ) nextVerificationNonce, err := localSession.VerifyCommitSig( - localCommitTx, commitSigs.PartialSig, localCtrNonce, + localCommitTx, &partialSig, localCtrNonce, ) if err != nil { close(cancelChan) @@ -5352,9 +5371,16 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error { // serialize the ECDSA sig. For taproot channels, we'll serialize the // partial sig that includes the nonce that was used for signing. if lc.channelState.ChanType.IsTaproot() { + partialSig, err := commitSigs.PartialSig.UnwrapOrErrV( + errNoPartialSig, + ) + if err != nil { + return err + } + var sigBytes [lnwire.PartialSigWithNonceLen]byte b := bytes.NewBuffer(sigBytes[0:0]) - if err := commitSigs.PartialSig.Encode(b); err != nil { + if err := partialSig.Encode(b); err != nil { return err } @@ -5783,19 +5809,21 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) ( // Now that we have a new verification nonce from them, we can refresh // our remote musig2 session which allows us to create another state. if lc.channelState.ChanType.IsTaproot() { - if revMsg.LocalNonce == nil { - return nil, nil, nil, nil, fmt.Errorf("next " + - "revocation nonce not set") + localNonce, err := revMsg.LocalNonce.UnwrapOrErrV(errNoNonce) + if err != nil { + return nil, nil, nil, nil, err } - newRemoteSession, err := lc.musigSessions.RemoteSession.Refresh( + + session, err := lc.musigSessions.RemoteSession.Refresh( &musig2.Nonces{ - PubNonce: *revMsg.LocalNonce, + PubNonce: localNonce, }, ) if err != nil { return nil, nil, nil, nil, err } - lc.musigSessions.RemoteSession = newRemoteSession + + lc.musigSessions.RemoteSession = session } // At this point, the revocation has been accepted, and we've rotated @@ -8506,8 +8534,8 @@ func (lc *LightningChannel) generateRevocation(height uint64) (*lnwire.RevokeAnd if err != nil { return nil, err } - revocationMsg.LocalNonce = (*lnwire.Musig2Nonce)( - &nextVerificationNonce.PubNonce, + revocationMsg.LocalNonce = lnwire.SomeMusig2Nonce( + nextVerificationNonce.PubNonce, ) } diff --git a/lnwallet/channel_test.go b/lnwallet/channel_test.go index 6d3f928b6..6717b9209 100644 --- a/lnwallet/channel_test.go +++ b/lnwallet/channel_test.go @@ -2875,10 +2875,10 @@ func assertNoChanSyncNeeded(t *testing.T, aliceChannel *LightningChannel, // nonces. if aliceChannel.channelState.ChanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceChanSyncMsg.LocalNonce, + PubNonce: aliceChanSyncMsg.LocalNonce.UnwrapOrFailV(t), } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobChanSyncMsg.LocalNonce, + PubNonce: bobChanSyncMsg.LocalNonce.UnwrapOrFailV(t), } } @@ -3486,10 +3486,10 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) { // nonces. if chanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceSyncMsg.LocalNonce, + PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFail(t).Val, } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobSyncMsg.LocalNonce, + PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFail(t).Val, } } @@ -3548,10 +3548,10 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) { // nonces. if chanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceSyncMsg.LocalNonce, + PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t), } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobSyncMsg.LocalNonce, + PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t), } } @@ -3671,10 +3671,10 @@ func testChanSyncOweRevocationAndCommit(t *testing.T, // nonces. if chanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceSyncMsg.LocalNonce, + PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t), } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobSyncMsg.LocalNonce, + PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t), } } @@ -3751,10 +3751,10 @@ func testChanSyncOweRevocationAndCommit(t *testing.T, // nonces. if chanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceSyncMsg.LocalNonce, + PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t), } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobSyncMsg.LocalNonce, + PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t), } } @@ -3888,10 +3888,10 @@ func testChanSyncOweRevocationAndCommitForceTransition(t *testing.T, // nonces. if chanType.IsTaproot() { aliceChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *aliceSyncMsg.LocalNonce, + PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t), } bobChannel.pendingVerificationNonce = &musig2.Nonces{ - PubNonce: *bobSyncMsg.LocalNonce, + PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t), } } diff --git a/lnwire/accept_channel.go b/lnwire/accept_channel.go index 66dda815c..e45bcf9ab 100644 --- a/lnwire/accept_channel.go +++ b/lnwire/accept_channel.go @@ -110,7 +110,7 @@ type AcceptChannel struct { // verify the very first commitment transaction signature. // This will only be populated if the simple taproot channels type was // negotiated. - LocalNonce *Musig2Nonce + LocalNonce OptMusig2NonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -141,9 +141,9 @@ func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error { if a.LeaseExpiry != nil { recordProducers = append(recordProducers, a.LeaseExpiry) } - if a.LocalNonce != nil { - recordProducers = append(recordProducers, a.LocalNonce) - } + a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { + recordProducers = append(recordProducers, &localNonce) + }) err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) if err != nil { return err @@ -248,7 +248,7 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { var ( chanType ChannelType leaseExpiry LeaseExpiry - localNonce Musig2Nonce + localNonce = a.LocalNonce.Zero() ) typeMap, err := tlvRecords.ExtractRecords( &a.UpfrontShutdownScript, &chanType, &leaseExpiry, @@ -265,8 +265,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { a.LeaseExpiry = &leaseExpiry } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - a.LocalNonce = &localNonce + if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil { + a.LocalNonce = tlv.SomeRecordT(localNonce) } a.ExtraData = tlvRecords diff --git a/lnwire/channel_ready.go b/lnwire/channel_ready.go index 07872f800..bdcb95ce8 100644 --- a/lnwire/channel_ready.go +++ b/lnwire/channel_ready.go @@ -31,7 +31,7 @@ type ChannelReady struct { // This will only be populated if the simple taproot channels type was // negotiated. This is the local nonce that will be used by the sender // to accept a new commitment state transition. - NextLocalNonce *Musig2Nonce + NextLocalNonce OptMusig2NonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -77,7 +77,7 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { // the AliasScidRecordType. var ( aliasScid ShortChannelID - localNonce Musig2Nonce + localNonce = c.NextLocalNonce.Zero() ) typeMap, err := tlvRecords.ExtractRecords( &aliasScid, &localNonce, @@ -91,8 +91,8 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error { if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { c.AliasScid = &aliasScid } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - c.NextLocalNonce = &localNonce + if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil { + c.NextLocalNonce = tlv.SomeRecordT(localNonce) } if len(tlvRecords) != 0 { @@ -121,9 +121,9 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error { if c.AliasScid != nil { recordProducers = append(recordProducers, c.AliasScid) } - if c.NextLocalNonce != nil { - recordProducers = append(recordProducers, c.NextLocalNonce) - } + c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { + recordProducers = append(recordProducers, &localNonce) + }) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/channel_reestablish.go b/lnwire/channel_reestablish.go index b4a5258c8..e52327949 100644 --- a/lnwire/channel_reestablish.go +++ b/lnwire/channel_reestablish.go @@ -82,8 +82,7 @@ type ChannelReestablish struct { // This will only be populated if the simple taproot channels type was // negotiated. // - // TODO(roasbeef): rename to verification nonce - LocalNonce *Musig2Nonce + LocalNonce OptMusig2NonceTLV // DynHeight is an optional field that stores the dynamic commitment // negotiation height that is incremented upon successful completion of @@ -138,9 +137,9 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error { } recordProducers := make([]tlv.RecordProducer, 0, 1) - if a.LocalNonce != nil { - recordProducers = append(recordProducers, a.LocalNonce) - } + a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { + recordProducers = append(recordProducers, &localNonce) + }) a.DynHeight.WhenSome(func(h DynHeight) { recordProducers = append(recordProducers, &h) }) @@ -203,8 +202,10 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { return err } - var localNonce Musig2Nonce - var dynHeight DynHeight + var ( + dynHeight DynHeight + localNonce = a.LocalNonce.Zero() + ) typeMap, err := tlvRecords.ExtractRecords( &localNonce, &dynHeight, ) @@ -212,8 +213,8 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error { return err } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - a.LocalNonce = &localNonce + if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil { + a.LocalNonce = tlv.SomeRecordT(localNonce) } if val, ok := typeMap[CRDynHeight]; ok && val == nil { a.DynHeight = fn.Some(dynHeight) diff --git a/lnwire/closing_complete.go b/lnwire/closing_complete.go index c3cd0cc4d..d33abf672 100644 --- a/lnwire/closing_complete.go +++ b/lnwire/closing_complete.go @@ -50,9 +50,9 @@ type ClosingComplete struct { // decodeClosingSigs decodes the closing sig TLV records in the passed // ExtraOpaqueData. func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error { - sig1 := tlv.ZeroRecordT[tlv.TlvType1, Sig]() - sig2 := tlv.ZeroRecordT[tlv.TlvType2, Sig]() - sig3 := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + sig1 := c.CloserNoClosee.Zero() + sig2 := c.NoCloserClosee.Zero() + sig3 := c.CloserAndClosee.Zero() typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3) if err != nil { diff --git a/lnwire/closing_signed.go b/lnwire/closing_signed.go index 3e3651964..08b5bb6a7 100644 --- a/lnwire/closing_signed.go +++ b/lnwire/closing_signed.go @@ -36,7 +36,7 @@ type ClosingSigned struct { // // NOTE: This field is only populated if a musig2 taproot channel is // being signed for. In this case, the above Sig type MUST be blank. - PartialSig *PartialSig + PartialSig OptPartialSigTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -76,17 +76,15 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { return err } - var ( - partialSig PartialSig - ) + partialSig := c.PartialSig.Zero() typeMap, err := tlvRecords.ExtractRecords(&partialSig) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[PartialSigRecordType]; ok && val == nil { - c.PartialSig = &partialSig + if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { + c.PartialSig = tlv.SomeRecordT(partialSig) } if len(tlvRecords) != 0 { @@ -102,9 +100,9 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if c.PartialSig != nil { - recordProducers = append(recordProducers, c.PartialSig) - } + c.PartialSig.WhenSome(func(sig PartialSigTLV) { + recordProducers = append(recordProducers, &sig) + }) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/commit_sig.go b/lnwire/commit_sig.go index d25d36a8a..7deb64ae1 100644 --- a/lnwire/commit_sig.go +++ b/lnwire/commit_sig.go @@ -43,7 +43,7 @@ type CommitSig struct { // // NOTE: This field is only populated if a musig2 taproot channel is // being signed for. In this case, the above Sig type MUST be blank. - PartialSig *PartialSigWithNonce + PartialSig OptPartialSigWithNonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -81,17 +81,15 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { return err } - var ( - partialSig PartialSigWithNonce - ) + partialSig := c.PartialSig.Zero() typeMap, err := tlvRecords.ExtractRecords(&partialSig) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { - c.PartialSig = &partialSig + if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil { + c.PartialSig = tlv.SomeRecordT(partialSig) } if len(tlvRecords) != 0 { @@ -107,9 +105,9 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if c.PartialSig != nil { - recordProducers = append(recordProducers, c.PartialSig) - } + c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { + recordProducers = append(recordProducers, &sig) + }) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/funding_created.go b/lnwire/funding_created.go index f8128ff76..86aa0bb40 100644 --- a/lnwire/funding_created.go +++ b/lnwire/funding_created.go @@ -32,7 +32,7 @@ type FundingCreated struct { // // NOTE: This field is only populated if a musig2 taproot channel is // being signed for. In this case, the above Sig type MUST be blank. - PartialSig *PartialSigWithNonce + PartialSig OptPartialSigWithNonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -51,9 +51,9 @@ var _ Message = (*FundingCreated)(nil) // This is part of the lnwire.Message interface. func (f *FundingCreated) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if f.PartialSig != nil { - recordProducers = append(recordProducers, f.PartialSig) - } + f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { + recordProducers = append(recordProducers, &sig) + }) err := EncodeMessageExtraData(&f.ExtraData, recordProducers...) if err != nil { return err @@ -92,17 +92,15 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error { return err } - var ( - partialSig PartialSigWithNonce - ) + partialSig := f.PartialSig.Zero() typeMap, err := tlvRecords.ExtractRecords(&partialSig) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { - f.PartialSig = &partialSig + if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil { + f.PartialSig = tlv.SomeRecordT(partialSig) } if len(tlvRecords) != 0 { diff --git a/lnwire/funding_signed.go b/lnwire/funding_signed.go index c7fb03d15..2dd62e177 100644 --- a/lnwire/funding_signed.go +++ b/lnwire/funding_signed.go @@ -24,7 +24,7 @@ type FundingSigned struct { // // NOTE: This field is only populated if a musig2 taproot channel is // being signed for. In this case, the above Sig type MUST be blank. - PartialSig *PartialSigWithNonce + PartialSig OptPartialSigWithNonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -43,9 +43,9 @@ var _ Message = (*FundingSigned)(nil) // This is part of the lnwire.Message interface. func (f *FundingSigned) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if f.PartialSig != nil { - recordProducers = append(recordProducers, f.PartialSig) - } + f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) { + recordProducers = append(recordProducers, &sig) + }) err := EncodeMessageExtraData(&f.ExtraData, recordProducers...) if err != nil { return err @@ -78,17 +78,15 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error { return err } - var ( - partialSig PartialSigWithNonce - ) + partialSig := f.PartialSig.Zero() typeMap, err := tlvRecords.ExtractRecords(&partialSig) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { - f.PartialSig = &partialSig + if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil { + f.PartialSig = tlv.SomeRecordT(partialSig) } if len(tlvRecords) != 0 { diff --git a/lnwire/lnwire_test.go b/lnwire/lnwire_test.go index 1246f92e7..3a1d02c18 100644 --- a/lnwire/lnwire_test.go +++ b/lnwire/lnwire_test.go @@ -44,11 +44,19 @@ var ( const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" -func randLocalNonce(r *rand.Rand) *Musig2Nonce { +func randLocalNonce(r *rand.Rand) Musig2Nonce { var nonce Musig2Nonce _, _ = io.ReadFull(r, nonce[:]) - return &nonce + return nonce +} + +func someLocalNonce[T tlv.TlvType]( + r *rand.Rand) tlv.OptionalRecordT[T, Musig2Nonce] { + + return tlv.SomeRecordT(tlv.NewRecordT[T, Musig2Nonce]( + randLocalNonce(r), + )) } func randPartialSig(r *rand.Rand) (*PartialSig, error) { @@ -65,6 +73,19 @@ func randPartialSig(r *rand.Rand) (*PartialSig, error) { }, nil } +func somePartialSig(t *testing.T, + r *rand.Rand) tlv.OptionalRecordT[PartialSigType, PartialSig] { + + sig, err := randPartialSig(r) + if err != nil { + t.Fatal(err) + } + + return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig]( + *sig, + )) +} + func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { var sigBytes [32]byte if _, err := r.Read(sigBytes[:]); err != nil { @@ -76,10 +97,25 @@ func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { return &PartialSigWithNonce{ PartialSig: NewPartialSig(s), - Nonce: *randLocalNonce(r), + Nonce: randLocalNonce(r), }, nil } +func somePartialSigWithNonce(t *testing.T, + r *rand.Rand) OptPartialSigWithNonceTLV { + + sig, err := randPartialSigWithNonce(r) + if err != nil { + t.Fatal(err) + } + + return tlv.SomeRecordT( + tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce]( + *sig, + ), + ) +} + func randAlias(r *rand.Rand) NodeAlias { var a NodeAlias for i := range a { @@ -480,7 +516,8 @@ func TestLightningWireProtocol(t *testing.T) { req.LeaseExpiry = new(LeaseExpiry) *req.LeaseExpiry = LeaseExpiry(1337) - req.LocalNonce = randLocalNonce(r) + //nolint:lll + req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) } else { req.UpfrontShutdownScript = []byte{} } @@ -554,7 +591,8 @@ func TestLightningWireProtocol(t *testing.T) { req.LeaseExpiry = new(LeaseExpiry) *req.LeaseExpiry = LeaseExpiry(1337) - req.LocalNonce = randLocalNonce(r) + //nolint:lll + req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) } else { req.UpfrontShutdownScript = []byte{} } @@ -591,12 +629,7 @@ func TestLightningWireProtocol(t *testing.T) { // 1/2 chance to attach a partial sig. if r.Intn(2) == 0 { - req.PartialSig, err = randPartialSigWithNonce(r) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } + req.PartialSig = somePartialSigWithNonce(t, r) } v[0] = reflect.ValueOf(req) @@ -621,12 +654,7 @@ func TestLightningWireProtocol(t *testing.T) { // 1/2 chance to attach a partial sig. if r.Intn(2) == 0 { - req.PartialSig, err = randPartialSigWithNonce(r) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } + req.PartialSig = somePartialSigWithNonce(t, r) } v[0] = reflect.ValueOf(req) @@ -649,7 +677,9 @@ func TestLightningWireProtocol(t *testing.T) { if r.Int31()%2 == 0 { scid := NewShortChanIDFromInt(uint64(r.Int63())) req.AliasScid = &scid - req.NextLocalNonce = randLocalNonce(r) + + //nolint:lll + req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r) } v[0] = reflect.ValueOf(*req) @@ -676,9 +706,8 @@ func TestLightningWireProtocol(t *testing.T) { } if r.Int31()%2 == 0 { - req.ShutdownNonce = (*ShutdownNonce)( - randLocalNonce(r), - ) + //nolint:lll + req.ShutdownNonce = someLocalNonce[ShutdownNonceType](r) } v[0] = reflect.ValueOf(req) @@ -701,12 +730,7 @@ func TestLightningWireProtocol(t *testing.T) { } if r.Int31()%2 == 0 { - req.PartialSig, err = randPartialSig(r) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } + req.PartialSig = somePartialSig(t, r) } v[0] = reflect.ValueOf(req) @@ -854,12 +878,7 @@ func TestLightningWireProtocol(t *testing.T) { // 50/50 chance to attach a partial sig. if r.Int31()%2 == 0 { - req.PartialSig, err = randPartialSigWithNonce(r) - if err != nil { - t.Fatalf("unable to generate sig: %v", - err) - return - } + req.PartialSig = somePartialSigWithNonce(t, r) } v[0] = reflect.ValueOf(*req) @@ -883,7 +902,8 @@ func TestLightningWireProtocol(t *testing.T) { // 50/50 chance to attach a local nonce. if r.Int31()%2 == 0 { - req.LocalNonce = randLocalNonce(r) + //nolint:lll + req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) } v[0] = reflect.ValueOf(*req) @@ -1107,7 +1127,8 @@ func TestLightningWireProtocol(t *testing.T) { return } - req.LocalNonce = randLocalNonce(r) + //nolint:lll + req.LocalNonce = someLocalNonce[NonceRecordTypeT](r) } v[0] = reflect.ValueOf(req) @@ -1232,7 +1253,7 @@ func TestLightningWireProtocol(t *testing.T) { } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() + sig := req.CloserNoClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", @@ -1243,7 +1264,7 @@ func TestLightningWireProtocol(t *testing.T) { req.CloserNoClosee = tlv.SomeRecordT(sig) } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() + sig := req.NoCloserClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", @@ -1254,7 +1275,7 @@ func TestLightningWireProtocol(t *testing.T) { req.NoCloserClosee = tlv.SomeRecordT(sig) } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + sig := req.CloserAndClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", @@ -1281,7 +1302,7 @@ func TestLightningWireProtocol(t *testing.T) { } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() + sig := req.CloserNoClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", @@ -1292,7 +1313,7 @@ func TestLightningWireProtocol(t *testing.T) { req.CloserNoClosee = tlv.SomeRecordT(sig) } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() + sig := req.NoCloserClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", @@ -1303,7 +1324,7 @@ func TestLightningWireProtocol(t *testing.T) { req.NoCloserClosee = tlv.SomeRecordT(sig) } if r.Intn(2) == 0 { - sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() + sig := req.CloserAndClosee.Zero() _, err := r.Read(sig.Val.bytes[:]) if err != nil { t.Fatalf("unable to generate sig: %v", diff --git a/lnwire/musig2.go b/lnwire/musig2.go index 6602ee694..cfc753f82 100644 --- a/lnwire/musig2.go +++ b/lnwire/musig2.go @@ -7,21 +7,33 @@ import ( "github.com/lightningnetwork/lnd/tlv" ) -const ( - // NonceRecordType is the TLV type used to encode a local musig2 nonce. - NonceRecordType tlv.Type = 4 -) +// NonceRecordTypeT is the TLV type used to encode a local musig2 nonce. +type NonceRecordTypeT = tlv.TlvType4 -// Musig2Nonce represents a musig2 public nonce, which is the concatenation of -// two EC points serialized in compressed format. -type Musig2Nonce [musig2.PubNonceSize]byte +// nonceRecordType is the TLV (integer) type used to encode a local musig2 +// nonce. +var nonceRecordType tlv.Type = (NonceRecordTypeT)(nil).TypeVal() + +type ( + // Musig2Nonce represents a musig2 public nonce, which is the + // concatenation of two EC points serialized in compressed format. + Musig2Nonce [musig2.PubNonceSize]byte + + // Musig2NonceTLV is a TLV type that can be used to encode/decode a + // musig2 nonce. This is an optional TLV. + Musig2NonceTLV = tlv.RecordT[NonceRecordTypeT, Musig2Nonce] + + // OptMusig2NonceTLV is a TLV type that can be used to encode/decode a + // musig2 nonce. + OptMusig2NonceTLV = tlv.OptionalRecordT[NonceRecordTypeT, Musig2Nonce] +) // Record returns a TLV record that can be used to encode/decode the musig2 // nonce from a given TLV stream. func (m *Musig2Nonce) Record() tlv.Record { return tlv.MakeStaticRecord( - NonceRecordType, m, musig2.PubNonceSize, nonceTypeEncoder, - nonceTypeDecoder, + nonceRecordType, m, musig2.PubNonceSize, + nonceTypeEncoder, nonceTypeDecoder, ) } @@ -48,3 +60,10 @@ func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte, val, "lnwire.Musig2Nonce", l, musig2.PubNonceSize, ) } + +// SomeMusig2Nonce is a helper function that creates a musig2 nonce TLV. +func SomeMusig2Nonce(nonce Musig2Nonce) OptMusig2NonceTLV { + return tlv.SomeRecordT( + tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce), + ) +} diff --git a/lnwire/open_channel.go b/lnwire/open_channel.go index 9cb4bc41a..9694290f7 100644 --- a/lnwire/open_channel.go +++ b/lnwire/open_channel.go @@ -146,7 +146,7 @@ type OpenChannel struct { // verify the very first commitment transaction signature. This will // only be populated if the simple taproot channels type was // negotiated. - LocalNonce *Musig2Nonce + LocalNonce OptMusig2NonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -175,9 +175,9 @@ func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error { if o.LeaseExpiry != nil { recordProducers = append(recordProducers, o.LeaseExpiry) } - if o.LocalNonce != nil { - recordProducers = append(recordProducers, o.LocalNonce) - } + o.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { + recordProducers = append(recordProducers, &localNonce) + }) err := EncodeMessageExtraData(&o.ExtraData, recordProducers...) if err != nil { return err @@ -302,7 +302,7 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { var ( chanType ChannelType leaseExpiry LeaseExpiry - localNonce Musig2Nonce + localNonce = o.LocalNonce.Zero() ) typeMap, err := tlvRecords.ExtractRecords( &o.UpfrontShutdownScript, &chanType, &leaseExpiry, @@ -319,8 +319,8 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { o.LeaseExpiry = &leaseExpiry } - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - o.LocalNonce = &localNonce + if val, ok := typeMap[o.LocalNonce.TlvType()]; ok && val == nil { + o.LocalNonce = tlv.SomeRecordT(localNonce) } o.ExtraData = tlvRecords diff --git a/lnwire/partial_sig.go b/lnwire/partial_sig.go index e460270f1..1751ae5ce 100644 --- a/lnwire/partial_sig.go +++ b/lnwire/partial_sig.go @@ -11,11 +11,20 @@ import ( const ( // PartialSigLen is the length of a musig2 partial signature. PartialSigLen = 32 +) - // PartialSigRecordType is the type of the tlv record for a musig2 +type ( + // PartialSigType is the type of the tlv record for a musig2 // partial signature. This is an _even_ type, which means it's required // if included. - PartialSigRecordType tlv.Type = 6 + PartialSigType = tlv.TlvType6 + + // PartialSigTLV is a tlv record for a musig2 partial signature. + PartialSigTLV = tlv.RecordT[PartialSigType, PartialSig] + + // OptPartialSigTLV is a tlv record for a musig2 partial signature. + // This is an optional record type. + OptPartialSigTLV = tlv.OptionalRecordT[PartialSigType, PartialSig] ) // PartialSig is the base partial sig type. This only encodes the 32-byte @@ -36,7 +45,7 @@ func NewPartialSig(sig btcec.ModNScalar) PartialSig { // Record returns the tlv record for the partial sig. func (p *PartialSig) Record() tlv.Record { return tlv.MakeStaticRecord( - PartialSigRecordType, p, PartialSigLen, + (PartialSigType)(nil).TypeVal(), p, PartialSigLen, partialSigTypeEncoder, partialSigTypeDecoder, ) } @@ -88,16 +97,35 @@ func (p *PartialSig) Decode(r io.Reader) error { return partialSigTypeDecoder(r, p, nil, PartialSigLen) } +// SomePartialSig is a helper function that returns an otional PartialSig. +func SomePartialSig(sig PartialSig) OptPartialSigTLV { + return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig](sig)) +} + const ( // PartialSigWithNonceLen is the length of a serialized // PartialSigWithNonce. The sig is encoded as the 32 byte S value // followed by the 66 nonce value. PartialSigWithNonceLen = 98 +) - // PartialSigWithNonceRecordType is the type of the tlv record for a - // musig2 partial signature with nonce. This is an _even_ type, which - // means it's required if included. - PartialSigWithNonceRecordType tlv.Type = 2 +type ( + // PartialSigWithNonceType is the type of the tlv record for a musig2 + // partial signature with nonce. This is an _even_ type, which means + // it's required if included. + PartialSigWithNonceType = tlv.TlvType2 + + // PartialSigWithNonceTLV is a tlv record for a musig2 partial + // signature. + PartialSigWithNonceTLV = tlv.RecordT[ + PartialSigWithNonceType, PartialSigWithNonce, + ] + + // OptPartialSigWithNonceTLV is a tlv record for a musig2 partial + // signature. This is an optional record type. + OptPartialSigWithNonceTLV = tlv.OptionalRecordT[ + PartialSigWithNonceType, PartialSigWithNonce, + ] ) // PartialSigWithNonce is a partial signature with the nonce that was used to @@ -129,8 +157,9 @@ func NewPartialSigWithNonce(nonce [musig2.PubNonceSize]byte, // Record returns the tlv record for the partial sig with nonce. func (p *PartialSigWithNonce) Record() tlv.Record { return tlv.MakeStaticRecord( - PartialSigWithNonceRecordType, p, PartialSigWithNonceLen, - partialSigWithNonceTypeEncoder, partialSigWithNonceTypeDecoder, + (PartialSigWithNonceType)(nil).TypeVal(), p, + PartialSigWithNonceLen, partialSigWithNonceTypeEncoder, + partialSigWithNonceTypeDecoder, ) } @@ -199,3 +228,20 @@ func (p *PartialSigWithNonce) Decode(r io.Reader) error { r, p, nil, PartialSigWithNonceLen, ) } + +// MaybePartialSigWithNonce is a helper function that returns an optional +// PartialSigWithNonceTLV. +func MaybePartialSigWithNonce(sig *PartialSigWithNonce, +) OptPartialSigWithNonceTLV { + + if sig == nil { + var none OptPartialSigWithNonceTLV + return none + } + + return tlv.SomeRecordT( + tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce]( + *sig, + ), + ) +} diff --git a/lnwire/revoke_and_ack.go b/lnwire/revoke_and_ack.go index 6b6b80167..9dca1631a 100644 --- a/lnwire/revoke_and_ack.go +++ b/lnwire/revoke_and_ack.go @@ -36,7 +36,7 @@ type RevokeAndAck struct { // LocalNonce is the next _local_ nonce for the sending party. This // allows the receiving party to propose a new commitment using their // remote nonce and the sender's local nonce. - LocalNonce *Musig2Nonce + LocalNonce OptMusig2NonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -74,15 +74,15 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { return err } - var musigNonce Musig2Nonce - typeMap, err := tlvRecords.ExtractRecords(&musigNonce) + localNonce := c.LocalNonce.Zero() + typeMap, err := tlvRecords.ExtractRecords(&localNonce) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[NonceRecordType]; ok && val == nil { - c.LocalNonce = &musigNonce + if val, ok := typeMap[c.LocalNonce.TlvType()]; ok && val == nil { + c.LocalNonce = tlv.SomeRecordT(localNonce) } if len(tlvRecords) != 0 { @@ -98,9 +98,9 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if c.LocalNonce != nil { - recordProducers = append(recordProducers, c.LocalNonce) - } + c.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) { + recordProducers = append(recordProducers, &localNonce) + }) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) if err != nil { return err diff --git a/lnwire/shutdown.go b/lnwire/shutdown.go index 5b59b47ab..c5455651b 100644 --- a/lnwire/shutdown.go +++ b/lnwire/shutdown.go @@ -4,52 +4,21 @@ import ( "bytes" "io" - "github.com/btcsuite/btcd/btcec/v2/schnorr/musig2" "github.com/lightningnetwork/lnd/tlv" ) -const ( - // ShutdownNonceRecordType is the type of the shutdown nonce TLV record. - ShutdownNonceRecordType = 8 +type ( + // ShutdownNonceType is the type of the shutdown nonce TLV record. + ShutdownNonceType = tlv.TlvType8 + + // ShutdownNonceTLV is the TLV record that contains the shutdown nonce. + ShutdownNonceTLV = tlv.OptionalRecordT[ShutdownNonceType, Musig2Nonce] ) -// ShutdownNonce is the type of the nonce we send during the shutdown flow. -// Unlike the other nonces, this nonce is symmetric w.r.t the message being -// signed (there's only one message for shutdown: the co-op close txn). -type ShutdownNonce Musig2Nonce - -// Record returns a TLV record that can be used to encode/decode the musig2 -// nonce from a given TLV stream. -func (s *ShutdownNonce) Record() tlv.Record { - return tlv.MakeStaticRecord( - ShutdownNonceRecordType, s, musig2.PubNonceSize, - shutdownNonceTypeEncoder, shutdownNonceTypeDecoder, - ) -} - -// shutdownNonceTypeEncoder is a custom TLV encoder for the Musig2Nonce type. -func shutdownNonceTypeEncoder(w io.Writer, val interface{}, - _ *[8]byte) error { - - if v, ok := val.(*ShutdownNonce); ok { - _, err := w.Write(v[:]) - return err - } - - return tlv.NewTypeForEncodingErr(val, "lnwire.Musig2Nonce") -} - -// shutdownNonceTypeDecoder is a custom TLV decoder for the Musig2Nonce record. -func shutdownNonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte, - l uint64) error { - - if v, ok := val.(*ShutdownNonce); ok { - _, err := io.ReadFull(r, v[:]) - return err - } - - return tlv.NewTypeForDecodingErr( - val, "lnwire.ShutdownNonce", l, musig2.PubNonceSize, +// SomeShutdownNonce returns a ShutdownNonceTLV with the given nonce. +func SomeShutdownNonce(nonce Musig2Nonce) ShutdownNonceTLV { + return tlv.SomeRecordT( + tlv.NewRecordT[ShutdownNonceType, Musig2Nonce](nonce), ) } @@ -67,7 +36,7 @@ type Shutdown struct { // ShutdownNonce is the nonce the sender will use to sign the first // co-op sign offer. - ShutdownNonce *ShutdownNonce + ShutdownNonce ShutdownNonceTLV // ExtraData is the set of data that was appended to this message to // fill out the full maximum transport message size. These fields can @@ -102,15 +71,15 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { return err } - var musigNonce ShutdownNonce + musigNonce := s.ShutdownNonce.Zero() typeMap, err := tlvRecords.ExtractRecords(&musigNonce) if err != nil { return err } // Set the corresponding TLV types if they were included in the stream. - if val, ok := typeMap[ShutdownNonceRecordType]; ok && val == nil { - s.ShutdownNonce = &musigNonce + if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil { + s.ShutdownNonce = tlv.SomeRecordT(musigNonce) } if len(tlvRecords) != 0 { @@ -126,9 +95,11 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error { // This is part of the lnwire.Message interface. func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { recordProducers := make([]tlv.RecordProducer, 0, 1) - if s.ShutdownNonce != nil { - recordProducers = append(recordProducers, s.ShutdownNonce) - } + s.ShutdownNonce.WhenSome( + func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) { + recordProducers = append(recordProducers, &nonce) + }, + ) err := EncodeMessageExtraData(&s.ExtraData, recordProducers...) if err != nil { return err