multi: upgrade new taproot TLVs to use tlv.OptionalRecordT

In this commit, we update new Taproot related TLVs (nonces, partial sig,
sig with nonce, etc). Along the way we were able to get rid of some
boiler plate, but most importantly, we're able to better protect against
API misuse (using a nonce that isn't initialized, etc) with the new
options API. In some areas this introduces a bit of extra boiler plate,
and where applicable I used some new helper functions to help cut down
on the noise.

Note to reviewers: this is done as a single commit, as changing the API
breaks all callers, so if we want things to compile it needs to be in a
wumbo commit.
This commit is contained in:
Olaoluwa Osuntokun 2024-02-23 18:04:51 -08:00
parent 6bd556d38c
commit 7feb8b21e1
No known key found for this signature in database
GPG Key ID: 3BBD59E99B280306
20 changed files with 401 additions and 288 deletions

View File

@ -1551,7 +1551,7 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) {
// If this is a taproot channel, then we'll need to generate our next // If this is a taproot channel, then we'll need to generate our next
// verification nonce to send to the remote party. They'll use this to // verification nonce to send to the remote party. They'll use this to
// sign the next update to our commitment transaction. // sign the next update to our commitment transaction.
var nextTaprootNonce *lnwire.Musig2Nonce var nextTaprootNonce lnwire.OptMusig2NonceTLV
if c.ChanType.IsTaproot() { if c.ChanType.IsTaproot() {
taprootRevProducer, err := DeriveMusig2Shachain( taprootRevProducer, err := DeriveMusig2Shachain(
c.RevocationProducer, c.RevocationProducer,
@ -1569,7 +1569,7 @@ func (c *OpenChannel) ChanSyncMsg() (*lnwire.ChannelReestablish, error) {
"nonce: %w", err) "nonce: %w", err)
} }
nextTaprootNonce = (*lnwire.Musig2Nonce)(&nextNonce.PubNonce) nextTaprootNonce = lnwire.SomeMusig2Nonce(nextNonce.PubNonce)
} }
return &lnwire.ChannelReestablish{ return &lnwire.ChannelReestablish{

View File

@ -48,6 +48,14 @@ var (
// //
// NOTE: for itest, this value is changed to 10ms. // NOTE: for itest, this value is changed to 10ms.
checkPeerChannelReadyInterval = 1 * time.Second checkPeerChannelReadyInterval = 1 * time.Second
// errNoLocalNonce is returned when a local nonce is not found in the
// expected TLV.
errNoLocalNonce = fmt.Errorf("local nonce not found")
// errNoPartialSig is returned when a partial sig is not found in the
// expected TLV.
errNoPartialSig = fmt.Errorf("partial sig not found")
) )
// WriteOutpoint writes an outpoint to an io.Writer. This is not the same as // WriteOutpoint writes an outpoint to an io.Writer. This is not the same as
@ -1801,17 +1809,17 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
} }
if resCtx.reservation.IsTaproot() { if resCtx.reservation.IsTaproot() {
if msg.LocalNonce == nil { localNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoLocalNonce)
err := fmt.Errorf("local nonce not set for taproot " + if err != nil {
"chan") log.Error(errNoLocalNonce)
log.Error(err)
f.failFundingFlow( f.failFundingFlow(resCtx.peer, cid, errNoLocalNonce)
resCtx.peer, cid, err,
) return
} }
remoteContribution.LocalNonce = &musig2.Nonces{ remoteContribution.LocalNonce = &musig2.Nonces{
PubNonce: *msg.LocalNonce, PubNonce: localNonce,
} }
} }
@ -1827,13 +1835,6 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
log.Debugf("Remote party accepted commitment constraints: %v", log.Debugf("Remote party accepted commitment constraints: %v",
spew.Sdump(remoteContribution.ChannelConfig.ChannelConstraints)) spew.Sdump(remoteContribution.ChannelConfig.ChannelConstraints))
var localNonce *lnwire.Musig2Nonce
if commitType.IsTaproot() {
localNonce = (*lnwire.Musig2Nonce)(
&ourContribution.LocalNonce.PubNonce,
)
}
// With the initiator's contribution recorded, respond with our // With the initiator's contribution recorded, respond with our
// contribution in the next message of the workflow. // contribution in the next message of the workflow.
fundingAccept := lnwire.AcceptChannel{ fundingAccept := lnwire.AcceptChannel{
@ -1854,7 +1855,12 @@ func (f *Manager) fundeeProcessOpenChannel(peer lnpeer.Peer,
UpfrontShutdownScript: ourContribution.UpfrontShutdown, UpfrontShutdownScript: ourContribution.UpfrontShutdown,
ChannelType: chanType, ChannelType: chanType,
LeaseExpiry: msg.LeaseExpiry, LeaseExpiry: msg.LeaseExpiry,
LocalNonce: localNonce, }
if commitType.IsTaproot() {
fundingAccept.LocalNonce = lnwire.SomeMusig2Nonce(
ourContribution.LocalNonce.PubNonce,
)
} }
if err := peer.SendMessage(true, &fundingAccept); err != nil { if err := peer.SendMessage(true, &fundingAccept); err != nil {
@ -2044,15 +2050,17 @@ func (f *Manager) funderProcessAcceptChannel(peer lnpeer.Peer,
} }
if resCtx.reservation.IsTaproot() { if resCtx.reservation.IsTaproot() {
if msg.LocalNonce == nil { localNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoLocalNonce)
err := fmt.Errorf("local nonce not set for taproot " + if err != nil {
"chan") log.Error(errNoLocalNonce)
log.Error(err)
f.failFundingFlow(resCtx.peer, cid, err) f.failFundingFlow(resCtx.peer, cid, errNoLocalNonce)
return
} }
remoteContribution.LocalNonce = &musig2.Nonces{ remoteContribution.LocalNonce = &musig2.Nonces{
PubNonce: *msg.LocalNonce, PubNonce: localNonce,
} }
} }
@ -2263,7 +2271,9 @@ func (f *Manager) continueFundingAccept(resCtx *reservationWithCtx,
return return
} }
fundingCreated.PartialSig = partialSig.ToWireSig() fundingCreated.PartialSig = lnwire.MaybePartialSigWithNonce(
partialSig.ToWireSig(),
)
} else { } else {
fundingCreated.CommitSig, err = lnwire.NewSigFromSignature(sig) fundingCreated.CommitSig, err = lnwire.NewSigFromSignature(sig)
if err != nil { if err != nil {
@ -2317,14 +2327,15 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer,
// our internal input.Signature type. // our internal input.Signature type.
var commitSig input.Signature var commitSig input.Signature
if resCtx.reservation.IsTaproot() { if resCtx.reservation.IsTaproot() {
if msg.PartialSig == nil { partialSig, err := msg.PartialSig.UnwrapOrErrV(errNoPartialSig)
log.Errorf("partial sig not included: %v", err) if err != nil {
f.failFundingFlow(peer, cid, err) f.failFundingFlow(peer, cid, err)
return return
} }
commitSig = new(lnwallet.MusigPartialSig).FromWireSig( commitSig = new(lnwallet.MusigPartialSig).FromWireSig(
msg.PartialSig, &partialSig,
) )
} else { } else {
commitSig, err = msg.CommitSig.ToSignature() commitSig, err = msg.CommitSig.ToSignature()
@ -2408,7 +2419,9 @@ func (f *Manager) fundeeProcessFundingCreated(peer lnpeer.Peer,
return return
} }
fundingSigned.PartialSig = partialSig.ToWireSig() fundingSigned.PartialSig = lnwire.MaybePartialSigWithNonce(
partialSig.ToWireSig(),
)
} else { } else {
fundingSigned.CommitSig, err = lnwire.NewSigFromSignature(sig) fundingSigned.CommitSig, err = lnwire.NewSigFromSignature(sig)
if err != nil { if err != nil {
@ -2565,14 +2578,15 @@ func (f *Manager) funderProcessFundingSigned(peer lnpeer.Peer,
// our internal input.Signature type. // our internal input.Signature type.
var commitSig input.Signature var commitSig input.Signature
if resCtx.reservation.IsTaproot() { if resCtx.reservation.IsTaproot() {
if msg.PartialSig == nil { partialSig, err := msg.PartialSig.UnwrapOrErrV(errNoPartialSig)
log.Errorf("partial sig not included: %v", err) if err != nil {
f.failFundingFlow(peer, cid, err) f.failFundingFlow(peer, cid, err)
return return
} }
commitSig = new(lnwallet.MusigPartialSig).FromWireSig( commitSig = new(lnwallet.MusigPartialSig).FromWireSig(
msg.PartialSig, &partialSig,
) )
} else { } else {
commitSig, err = msg.CommitSig.ToSignature() commitSig, err = msg.CommitSig.ToSignature()
@ -3153,8 +3167,8 @@ func (f *Manager) sendChannelReady(completeChan *channeldb.OpenChannel,
} }
f.nonceMtx.Unlock() f.nonceMtx.Unlock()
channelReadyMsg.NextLocalNonce = (*lnwire.Musig2Nonce)( channelReadyMsg.NextLocalNonce = lnwire.SomeMusig2Nonce(
&localNonce.PubNonce, localNonce.PubNonce,
) )
} }
@ -3824,11 +3838,9 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
channelReadyMsg.AliasScid = &alias channelReadyMsg.AliasScid = &alias
if firstVerNonce != nil { if firstVerNonce != nil {
wireNonce := (*lnwire.Musig2Nonce)( channelReadyMsg.NextLocalNonce = lnwire.SomeMusig2Nonce( //nolint:lll
&firstVerNonce.PubNonce, firstVerNonce.PubNonce,
) )
channelReadyMsg.NextLocalNonce = wireNonce
} }
err = peer.SendMessage(true, channelReadyMsg) err = peer.SendMessage(true, channelReadyMsg)
@ -3873,8 +3885,13 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
log.Infof("ChanID(%v): applying local+remote musig2 nonces", log.Infof("ChanID(%v): applying local+remote musig2 nonces",
chanID) chanID)
if msg.NextLocalNonce == nil { remoteNonce, err := msg.NextLocalNonce.UnwrapOrErrV(
log.Errorf("remote nonces are nil") errNoLocalNonce,
)
if err != nil {
cid := newChanIdentifier(msg.ChanID)
f.failFundingFlow(peer, cid, err)
return return
} }
@ -3882,7 +3899,7 @@ func (f *Manager) handleChannelReady(peer lnpeer.Peer, //nolint:funlen
chanOpts, chanOpts,
lnwallet.WithLocalMusigNonces(localNonce), lnwallet.WithLocalMusigNonces(localNonce),
lnwallet.WithRemoteMusigNonces(&musig2.Nonces{ lnwallet.WithRemoteMusigNonces(&musig2.Nonces{
PubNonce: *msg.NextLocalNonce, PubNonce: remoteNonce,
}), }),
) )
} }
@ -4714,13 +4731,6 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
log.Infof("Starting funding workflow with %v for pending_id(%x), "+ log.Infof("Starting funding workflow with %v for pending_id(%x), "+
"committype=%v", msg.Peer.Address(), chanID, commitType) "committype=%v", msg.Peer.Address(), chanID, commitType)
var localNonce *lnwire.Musig2Nonce
if commitType.IsTaproot() {
localNonce = (*lnwire.Musig2Nonce)(
&ourContribution.LocalNonce.PubNonce,
)
}
fundingOpen := lnwire.OpenChannel{ fundingOpen := lnwire.OpenChannel{
ChainHash: *f.cfg.Wallet.Cfg.NetParams.GenesisHash, ChainHash: *f.cfg.Wallet.Cfg.NetParams.GenesisHash,
PendingChannelID: chanID, PendingChannelID: chanID,
@ -4743,8 +4753,14 @@ func (f *Manager) handleInitFundingMsg(msg *InitFundingMsg) {
UpfrontShutdownScript: shutdown, UpfrontShutdownScript: shutdown,
ChannelType: chanType, ChannelType: chanType,
LeaseExpiry: leaseExpiry, LeaseExpiry: leaseExpiry,
LocalNonce: localNonce,
} }
if commitType.IsTaproot() {
fundingOpen.LocalNonce = lnwire.SomeMusig2Nonce(
ourContribution.LocalNonce.PubNonce,
)
}
if err := msg.Peer.SendMessage(true, &fundingOpen); err != nil { if err := msg.Peer.SendMessage(true, &fundingOpen); err != nil {
e := fmt.Errorf("unable to send funding request message: %v", e := fmt.Errorf("unable to send funding request message: %v",
err) err)

View File

@ -49,6 +49,10 @@ var (
// ErrInvalidShutdownScript is returned when we receive an address from // ErrInvalidShutdownScript is returned when we receive an address from
// a peer that isn't either a p2wsh or p2tr address. // a peer that isn't either a p2wsh or p2tr address.
ErrInvalidShutdownScript = fmt.Errorf("invalid shutdown script") ErrInvalidShutdownScript = fmt.Errorf("invalid shutdown script")
// errNoShutdownNonce is returned when a shutdown message is received
// w/o a nonce for a taproot channel.
errNoShutdownNonce = fmt.Errorf("shutdown nonce not populated")
) )
// closeState represents all the possible states the channel closer state // closeState represents all the possible states the channel closer state
@ -337,8 +341,8 @@ func (c *ChanCloser) initChanShutdown() (*lnwire.Shutdown, error) {
return nil, err return nil, err
} }
shutdown.ShutdownNonce = (*lnwire.ShutdownNonce)( shutdown.ShutdownNonce = lnwire.SomeShutdownNonce(
&firstClosingNonce.PubNonce, firstClosingNonce.PubNonce,
) )
chancloserLog.Infof("Initiating shutdown w/ nonce: %v", chancloserLog.Infof("Initiating shutdown w/ nonce: %v",
@ -548,13 +552,15 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) (
// remote nonces so we can properly create a new musig // remote nonces so we can properly create a new musig
// session for signing. // session for signing.
if c.cfg.Channel.ChanType().IsTaproot() { if c.cfg.Channel.ChanType().IsTaproot() {
if msg.ShutdownNonce == nil { shutdownNonce, err := msg.ShutdownNonce.UnwrapOrErrV(
return noShutdown, fmt.Errorf("shutdown " + errNoShutdownNonce,
"nonce not populated") )
if err != nil {
return noShutdown, err
} }
c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{ c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{
PubNonce: *msg.ShutdownNonce, PubNonce: shutdownNonce,
}) })
} }
@ -594,13 +600,15 @@ func (c *ChanCloser) ReceiveShutdown(msg lnwire.Shutdown) (
// local+remote nonces so we can properly create a new musig // local+remote nonces so we can properly create a new musig
// session for signing. // session for signing.
if c.cfg.Channel.ChanType().IsTaproot() { if c.cfg.Channel.ChanType().IsTaproot() {
if msg.ShutdownNonce == nil { shutdownNonce, err := msg.ShutdownNonce.UnwrapOrErrV(
return noShutdown, fmt.Errorf("shutdown " + errNoShutdownNonce,
"nonce not populated") )
if err != nil {
return noShutdown, err
} }
c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{ c.cfg.MusigSession.InitRemoteNonce(&musig2.Nonces{
PubNonce: *msg.ShutdownNonce, PubNonce: shutdownNonce,
}) })
} }
@ -683,10 +691,10 @@ func (c *ChanCloser) BeginNegotiation() (fn.Option[lnwire.ClosingSigned],
} }
// ReceiveClosingSigned is a method that should be called whenever we receive a // ReceiveClosingSigned is a method that should be called whenever we receive a
// ClosingSigned message from the wire. It may or may not return a ClosingSigned // ClosingSigned message from the wire. It may or may not return a
// of our own to send back to the remote. // ClosingSigned of our own to send back to the remote.
func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) ( func (c *ChanCloser) ReceiveClosingSigned( //nolint:funlen
fn.Option[lnwire.ClosingSigned], error) { msg lnwire.ClosingSigned) (fn.Option[lnwire.ClosingSigned], error) {
noClosing := fn.None[lnwire.ClosingSigned]() noClosing := fn.None[lnwire.ClosingSigned]()
@ -702,7 +710,7 @@ func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) (
// If this is a taproot channel, then it MUST have a partial // If this is a taproot channel, then it MUST have a partial
// signature set at this point. // signature set at this point.
isTaproot := c.cfg.Channel.ChanType().IsTaproot() isTaproot := c.cfg.Channel.ChanType().IsTaproot()
if isTaproot && msg.PartialSig == nil { if isTaproot && msg.PartialSig.IsNone() {
return noClosing, return noClosing,
fmt.Errorf("partial sig not set " + fmt.Errorf("partial sig not set " +
"for taproot chan") "for taproot chan")
@ -807,12 +815,23 @@ func (c *ChanCloser) ReceiveClosingSigned(msg lnwire.ClosingSigned) (
) )
matchingSig := c.priorFeeOffers[remoteProposedFee] matchingSig := c.priorFeeOffers[remoteProposedFee]
if c.cfg.Channel.ChanType().IsTaproot() { if c.cfg.Channel.ChanType().IsTaproot() {
localWireSig, err := matchingSig.PartialSig.UnwrapOrErrV( //nolint:lll
fmt.Errorf("none local sig"),
)
if err != nil {
return noClosing, err
}
remoteWireSig, err := msg.PartialSig.UnwrapOrErrV(
fmt.Errorf("none remote sig"),
)
if err != nil {
return noClosing, err
}
muSession := c.cfg.MusigSession muSession := c.cfg.MusigSession
localSig, remoteSig, closeOpts, err = localSig, remoteSig, closeOpts, err = muSession.CombineClosingOpts( //nolint:lll
muSession.CombineClosingOpts( localWireSig, remoteWireSig,
*matchingSig.PartialSig, )
*msg.PartialSig,
)
if err != nil { if err != nil {
return noClosing, err return noClosing, err
} }
@ -952,7 +971,9 @@ func (c *ChanCloser) proposeCloseSigned(fee btcutil.Amount) (
// over a partial signature which'll be combined once our offer is // over a partial signature which'll be combined once our offer is
// accepted. // accepted.
if partialSig != nil { if partialSig != nil {
closeSignedMsg.PartialSig = &partialSig.PartialSig closeSignedMsg.PartialSig = lnwire.SomePartialSig(
partialSig.PartialSig,
)
} }
// We'll also save this close signed, in the case that the remote party // We'll also save this close signed, in the case that the remote party

View File

@ -541,11 +541,9 @@ func TestTaprootFastClose(t *testing.T) {
require.True(t, oShutdown.IsSome()) require.True(t, oShutdown.IsSome())
require.True(t, oClosingSigned.IsNone()) require.True(t, oClosingSigned.IsNone())
bobShutdown := oShutdown.UnsafeFromSome()
// Alice should process the shutdown message, and create a closing // Alice should process the shutdown message, and create a closing
// signed of her own. // signed of her own.
oShutdown, err = aliceCloser.ReceiveShutdown(bobShutdown) oShutdown, err = aliceCloser.ReceiveShutdown(oShutdown.UnwrapOrFail(t))
require.NoError(t, err) require.NoError(t, err)
oClosingSigned, err = aliceCloser.BeginNegotiation() oClosingSigned, err = aliceCloser.BeginNegotiation()
require.NoError(t, err) require.NoError(t, err)
@ -554,7 +552,7 @@ func TestTaprootFastClose(t *testing.T) {
require.True(t, oShutdown.IsNone()) require.True(t, oShutdown.IsNone())
require.True(t, oClosingSigned.IsSome()) require.True(t, oClosingSigned.IsSome())
aliceClosingSigned := oClosingSigned.UnsafeFromSome() aliceClosingSigned := oClosingSigned.UnwrapOrFail(t)
// Next, Bob will process the closing signed message, and send back a // Next, Bob will process the closing signed message, and send back a
// new one that should match exactly the offer Alice sent. // new one that should match exactly the offer Alice sent.
@ -564,7 +562,7 @@ func TestTaprootFastClose(t *testing.T) {
require.NotNil(t, tx) require.NotNil(t, tx)
require.True(t, oClosingSigned.IsSome()) require.True(t, oClosingSigned.IsSome())
bobClosingSigned := oClosingSigned.UnsafeFromSome() bobClosingSigned := oClosingSigned.UnwrapOrFail(t)
// At this point, Bob has accepted the offer, so he can broadcast the // At this point, Bob has accepted the offer, so he can broadcast the
// closing transaction, and considers the channel closed. // closing transaction, and considers the channel closed.
@ -597,7 +595,7 @@ func TestTaprootFastClose(t *testing.T) {
require.NotNil(t, tx) require.NotNil(t, tx)
require.True(t, oClosingSigned.IsSome()) require.True(t, oClosingSigned.IsSome())
aliceClosingSigned = oClosingSigned.UnsafeFromSome() aliceClosingSigned = oClosingSigned.UnwrapOrFail(t)
// Alice should now also broadcast her closing transaction. // Alice should now also broadcast her closing transaction.
_, err = lnutils.RecvOrTimeout(broadcastSignal, time.Second*1) _, err = lnutils.RecvOrTimeout(broadcastSignal, time.Second*1)

View File

@ -127,6 +127,13 @@ var (
// we've likely lost data ourselves. // we've likely lost data ourselves.
ErrForceCloseLocalDataLoss = errors.New("cannot force close " + ErrForceCloseLocalDataLoss = errors.New("cannot force close " +
"channel with local data loss") "channel with local data loss")
// errNoNonce is returned when a nonce is required, but none is found.
errNoNonce = errors.New("no nonce found")
// errNoPartialSig is returned when a partial signature is required,
// but none is found.
errNoPartialSig = errors.New("no partial signature found")
) )
// ErrCommitSyncLocalDataLoss is returned in the case that we receive a valid // ErrCommitSyncLocalDataLoss is returned in the case that we receive a valid
@ -4128,7 +4135,7 @@ type CommitSigs struct {
// PartialSig is the musig2 partial signature for taproot commitment // PartialSig is the musig2 partial signature for taproot commitment
// transactions. // transactions.
PartialSig *lnwire.PartialSigWithNonce PartialSig lnwire.OptPartialSigWithNonceTLV
} }
// NewCommitState wraps the various signatures needed to properly // NewCommitState wraps the various signatures needed to properly
@ -4341,7 +4348,7 @@ func (lc *LightningChannel) SignNextCommitment() (*NewCommitState, error) {
CommitSigs: &CommitSigs{ CommitSigs: &CommitSigs{
CommitSig: sig, CommitSig: sig,
HtlcSigs: htlcSigs, HtlcSigs: htlcSigs,
PartialSig: partialSig, PartialSig: lnwire.MaybePartialSigWithNonce(partialSig),
}, },
PendingHTLCs: commitDiff.Commitment.Htlcs, PendingHTLCs: commitDiff.Commitment.Htlcs,
}, nil }, nil
@ -4416,19 +4423,24 @@ func (lc *LightningChannel) ProcessChanSyncMsg(
// bail out, otherwise we'll init our local session then continue as // bail out, otherwise we'll init our local session then continue as
// normal. // normal.
switch { switch {
case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce == nil: case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsNone():
return nil, nil, nil, fmt.Errorf("remote verification nonce " + return nil, nil, nil, fmt.Errorf("remote verification nonce " +
"not sent") "not sent")
case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce != nil: case lc.channelState.ChanType.IsTaproot() && msg.LocalNonce.IsSome():
if lc.opts.skipNonceInit { if lc.opts.skipNonceInit {
// Don't call InitRemoteMusigNonces if we have already // Don't call InitRemoteMusigNonces if we have already
// done so. // done so.
break break
} }
err := lc.InitRemoteMusigNonces(&musig2.Nonces{ nextNonce, err := msg.LocalNonce.UnwrapOrErrV(errNoNonce)
PubNonce: *msg.LocalNonce, if err != nil {
return nil, nil, nil, err
}
err = lc.InitRemoteMusigNonces(&musig2.Nonces{
PubNonce: nextNonce,
}) })
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("unable to init "+ return nil, nil, nil, fmt.Errorf("unable to init "+
@ -5227,6 +5239,13 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error {
if lc.channelState.ChanType.IsTaproot() { if lc.channelState.ChanType.IsTaproot() {
localSession := lc.musigSessions.LocalSession localSession := lc.musigSessions.LocalSession
partialSig, err := commitSigs.PartialSig.UnwrapOrErrV(
errNoPartialSig,
)
if err != nil {
return err
}
// As we want to ensure we never write nonces to disk, we'll // As we want to ensure we never write nonces to disk, we'll
// use the shachain state to generate a nonce for our next // use the shachain state to generate a nonce for our next
// local state. Similar to generateRevocation, we do height + 2 // local state. Similar to generateRevocation, we do height + 2
@ -5236,7 +5255,7 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error {
nextHeight+1, lc.taprootNonceProducer, nextHeight+1, lc.taprootNonceProducer,
) )
nextVerificationNonce, err := localSession.VerifyCommitSig( nextVerificationNonce, err := localSession.VerifyCommitSig(
localCommitTx, commitSigs.PartialSig, localCtrNonce, localCommitTx, &partialSig, localCtrNonce,
) )
if err != nil { if err != nil {
close(cancelChan) close(cancelChan)
@ -5352,9 +5371,16 @@ func (lc *LightningChannel) ReceiveNewCommitment(commitSigs *CommitSigs) error {
// serialize the ECDSA sig. For taproot channels, we'll serialize the // serialize the ECDSA sig. For taproot channels, we'll serialize the
// partial sig that includes the nonce that was used for signing. // partial sig that includes the nonce that was used for signing.
if lc.channelState.ChanType.IsTaproot() { if lc.channelState.ChanType.IsTaproot() {
partialSig, err := commitSigs.PartialSig.UnwrapOrErrV(
errNoPartialSig,
)
if err != nil {
return err
}
var sigBytes [lnwire.PartialSigWithNonceLen]byte var sigBytes [lnwire.PartialSigWithNonceLen]byte
b := bytes.NewBuffer(sigBytes[0:0]) b := bytes.NewBuffer(sigBytes[0:0])
if err := commitSigs.PartialSig.Encode(b); err != nil { if err := partialSig.Encode(b); err != nil {
return err return err
} }
@ -5783,19 +5809,21 @@ func (lc *LightningChannel) ReceiveRevocation(revMsg *lnwire.RevokeAndAck) (
// Now that we have a new verification nonce from them, we can refresh // Now that we have a new verification nonce from them, we can refresh
// our remote musig2 session which allows us to create another state. // our remote musig2 session which allows us to create another state.
if lc.channelState.ChanType.IsTaproot() { if lc.channelState.ChanType.IsTaproot() {
if revMsg.LocalNonce == nil { localNonce, err := revMsg.LocalNonce.UnwrapOrErrV(errNoNonce)
return nil, nil, nil, nil, fmt.Errorf("next " + if err != nil {
"revocation nonce not set") return nil, nil, nil, nil, err
} }
newRemoteSession, err := lc.musigSessions.RemoteSession.Refresh(
session, err := lc.musigSessions.RemoteSession.Refresh(
&musig2.Nonces{ &musig2.Nonces{
PubNonce: *revMsg.LocalNonce, PubNonce: localNonce,
}, },
) )
if err != nil { if err != nil {
return nil, nil, nil, nil, err return nil, nil, nil, nil, err
} }
lc.musigSessions.RemoteSession = newRemoteSession
lc.musigSessions.RemoteSession = session
} }
// At this point, the revocation has been accepted, and we've rotated // At this point, the revocation has been accepted, and we've rotated
@ -8506,8 +8534,8 @@ func (lc *LightningChannel) generateRevocation(height uint64) (*lnwire.RevokeAnd
if err != nil { if err != nil {
return nil, err return nil, err
} }
revocationMsg.LocalNonce = (*lnwire.Musig2Nonce)( revocationMsg.LocalNonce = lnwire.SomeMusig2Nonce(
&nextVerificationNonce.PubNonce, nextVerificationNonce.PubNonce,
) )
} }

View File

@ -2875,10 +2875,10 @@ func assertNoChanSyncNeeded(t *testing.T, aliceChannel *LightningChannel,
// nonces. // nonces.
if aliceChannel.channelState.ChanType.IsTaproot() { if aliceChannel.channelState.ChanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceChanSyncMsg.LocalNonce, PubNonce: aliceChanSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobChanSyncMsg.LocalNonce, PubNonce: bobChanSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
} }
@ -3486,10 +3486,10 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) {
// nonces. // nonces.
if chanType.IsTaproot() { if chanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceSyncMsg.LocalNonce, PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFail(t).Val,
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobSyncMsg.LocalNonce, PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFail(t).Val,
} }
} }
@ -3548,10 +3548,10 @@ func testChanSyncOweRevocation(t *testing.T, chanType channeldb.ChannelType) {
// nonces. // nonces.
if chanType.IsTaproot() { if chanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceSyncMsg.LocalNonce, PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobSyncMsg.LocalNonce, PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
} }
@ -3671,10 +3671,10 @@ func testChanSyncOweRevocationAndCommit(t *testing.T,
// nonces. // nonces.
if chanType.IsTaproot() { if chanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceSyncMsg.LocalNonce, PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobSyncMsg.LocalNonce, PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
} }
@ -3751,10 +3751,10 @@ func testChanSyncOweRevocationAndCommit(t *testing.T,
// nonces. // nonces.
if chanType.IsTaproot() { if chanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceSyncMsg.LocalNonce, PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobSyncMsg.LocalNonce, PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
} }
@ -3888,10 +3888,10 @@ func testChanSyncOweRevocationAndCommitForceTransition(t *testing.T,
// nonces. // nonces.
if chanType.IsTaproot() { if chanType.IsTaproot() {
aliceChannel.pendingVerificationNonce = &musig2.Nonces{ aliceChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *aliceSyncMsg.LocalNonce, PubNonce: aliceSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
bobChannel.pendingVerificationNonce = &musig2.Nonces{ bobChannel.pendingVerificationNonce = &musig2.Nonces{
PubNonce: *bobSyncMsg.LocalNonce, PubNonce: bobSyncMsg.LocalNonce.UnwrapOrFailV(t),
} }
} }

View File

@ -110,7 +110,7 @@ type AcceptChannel struct {
// verify the very first commitment transaction signature. // verify the very first commitment transaction signature.
// This will only be populated if the simple taproot channels type was // This will only be populated if the simple taproot channels type was
// negotiated. // negotiated.
LocalNonce *Musig2Nonce LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -141,9 +141,9 @@ func (a *AcceptChannel) Encode(w *bytes.Buffer, pver uint32) error {
if a.LeaseExpiry != nil { if a.LeaseExpiry != nil {
recordProducers = append(recordProducers, a.LeaseExpiry) recordProducers = append(recordProducers, a.LeaseExpiry)
} }
if a.LocalNonce != nil { a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, a.LocalNonce) recordProducers = append(recordProducers, &localNonce)
} })
err := EncodeMessageExtraData(&a.ExtraData, recordProducers...) err := EncodeMessageExtraData(&a.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err
@ -248,7 +248,7 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error {
var ( var (
chanType ChannelType chanType ChannelType
leaseExpiry LeaseExpiry leaseExpiry LeaseExpiry
localNonce Musig2Nonce localNonce = a.LocalNonce.Zero()
) )
typeMap, err := tlvRecords.ExtractRecords( typeMap, err := tlvRecords.ExtractRecords(
&a.UpfrontShutdownScript, &chanType, &leaseExpiry, &a.UpfrontShutdownScript, &chanType, &leaseExpiry,
@ -265,8 +265,8 @@ func (a *AcceptChannel) Decode(r io.Reader, pver uint32) error {
if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil {
a.LeaseExpiry = &leaseExpiry a.LeaseExpiry = &leaseExpiry
} }
if val, ok := typeMap[NonceRecordType]; ok && val == nil { if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil {
a.LocalNonce = &localNonce a.LocalNonce = tlv.SomeRecordT(localNonce)
} }
a.ExtraData = tlvRecords a.ExtraData = tlvRecords

View File

@ -31,7 +31,7 @@ type ChannelReady struct {
// This will only be populated if the simple taproot channels type was // This will only be populated if the simple taproot channels type was
// negotiated. This is the local nonce that will be used by the sender // negotiated. This is the local nonce that will be used by the sender
// to accept a new commitment state transition. // to accept a new commitment state transition.
NextLocalNonce *Musig2Nonce NextLocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -77,7 +77,7 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error {
// the AliasScidRecordType. // the AliasScidRecordType.
var ( var (
aliasScid ShortChannelID aliasScid ShortChannelID
localNonce Musig2Nonce localNonce = c.NextLocalNonce.Zero()
) )
typeMap, err := tlvRecords.ExtractRecords( typeMap, err := tlvRecords.ExtractRecords(
&aliasScid, &localNonce, &aliasScid, &localNonce,
@ -91,8 +91,8 @@ func (c *ChannelReady) Decode(r io.Reader, _ uint32) error {
if val, ok := typeMap[AliasScidRecordType]; ok && val == nil { if val, ok := typeMap[AliasScidRecordType]; ok && val == nil {
c.AliasScid = &aliasScid c.AliasScid = &aliasScid
} }
if val, ok := typeMap[NonceRecordType]; ok && val == nil { if val, ok := typeMap[c.NextLocalNonce.TlvType()]; ok && val == nil {
c.NextLocalNonce = &localNonce c.NextLocalNonce = tlv.SomeRecordT(localNonce)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {
@ -121,9 +121,9 @@ func (c *ChannelReady) Encode(w *bytes.Buffer, _ uint32) error {
if c.AliasScid != nil { if c.AliasScid != nil {
recordProducers = append(recordProducers, c.AliasScid) recordProducers = append(recordProducers, c.AliasScid)
} }
if c.NextLocalNonce != nil { c.NextLocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, c.NextLocalNonce) recordProducers = append(recordProducers, &localNonce)
} })
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err

View File

@ -82,8 +82,7 @@ type ChannelReestablish struct {
// This will only be populated if the simple taproot channels type was // This will only be populated if the simple taproot channels type was
// negotiated. // negotiated.
// //
// TODO(roasbeef): rename to verification nonce LocalNonce OptMusig2NonceTLV
LocalNonce *Musig2Nonce
// DynHeight is an optional field that stores the dynamic commitment // DynHeight is an optional field that stores the dynamic commitment
// negotiation height that is incremented upon successful completion of // negotiation height that is incremented upon successful completion of
@ -138,9 +137,9 @@ func (a *ChannelReestablish) Encode(w *bytes.Buffer, pver uint32) error {
} }
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if a.LocalNonce != nil { a.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, a.LocalNonce) recordProducers = append(recordProducers, &localNonce)
} })
a.DynHeight.WhenSome(func(h DynHeight) { a.DynHeight.WhenSome(func(h DynHeight) {
recordProducers = append(recordProducers, &h) recordProducers = append(recordProducers, &h)
}) })
@ -203,8 +202,10 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var localNonce Musig2Nonce var (
var dynHeight DynHeight dynHeight DynHeight
localNonce = a.LocalNonce.Zero()
)
typeMap, err := tlvRecords.ExtractRecords( typeMap, err := tlvRecords.ExtractRecords(
&localNonce, &dynHeight, &localNonce, &dynHeight,
) )
@ -212,8 +213,8 @@ func (a *ChannelReestablish) Decode(r io.Reader, pver uint32) error {
return err return err
} }
if val, ok := typeMap[NonceRecordType]; ok && val == nil { if val, ok := typeMap[a.LocalNonce.TlvType()]; ok && val == nil {
a.LocalNonce = &localNonce a.LocalNonce = tlv.SomeRecordT(localNonce)
} }
if val, ok := typeMap[CRDynHeight]; ok && val == nil { if val, ok := typeMap[CRDynHeight]; ok && val == nil {
a.DynHeight = fn.Some(dynHeight) a.DynHeight = fn.Some(dynHeight)

View File

@ -50,9 +50,9 @@ type ClosingComplete struct {
// decodeClosingSigs decodes the closing sig TLV records in the passed // decodeClosingSigs decodes the closing sig TLV records in the passed
// ExtraOpaqueData. // ExtraOpaqueData.
func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error { func decodeClosingSigs(c *ClosingSigs, tlvRecords ExtraOpaqueData) error {
sig1 := tlv.ZeroRecordT[tlv.TlvType1, Sig]() sig1 := c.CloserNoClosee.Zero()
sig2 := tlv.ZeroRecordT[tlv.TlvType2, Sig]() sig2 := c.NoCloserClosee.Zero()
sig3 := tlv.ZeroRecordT[tlv.TlvType3, Sig]() sig3 := c.CloserAndClosee.Zero()
typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3) typeMap, err := tlvRecords.ExtractRecords(&sig1, &sig2, &sig3)
if err != nil { if err != nil {

View File

@ -36,7 +36,7 @@ type ClosingSigned struct {
// //
// NOTE: This field is only populated if a musig2 taproot channel is // NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank. // being signed for. In this case, the above Sig type MUST be blank.
PartialSig *PartialSig PartialSig OptPartialSigTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -76,17 +76,15 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var ( partialSig := c.PartialSig.Zero()
partialSig PartialSig
)
typeMap, err := tlvRecords.ExtractRecords(&partialSig) typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[PartialSigRecordType]; ok && val == nil { if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil {
c.PartialSig = &partialSig c.PartialSig = tlv.SomeRecordT(partialSig)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {
@ -102,9 +100,9 @@ func (c *ClosingSigned) Decode(r io.Reader, pver uint32) error {
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error { func (c *ClosingSigned) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if c.PartialSig != nil { c.PartialSig.WhenSome(func(sig PartialSigTLV) {
recordProducers = append(recordProducers, c.PartialSig) recordProducers = append(recordProducers, &sig)
} })
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err

View File

@ -43,7 +43,7 @@ type CommitSig struct {
// //
// NOTE: This field is only populated if a musig2 taproot channel is // NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank. // being signed for. In this case, the above Sig type MUST be blank.
PartialSig *PartialSigWithNonce PartialSig OptPartialSigWithNonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -81,17 +81,15 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var ( partialSig := c.PartialSig.Zero()
partialSig PartialSigWithNonce
)
typeMap, err := tlvRecords.ExtractRecords(&partialSig) typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { if val, ok := typeMap[c.PartialSig.TlvType()]; ok && val == nil {
c.PartialSig = &partialSig c.PartialSig = tlv.SomeRecordT(partialSig)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {
@ -107,9 +105,9 @@ func (c *CommitSig) Decode(r io.Reader, pver uint32) error {
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error { func (c *CommitSig) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if c.PartialSig != nil { c.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, c.PartialSig) recordProducers = append(recordProducers, &sig)
} })
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err

View File

@ -32,7 +32,7 @@ type FundingCreated struct {
// //
// NOTE: This field is only populated if a musig2 taproot channel is // NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank. // being signed for. In this case, the above Sig type MUST be blank.
PartialSig *PartialSigWithNonce PartialSig OptPartialSigWithNonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -51,9 +51,9 @@ var _ Message = (*FundingCreated)(nil)
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (f *FundingCreated) Encode(w *bytes.Buffer, pver uint32) error { func (f *FundingCreated) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if f.PartialSig != nil { f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, f.PartialSig) recordProducers = append(recordProducers, &sig)
} })
err := EncodeMessageExtraData(&f.ExtraData, recordProducers...) err := EncodeMessageExtraData(&f.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err
@ -92,17 +92,15 @@ func (f *FundingCreated) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var ( partialSig := f.PartialSig.Zero()
partialSig PartialSigWithNonce
)
typeMap, err := tlvRecords.ExtractRecords(&partialSig) typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil {
f.PartialSig = &partialSig f.PartialSig = tlv.SomeRecordT(partialSig)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {

View File

@ -24,7 +24,7 @@ type FundingSigned struct {
// //
// NOTE: This field is only populated if a musig2 taproot channel is // NOTE: This field is only populated if a musig2 taproot channel is
// being signed for. In this case, the above Sig type MUST be blank. // being signed for. In this case, the above Sig type MUST be blank.
PartialSig *PartialSigWithNonce PartialSig OptPartialSigWithNonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -43,9 +43,9 @@ var _ Message = (*FundingSigned)(nil)
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (f *FundingSigned) Encode(w *bytes.Buffer, pver uint32) error { func (f *FundingSigned) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if f.PartialSig != nil { f.PartialSig.WhenSome(func(sig PartialSigWithNonceTLV) {
recordProducers = append(recordProducers, f.PartialSig) recordProducers = append(recordProducers, &sig)
} })
err := EncodeMessageExtraData(&f.ExtraData, recordProducers...) err := EncodeMessageExtraData(&f.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err
@ -78,17 +78,15 @@ func (f *FundingSigned) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var ( partialSig := f.PartialSig.Zero()
partialSig PartialSigWithNonce
)
typeMap, err := tlvRecords.ExtractRecords(&partialSig) typeMap, err := tlvRecords.ExtractRecords(&partialSig)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[PartialSigWithNonceRecordType]; ok && val == nil { if val, ok := typeMap[f.PartialSig.TlvType()]; ok && val == nil {
f.PartialSig = &partialSig f.PartialSig = tlv.SomeRecordT(partialSig)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {

View File

@ -44,11 +44,19 @@ var (
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func randLocalNonce(r *rand.Rand) *Musig2Nonce { func randLocalNonce(r *rand.Rand) Musig2Nonce {
var nonce Musig2Nonce var nonce Musig2Nonce
_, _ = io.ReadFull(r, nonce[:]) _, _ = io.ReadFull(r, nonce[:])
return &nonce return nonce
}
func someLocalNonce[T tlv.TlvType](
r *rand.Rand) tlv.OptionalRecordT[T, Musig2Nonce] {
return tlv.SomeRecordT(tlv.NewRecordT[T, Musig2Nonce](
randLocalNonce(r),
))
} }
func randPartialSig(r *rand.Rand) (*PartialSig, error) { func randPartialSig(r *rand.Rand) (*PartialSig, error) {
@ -65,6 +73,19 @@ func randPartialSig(r *rand.Rand) (*PartialSig, error) {
}, nil }, nil
} }
func somePartialSig(t *testing.T,
r *rand.Rand) tlv.OptionalRecordT[PartialSigType, PartialSig] {
sig, err := randPartialSig(r)
if err != nil {
t.Fatal(err)
}
return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig](
*sig,
))
}
func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) { func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) {
var sigBytes [32]byte var sigBytes [32]byte
if _, err := r.Read(sigBytes[:]); err != nil { if _, err := r.Read(sigBytes[:]); err != nil {
@ -76,10 +97,25 @@ func randPartialSigWithNonce(r *rand.Rand) (*PartialSigWithNonce, error) {
return &PartialSigWithNonce{ return &PartialSigWithNonce{
PartialSig: NewPartialSig(s), PartialSig: NewPartialSig(s),
Nonce: *randLocalNonce(r), Nonce: randLocalNonce(r),
}, nil }, nil
} }
func somePartialSigWithNonce(t *testing.T,
r *rand.Rand) OptPartialSigWithNonceTLV {
sig, err := randPartialSigWithNonce(r)
if err != nil {
t.Fatal(err)
}
return tlv.SomeRecordT(
tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce](
*sig,
),
)
}
func randAlias(r *rand.Rand) NodeAlias { func randAlias(r *rand.Rand) NodeAlias {
var a NodeAlias var a NodeAlias
for i := range a { for i := range a {
@ -480,7 +516,8 @@ func TestLightningWireProtocol(t *testing.T) {
req.LeaseExpiry = new(LeaseExpiry) req.LeaseExpiry = new(LeaseExpiry)
*req.LeaseExpiry = LeaseExpiry(1337) *req.LeaseExpiry = LeaseExpiry(1337)
req.LocalNonce = randLocalNonce(r) //nolint:lll
req.LocalNonce = someLocalNonce[NonceRecordTypeT](r)
} else { } else {
req.UpfrontShutdownScript = []byte{} req.UpfrontShutdownScript = []byte{}
} }
@ -554,7 +591,8 @@ func TestLightningWireProtocol(t *testing.T) {
req.LeaseExpiry = new(LeaseExpiry) req.LeaseExpiry = new(LeaseExpiry)
*req.LeaseExpiry = LeaseExpiry(1337) *req.LeaseExpiry = LeaseExpiry(1337)
req.LocalNonce = randLocalNonce(r) //nolint:lll
req.LocalNonce = someLocalNonce[NonceRecordTypeT](r)
} else { } else {
req.UpfrontShutdownScript = []byte{} req.UpfrontShutdownScript = []byte{}
} }
@ -591,12 +629,7 @@ func TestLightningWireProtocol(t *testing.T) {
// 1/2 chance to attach a partial sig. // 1/2 chance to attach a partial sig.
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
req.PartialSig, err = randPartialSigWithNonce(r) req.PartialSig = somePartialSigWithNonce(t, r)
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
} }
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
@ -621,12 +654,7 @@ func TestLightningWireProtocol(t *testing.T) {
// 1/2 chance to attach a partial sig. // 1/2 chance to attach a partial sig.
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
req.PartialSig, err = randPartialSigWithNonce(r) req.PartialSig = somePartialSigWithNonce(t, r)
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
} }
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
@ -649,7 +677,9 @@ func TestLightningWireProtocol(t *testing.T) {
if r.Int31()%2 == 0 { if r.Int31()%2 == 0 {
scid := NewShortChanIDFromInt(uint64(r.Int63())) scid := NewShortChanIDFromInt(uint64(r.Int63()))
req.AliasScid = &scid req.AliasScid = &scid
req.NextLocalNonce = randLocalNonce(r)
//nolint:lll
req.NextLocalNonce = someLocalNonce[NonceRecordTypeT](r)
} }
v[0] = reflect.ValueOf(*req) v[0] = reflect.ValueOf(*req)
@ -676,9 +706,8 @@ func TestLightningWireProtocol(t *testing.T) {
} }
if r.Int31()%2 == 0 { if r.Int31()%2 == 0 {
req.ShutdownNonce = (*ShutdownNonce)( //nolint:lll
randLocalNonce(r), req.ShutdownNonce = someLocalNonce[ShutdownNonceType](r)
)
} }
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
@ -701,12 +730,7 @@ func TestLightningWireProtocol(t *testing.T) {
} }
if r.Int31()%2 == 0 { if r.Int31()%2 == 0 {
req.PartialSig, err = randPartialSig(r) req.PartialSig = somePartialSig(t, r)
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
} }
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
@ -854,12 +878,7 @@ func TestLightningWireProtocol(t *testing.T) {
// 50/50 chance to attach a partial sig. // 50/50 chance to attach a partial sig.
if r.Int31()%2 == 0 { if r.Int31()%2 == 0 {
req.PartialSig, err = randPartialSigWithNonce(r) req.PartialSig = somePartialSigWithNonce(t, r)
if err != nil {
t.Fatalf("unable to generate sig: %v",
err)
return
}
} }
v[0] = reflect.ValueOf(*req) v[0] = reflect.ValueOf(*req)
@ -883,7 +902,8 @@ func TestLightningWireProtocol(t *testing.T) {
// 50/50 chance to attach a local nonce. // 50/50 chance to attach a local nonce.
if r.Int31()%2 == 0 { if r.Int31()%2 == 0 {
req.LocalNonce = randLocalNonce(r) //nolint:lll
req.LocalNonce = someLocalNonce[NonceRecordTypeT](r)
} }
v[0] = reflect.ValueOf(*req) v[0] = reflect.ValueOf(*req)
@ -1107,7 +1127,8 @@ func TestLightningWireProtocol(t *testing.T) {
return return
} }
req.LocalNonce = randLocalNonce(r) //nolint:lll
req.LocalNonce = someLocalNonce[NonceRecordTypeT](r)
} }
v[0] = reflect.ValueOf(req) v[0] = reflect.ValueOf(req)
@ -1232,7 +1253,7 @@ func TestLightningWireProtocol(t *testing.T) {
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() sig := req.CloserNoClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",
@ -1243,7 +1264,7 @@ func TestLightningWireProtocol(t *testing.T) {
req.CloserNoClosee = tlv.SomeRecordT(sig) req.CloserNoClosee = tlv.SomeRecordT(sig)
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() sig := req.NoCloserClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",
@ -1254,7 +1275,7 @@ func TestLightningWireProtocol(t *testing.T) {
req.NoCloserClosee = tlv.SomeRecordT(sig) req.NoCloserClosee = tlv.SomeRecordT(sig)
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() sig := req.CloserAndClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",
@ -1281,7 +1302,7 @@ func TestLightningWireProtocol(t *testing.T) {
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType1, Sig]() sig := req.CloserNoClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",
@ -1292,7 +1313,7 @@ func TestLightningWireProtocol(t *testing.T) {
req.CloserNoClosee = tlv.SomeRecordT(sig) req.CloserNoClosee = tlv.SomeRecordT(sig)
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType2, Sig]() sig := req.NoCloserClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",
@ -1303,7 +1324,7 @@ func TestLightningWireProtocol(t *testing.T) {
req.NoCloserClosee = tlv.SomeRecordT(sig) req.NoCloserClosee = tlv.SomeRecordT(sig)
} }
if r.Intn(2) == 0 { if r.Intn(2) == 0 {
sig := tlv.ZeroRecordT[tlv.TlvType3, Sig]() sig := req.CloserAndClosee.Zero()
_, err := r.Read(sig.Val.bytes[:]) _, err := r.Read(sig.Val.bytes[:])
if err != nil { if err != nil {
t.Fatalf("unable to generate sig: %v", t.Fatalf("unable to generate sig: %v",

View File

@ -7,21 +7,33 @@ import (
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
const ( // NonceRecordTypeT is the TLV type used to encode a local musig2 nonce.
// NonceRecordType is the TLV type used to encode a local musig2 nonce. type NonceRecordTypeT = tlv.TlvType4
NonceRecordType tlv.Type = 4
)
// Musig2Nonce represents a musig2 public nonce, which is the concatenation of // nonceRecordType is the TLV (integer) type used to encode a local musig2
// two EC points serialized in compressed format. // nonce.
type Musig2Nonce [musig2.PubNonceSize]byte var nonceRecordType tlv.Type = (NonceRecordTypeT)(nil).TypeVal()
type (
// Musig2Nonce represents a musig2 public nonce, which is the
// concatenation of two EC points serialized in compressed format.
Musig2Nonce [musig2.PubNonceSize]byte
// Musig2NonceTLV is a TLV type that can be used to encode/decode a
// musig2 nonce. This is an optional TLV.
Musig2NonceTLV = tlv.RecordT[NonceRecordTypeT, Musig2Nonce]
// OptMusig2NonceTLV is a TLV type that can be used to encode/decode a
// musig2 nonce.
OptMusig2NonceTLV = tlv.OptionalRecordT[NonceRecordTypeT, Musig2Nonce]
)
// Record returns a TLV record that can be used to encode/decode the musig2 // Record returns a TLV record that can be used to encode/decode the musig2
// nonce from a given TLV stream. // nonce from a given TLV stream.
func (m *Musig2Nonce) Record() tlv.Record { func (m *Musig2Nonce) Record() tlv.Record {
return tlv.MakeStaticRecord( return tlv.MakeStaticRecord(
NonceRecordType, m, musig2.PubNonceSize, nonceTypeEncoder, nonceRecordType, m, musig2.PubNonceSize,
nonceTypeDecoder, nonceTypeEncoder, nonceTypeDecoder,
) )
} }
@ -48,3 +60,10 @@ func nonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte,
val, "lnwire.Musig2Nonce", l, musig2.PubNonceSize, val, "lnwire.Musig2Nonce", l, musig2.PubNonceSize,
) )
} }
// SomeMusig2Nonce is a helper function that creates a musig2 nonce TLV.
func SomeMusig2Nonce(nonce Musig2Nonce) OptMusig2NonceTLV {
return tlv.SomeRecordT(
tlv.NewRecordT[NonceRecordTypeT, Musig2Nonce](nonce),
)
}

View File

@ -146,7 +146,7 @@ type OpenChannel struct {
// verify the very first commitment transaction signature. This will // verify the very first commitment transaction signature. This will
// only be populated if the simple taproot channels type was // only be populated if the simple taproot channels type was
// negotiated. // negotiated.
LocalNonce *Musig2Nonce LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -175,9 +175,9 @@ func (o *OpenChannel) Encode(w *bytes.Buffer, pver uint32) error {
if o.LeaseExpiry != nil { if o.LeaseExpiry != nil {
recordProducers = append(recordProducers, o.LeaseExpiry) recordProducers = append(recordProducers, o.LeaseExpiry)
} }
if o.LocalNonce != nil { o.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, o.LocalNonce) recordProducers = append(recordProducers, &localNonce)
} })
err := EncodeMessageExtraData(&o.ExtraData, recordProducers...) err := EncodeMessageExtraData(&o.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err
@ -302,7 +302,7 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error {
var ( var (
chanType ChannelType chanType ChannelType
leaseExpiry LeaseExpiry leaseExpiry LeaseExpiry
localNonce Musig2Nonce localNonce = o.LocalNonce.Zero()
) )
typeMap, err := tlvRecords.ExtractRecords( typeMap, err := tlvRecords.ExtractRecords(
&o.UpfrontShutdownScript, &chanType, &leaseExpiry, &o.UpfrontShutdownScript, &chanType, &leaseExpiry,
@ -319,8 +319,8 @@ func (o *OpenChannel) Decode(r io.Reader, pver uint32) error {
if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil { if val, ok := typeMap[LeaseExpiryRecordType]; ok && val == nil {
o.LeaseExpiry = &leaseExpiry o.LeaseExpiry = &leaseExpiry
} }
if val, ok := typeMap[NonceRecordType]; ok && val == nil { if val, ok := typeMap[o.LocalNonce.TlvType()]; ok && val == nil {
o.LocalNonce = &localNonce o.LocalNonce = tlv.SomeRecordT(localNonce)
} }
o.ExtraData = tlvRecords o.ExtraData = tlvRecords

View File

@ -11,11 +11,20 @@ import (
const ( const (
// PartialSigLen is the length of a musig2 partial signature. // PartialSigLen is the length of a musig2 partial signature.
PartialSigLen = 32 PartialSigLen = 32
)
// PartialSigRecordType is the type of the tlv record for a musig2 type (
// PartialSigType is the type of the tlv record for a musig2
// partial signature. This is an _even_ type, which means it's required // partial signature. This is an _even_ type, which means it's required
// if included. // if included.
PartialSigRecordType tlv.Type = 6 PartialSigType = tlv.TlvType6
// PartialSigTLV is a tlv record for a musig2 partial signature.
PartialSigTLV = tlv.RecordT[PartialSigType, PartialSig]
// OptPartialSigTLV is a tlv record for a musig2 partial signature.
// This is an optional record type.
OptPartialSigTLV = tlv.OptionalRecordT[PartialSigType, PartialSig]
) )
// PartialSig is the base partial sig type. This only encodes the 32-byte // PartialSig is the base partial sig type. This only encodes the 32-byte
@ -36,7 +45,7 @@ func NewPartialSig(sig btcec.ModNScalar) PartialSig {
// Record returns the tlv record for the partial sig. // Record returns the tlv record for the partial sig.
func (p *PartialSig) Record() tlv.Record { func (p *PartialSig) Record() tlv.Record {
return tlv.MakeStaticRecord( return tlv.MakeStaticRecord(
PartialSigRecordType, p, PartialSigLen, (PartialSigType)(nil).TypeVal(), p, PartialSigLen,
partialSigTypeEncoder, partialSigTypeDecoder, partialSigTypeEncoder, partialSigTypeDecoder,
) )
} }
@ -88,16 +97,35 @@ func (p *PartialSig) Decode(r io.Reader) error {
return partialSigTypeDecoder(r, p, nil, PartialSigLen) return partialSigTypeDecoder(r, p, nil, PartialSigLen)
} }
// SomePartialSig is a helper function that returns an otional PartialSig.
func SomePartialSig(sig PartialSig) OptPartialSigTLV {
return tlv.SomeRecordT(tlv.NewRecordT[PartialSigType, PartialSig](sig))
}
const ( const (
// PartialSigWithNonceLen is the length of a serialized // PartialSigWithNonceLen is the length of a serialized
// PartialSigWithNonce. The sig is encoded as the 32 byte S value // PartialSigWithNonce. The sig is encoded as the 32 byte S value
// followed by the 66 nonce value. // followed by the 66 nonce value.
PartialSigWithNonceLen = 98 PartialSigWithNonceLen = 98
)
// PartialSigWithNonceRecordType is the type of the tlv record for a type (
// musig2 partial signature with nonce. This is an _even_ type, which // PartialSigWithNonceType is the type of the tlv record for a musig2
// means it's required if included. // partial signature with nonce. This is an _even_ type, which means
PartialSigWithNonceRecordType tlv.Type = 2 // it's required if included.
PartialSigWithNonceType = tlv.TlvType2
// PartialSigWithNonceTLV is a tlv record for a musig2 partial
// signature.
PartialSigWithNonceTLV = tlv.RecordT[
PartialSigWithNonceType, PartialSigWithNonce,
]
// OptPartialSigWithNonceTLV is a tlv record for a musig2 partial
// signature. This is an optional record type.
OptPartialSigWithNonceTLV = tlv.OptionalRecordT[
PartialSigWithNonceType, PartialSigWithNonce,
]
) )
// PartialSigWithNonce is a partial signature with the nonce that was used to // PartialSigWithNonce is a partial signature with the nonce that was used to
@ -129,8 +157,9 @@ func NewPartialSigWithNonce(nonce [musig2.PubNonceSize]byte,
// Record returns the tlv record for the partial sig with nonce. // Record returns the tlv record for the partial sig with nonce.
func (p *PartialSigWithNonce) Record() tlv.Record { func (p *PartialSigWithNonce) Record() tlv.Record {
return tlv.MakeStaticRecord( return tlv.MakeStaticRecord(
PartialSigWithNonceRecordType, p, PartialSigWithNonceLen, (PartialSigWithNonceType)(nil).TypeVal(), p,
partialSigWithNonceTypeEncoder, partialSigWithNonceTypeDecoder, PartialSigWithNonceLen, partialSigWithNonceTypeEncoder,
partialSigWithNonceTypeDecoder,
) )
} }
@ -199,3 +228,20 @@ func (p *PartialSigWithNonce) Decode(r io.Reader) error {
r, p, nil, PartialSigWithNonceLen, r, p, nil, PartialSigWithNonceLen,
) )
} }
// MaybePartialSigWithNonce is a helper function that returns an optional
// PartialSigWithNonceTLV.
func MaybePartialSigWithNonce(sig *PartialSigWithNonce,
) OptPartialSigWithNonceTLV {
if sig == nil {
var none OptPartialSigWithNonceTLV
return none
}
return tlv.SomeRecordT(
tlv.NewRecordT[PartialSigWithNonceType, PartialSigWithNonce](
*sig,
),
)
}

View File

@ -36,7 +36,7 @@ type RevokeAndAck struct {
// LocalNonce is the next _local_ nonce for the sending party. This // LocalNonce is the next _local_ nonce for the sending party. This
// allows the receiving party to propose a new commitment using their // allows the receiving party to propose a new commitment using their
// remote nonce and the sender's local nonce. // remote nonce and the sender's local nonce.
LocalNonce *Musig2Nonce LocalNonce OptMusig2NonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -74,15 +74,15 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var musigNonce Musig2Nonce localNonce := c.LocalNonce.Zero()
typeMap, err := tlvRecords.ExtractRecords(&musigNonce) typeMap, err := tlvRecords.ExtractRecords(&localNonce)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[NonceRecordType]; ok && val == nil { if val, ok := typeMap[c.LocalNonce.TlvType()]; ok && val == nil {
c.LocalNonce = &musigNonce c.LocalNonce = tlv.SomeRecordT(localNonce)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {
@ -98,9 +98,9 @@ func (c *RevokeAndAck) Decode(r io.Reader, pver uint32) error {
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error { func (c *RevokeAndAck) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if c.LocalNonce != nil { c.LocalNonce.WhenSome(func(localNonce Musig2NonceTLV) {
recordProducers = append(recordProducers, c.LocalNonce) recordProducers = append(recordProducers, &localNonce)
} })
err := EncodeMessageExtraData(&c.ExtraData, recordProducers...) err := EncodeMessageExtraData(&c.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err

View File

@ -4,52 +4,21 @@ import (
"bytes" "bytes"
"io" "io"
"github.com/btcsuite/btcd/btcec/v2/schnorr/musig2"
"github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/tlv"
) )
const ( type (
// ShutdownNonceRecordType is the type of the shutdown nonce TLV record. // ShutdownNonceType is the type of the shutdown nonce TLV record.
ShutdownNonceRecordType = 8 ShutdownNonceType = tlv.TlvType8
// ShutdownNonceTLV is the TLV record that contains the shutdown nonce.
ShutdownNonceTLV = tlv.OptionalRecordT[ShutdownNonceType, Musig2Nonce]
) )
// ShutdownNonce is the type of the nonce we send during the shutdown flow. // SomeShutdownNonce returns a ShutdownNonceTLV with the given nonce.
// Unlike the other nonces, this nonce is symmetric w.r.t the message being func SomeShutdownNonce(nonce Musig2Nonce) ShutdownNonceTLV {
// signed (there's only one message for shutdown: the co-op close txn). return tlv.SomeRecordT(
type ShutdownNonce Musig2Nonce tlv.NewRecordT[ShutdownNonceType, Musig2Nonce](nonce),
// Record returns a TLV record that can be used to encode/decode the musig2
// nonce from a given TLV stream.
func (s *ShutdownNonce) Record() tlv.Record {
return tlv.MakeStaticRecord(
ShutdownNonceRecordType, s, musig2.PubNonceSize,
shutdownNonceTypeEncoder, shutdownNonceTypeDecoder,
)
}
// shutdownNonceTypeEncoder is a custom TLV encoder for the Musig2Nonce type.
func shutdownNonceTypeEncoder(w io.Writer, val interface{},
_ *[8]byte) error {
if v, ok := val.(*ShutdownNonce); ok {
_, err := w.Write(v[:])
return err
}
return tlv.NewTypeForEncodingErr(val, "lnwire.Musig2Nonce")
}
// shutdownNonceTypeDecoder is a custom TLV decoder for the Musig2Nonce record.
func shutdownNonceTypeDecoder(r io.Reader, val interface{}, _ *[8]byte,
l uint64) error {
if v, ok := val.(*ShutdownNonce); ok {
_, err := io.ReadFull(r, v[:])
return err
}
return tlv.NewTypeForDecodingErr(
val, "lnwire.ShutdownNonce", l, musig2.PubNonceSize,
) )
} }
@ -67,7 +36,7 @@ type Shutdown struct {
// ShutdownNonce is the nonce the sender will use to sign the first // ShutdownNonce is the nonce the sender will use to sign the first
// co-op sign offer. // co-op sign offer.
ShutdownNonce *ShutdownNonce ShutdownNonce ShutdownNonceTLV
// ExtraData is the set of data that was appended to this message to // ExtraData is the set of data that was appended to this message to
// fill out the full maximum transport message size. These fields can // fill out the full maximum transport message size. These fields can
@ -102,15 +71,15 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
return err return err
} }
var musigNonce ShutdownNonce musigNonce := s.ShutdownNonce.Zero()
typeMap, err := tlvRecords.ExtractRecords(&musigNonce) typeMap, err := tlvRecords.ExtractRecords(&musigNonce)
if err != nil { if err != nil {
return err return err
} }
// Set the corresponding TLV types if they were included in the stream. // Set the corresponding TLV types if they were included in the stream.
if val, ok := typeMap[ShutdownNonceRecordType]; ok && val == nil { if val, ok := typeMap[s.ShutdownNonce.TlvType()]; ok && val == nil {
s.ShutdownNonce = &musigNonce s.ShutdownNonce = tlv.SomeRecordT(musigNonce)
} }
if len(tlvRecords) != 0 { if len(tlvRecords) != 0 {
@ -126,9 +95,11 @@ func (s *Shutdown) Decode(r io.Reader, pver uint32) error {
// This is part of the lnwire.Message interface. // This is part of the lnwire.Message interface.
func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error { func (s *Shutdown) Encode(w *bytes.Buffer, pver uint32) error {
recordProducers := make([]tlv.RecordProducer, 0, 1) recordProducers := make([]tlv.RecordProducer, 0, 1)
if s.ShutdownNonce != nil { s.ShutdownNonce.WhenSome(
recordProducers = append(recordProducers, s.ShutdownNonce) func(nonce tlv.RecordT[ShutdownNonceType, Musig2Nonce]) {
} recordProducers = append(recordProducers, &nonce)
},
)
err := EncodeMessageExtraData(&s.ExtraData, recordProducers...) err := EncodeMessageExtraData(&s.ExtraData, recordProducers...)
if err != nil { if err != nil {
return err return err