channeldb: add optional TapscriptRoot field + feature bit

This commit is contained in:
Olaoluwa Osuntokun 2024-03-13 09:15:54 -04:00 committed by Oliver Gugger
parent ecca095a9b
commit edf959d39f
No known key found for this signature in database
GPG Key ID: 8E4256593F177720
2 changed files with 40 additions and 1 deletions

View File

@ -252,6 +252,10 @@ type openChannelTlvData struct {
// memo is an optional text field that gives context to the user about // memo is an optional text field that gives context to the user about
// the channel. // the channel.
memo tlv.OptionalRecordT[tlv.TlvType5, []byte] memo tlv.OptionalRecordT[tlv.TlvType5, []byte]
// tapscriptRoot is the optional Tapscript root the channel funding
// output commits to.
tapscriptRoot tlv.OptionalRecordT[tlv.TlvType6, [32]byte]
} }
// encode serializes the openChannelTlvData to the given io.Writer. // encode serializes the openChannelTlvData to the given io.Writer.
@ -265,6 +269,11 @@ func (c *openChannelTlvData) encode(w io.Writer) error {
c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) { c.memo.WhenSome(func(memo tlv.RecordT[tlv.TlvType5, []byte]) {
tlvRecords = append(tlvRecords, memo.Record()) tlvRecords = append(tlvRecords, memo.Record())
}) })
c.tapscriptRoot.WhenSome(
func(root tlv.RecordT[tlv.TlvType6, [32]byte]) {
tlvRecords = append(tlvRecords, root.Record())
},
)
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := tlv.NewStream(tlvRecords...) tlvStream, err := tlv.NewStream(tlvRecords...)
@ -278,6 +287,7 @@ func (c *openChannelTlvData) encode(w io.Writer) error {
// decode deserializes the openChannelTlvData from the given io.Reader. // decode deserializes the openChannelTlvData from the given io.Reader.
func (c *openChannelTlvData) decode(r io.Reader) error { func (c *openChannelTlvData) decode(r io.Reader) error {
memo := c.memo.Zero() memo := c.memo.Zero()
tapscriptRoot := c.tapscriptRoot.Zero()
// Create the tlv stream. // Create the tlv stream.
tlvStream, err := tlv.NewStream( tlvStream, err := tlv.NewStream(
@ -286,6 +296,7 @@ func (c *openChannelTlvData) decode(r io.Reader) error {
c.initialRemoteBalance.Record(), c.initialRemoteBalance.Record(),
c.realScid.Record(), c.realScid.Record(),
memo.Record(), memo.Record(),
tapscriptRoot.Record(),
) )
if err != nil { if err != nil {
return err return err
@ -299,6 +310,9 @@ func (c *openChannelTlvData) decode(r io.Reader) error {
if _, ok := tlvs[memo.TlvType()]; ok { if _, ok := tlvs[memo.TlvType()]; ok {
c.memo = tlv.SomeRecordT(memo) c.memo = tlv.SomeRecordT(memo)
} }
if _, ok := tlvs[tapscriptRoot.TlvType()]; ok {
c.tapscriptRoot = tlv.SomeRecordT(tapscriptRoot)
}
return nil return nil
} }
@ -380,6 +394,11 @@ const (
// SimpleTaprootFeatureBit indicates that the simple-taproot-chans // SimpleTaprootFeatureBit indicates that the simple-taproot-chans
// feature bit was negotiated during the lifetime of the channel. // feature bit was negotiated during the lifetime of the channel.
SimpleTaprootFeatureBit ChannelType = 1 << 10 SimpleTaprootFeatureBit ChannelType = 1 << 10
// TapscriptRootBit indicates that this is a MuSig2 channel with a top
// level tapscript commitment. This MUST be set along with the
// SimpleTaprootFeatureBit.
TapscriptRootBit ChannelType = 1 << 11
) )
// IsSingleFunder returns true if the channel type if one of the known single // IsSingleFunder returns true if the channel type if one of the known single
@ -450,6 +469,12 @@ func (c ChannelType) IsTaproot() bool {
return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit return c&SimpleTaprootFeatureBit == SimpleTaprootFeatureBit
} }
// HasTapscriptRoot returns true if the channel is using a top level tapscript
// root commitment.
func (c ChannelType) HasTapscriptRoot() bool {
return c&TapscriptRootBit == TapscriptRootBit
}
// ChannelStateBounds are the parameters from OpenChannel and AcceptChannel // ChannelStateBounds are the parameters from OpenChannel and AcceptChannel
// that are responsible for providing bounds on the state space of the abstract // that are responsible for providing bounds on the state space of the abstract
// channel state. These values must be remembered for normal channel operation // channel state. These values must be remembered for normal channel operation
@ -1098,6 +1123,9 @@ func (c *OpenChannel) amendTlvData(auxData openChannelTlvData) {
auxData.memo.WhenSomeV(func(memo []byte) { auxData.memo.WhenSomeV(func(memo []byte) {
c.Memo = memo c.Memo = memo
}) })
auxData.tapscriptRoot.WhenSomeV(func(h [32]byte) {
c.TapscriptRoot = fn.Some[chainhash.Hash](h)
})
} }
// extractTlvData creates a new openChannelTlvData from the given channel. // extractTlvData creates a new openChannelTlvData from the given channel.
@ -1122,6 +1150,11 @@ func (c *OpenChannel) extractTlvData() openChannelTlvData {
tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo), tlv.NewPrimitiveRecord[tlv.TlvType5](c.Memo),
) )
} }
c.TapscriptRoot.WhenSome(func(h chainhash.Hash) {
auxData.tapscriptRoot = tlv.SomeRecordT(
tlv.NewPrimitiveRecord[tlv.TlvType6, [32]byte](h),
)
})
return auxData return auxData
} }

