From edf959d39f430858f974b5133805cb0d6e0bc533 Mon Sep 17 00:00:00 2001 From: Olaoluwa Osuntokun Date: Wed, 13 Mar 2024 09:15:54 -0400 Subject: [PATCH] channeldb: add optional TapscriptRoot field + feature bit --- channeldb/channel.go | 33 +++++++++++++++++++++++++++++++++ channeldb/channel_test.go | 8 +++++++- 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/channeldb/channel.go b/channeldb/channel.go index ec449613e..f29a9ec19 100644 --- a/channeldb/channel.go +++ b/channeldb/channel.go @@ -252,6 +252,10 @@ type openChannelTlvData struct { // memo is an optional text field that gives context to the user about // the channel. memo tlv.OptionalRecordT[tlv.TlvType5, []byte] + + // tapscriptRoot is the optional Tapscript root the channel funding + // output commits to. + tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte] } // encode serializes the openChannelTlvData to the given io.Writer. @@ -265,6 +269,11 @@ func (c *openChannelTlvData) encode(w io.Writer) error { c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { tlvRecords = append(tlvRecords, memo.Record()) }) + c.tapscriptRoot.WhenSome( + func(root tlv.RecordT[tlv.TlvType6, [32]byte]) { + tlvRecords = append(tlvRecords, root.Record()) + }, + ) // Create the tlv stream. tlvStream, err := tlv.NewStream(tlvRecords...) @@ -278,6 +287,7 @@ func (c *openChannelTlvData) encode(w io.Writer) error { // decode deserializes the openChannelTlvData from the given io.Reader. func (c *openChannelTlvData) decode(r io.Reader) error { memo := c.memo.Zero() + tapscriptRoot := c.tapscriptRoot.Zero() // Create the tlv stream. tlvStream, err := tlv.NewStream( @@ -286,6 +296,7 @@ func (c *openChannelTlvData) decode(r io.Reader) error { c.initialRemoteBalance.Record(), c.realScid.Record(), memo.Record(), + tapscriptRoot.Record(), ) if err != nil { return err @@ -299,6 +310,9 @@ func (c *openChannelTlvData) decode(r io.Reader) error { if _, ok := tlvs[memo.TlvType()]; ok { c.memo = tlv.SomeRecordT(memo) } + if _, ok := tlvs[tapscriptRoot.TlvType()]; ok { + c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot) + } return nil } @@ -380,6 +394,11 @@ const ( // SimpleTaprootFeatureBit indicates that the simple-taproot-chans // feature bit was negotiated during the lifetime of the channel. SimpleTaprootFeatureBit ChannelType = 1 << 10 + + // TapscriptRootBit indicates that this is a MuSig2 channel with a top + // level tapscript commitment. This MUST be set along with the + // SimpleTaprootFeatureBit. + TapscriptRootBit ChannelType = 1 << 11 ) // IsSingleFunder returns true if the channel type if one of the known single @@ -450,6 +469,12 @@ func (c ChannelType) IsTaproot() bool { return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit } +// HasTapscriptRoot returns true if the channel is using a top level tapscript +// root commitment. +func (c ChannelType) HasTapscriptRoot() bool { + return c&TapscriptRootBit == TapscriptRootBit +} + // ChannelStateBounds are the parameters from OpenChannel and AcceptChannel // that are responsible for providing bounds on the state space of the abstract // channel state. These values must be remembered for normal channel operation @@ -1098,6 +1123,9 @@ func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) { auxData.memo.WhenSomeV(func(memo []byte) { c.Memo = memo }) + auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) { + c.TapscriptRoot = fn.Some[chainhash.Hash](h) + }) } // extractTlvData creates a new openChannelTlvData from the given channel. @@ -1122,6 +1150,11 @@ func (c *OpenChannel) extractTlvData() openChannelTlvData { tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), ) } + c.TapscriptRoot.WhenSome(func(h chainhash.Hash) { + auxData.tapscriptRoot = tlv.SomeRecordT( + tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h), + ) + }) return auxData } diff --git a/channeldb/channel_test.go b/channeldb/channel_test.go index a7f3c1ebe..84aae9622 100644 --- a/channeldb/channel_test.go +++ b/channeldb/channel_test.go @@ -17,6 +17,7 @@ import ( "github.com/davecgh/go-spew/spew" "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/clock" + "github.com/lightningnetwork/lnd/fn" "github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnmock" @@ -173,7 +174,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption { } // channelIDOption is an option which sets the short channel ID of the channel. -var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { +func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption { return func(params *testChannelParams) { params.channel.ShortChannelID = chanID } @@ -326,6 +327,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { uniqueOutputIndex.Add(1) op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} + var tapscriptRoot chainhash.Hash + copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32)) + return &OpenChannel{ ChanType: SingleFunderBit | FrozenBit, ChainHash: key, @@ -368,6 +372,8 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel { ThawHeight: uint32(defaultPendingHeight), InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialRemoteBalance: lnwire.MilliSatoshi(3000), + Memo: []byte("test"), + TapscriptRoot: fn.Some(tapscriptRoot), } }