View File

@ -17,6 +17,7 @@ import (
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/clock" "github.com/lightningnetwork/lnd/clock"
"github.com/lightningnetwork/lnd/fn"
"github.com/lightningnetwork/lnd/keychain" "github.com/lightningnetwork/lnd/keychain"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnmock" "github.com/lightningnetwork/lnd/lnmock"
@ -173,7 +174,7 @@ func fundingPointOption(chanPoint wire.OutPoint) testChannelOption {
} }
// channelIDOption is an option which sets the short channel ID of the channel. // channelIDOption is an option which sets the short channel ID of the channel.
var channelIDOption = func(chanID lnwire.ShortChannelID) testChannelOption { func channelIDOption(chanID lnwire.ShortChannelID) testChannelOption {
return func(params *testChannelParams) { return func(params *testChannelParams) {
params.channel.ShortChannelID = chanID params.channel.ShortChannelID = chanID
} }
@ -326,6 +327,9 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
uniqueOutputIndex.Add(1) uniqueOutputIndex.Add(1)
op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()} op := wire.OutPoint{Hash: key, Index: uniqueOutputIndex.Load()}
var tapscriptRoot chainhash.Hash
copy(tapscriptRoot[:], bytes.Repeat([]byte{1}, 32))
return &OpenChannel{ return &OpenChannel{
ChanType: SingleFunderBit | FrozenBit, ChanType: SingleFunderBit | FrozenBit,
ChainHash: key, ChainHash: key,
@ -368,6 +372,8 @@ func createTestChannelState(t *testing.T, cdb *ChannelStateDB) *OpenChannel {
ThawHeight: uint32(defaultPendingHeight), ThawHeight: uint32(defaultPendingHeight),
InitialLocalBalance: lnwire.MilliSatoshi(9000), InitialLocalBalance: lnwire.MilliSatoshi(9000),
InitialRemoteBalance: lnwire.MilliSatoshi(3000), InitialRemoteBalance: lnwire.MilliSatoshi(3000),
Memo: []byte("test"),
TapscriptRoot: fn.Some(tapscriptRoot),
} }
} }