channeldb: add encoding for ChannelEdgePolicy2

Similarly to the previous commit, here we add the encoding for the new
ChannelEdgePolicy2. This is done in the same was as for
ChannelEdgeInfo2:
- a 0xff prefix
- followed by a type-byte
- followed by the TLV encoding of the ChannelEdgePolicy2.
This commit is contained in:
Elle Mouton
2024-08-21 12:58:01 +02:00
parent 7c9b38bcc6
commit 18445db7fa
3 changed files with 395 additions and 10 deletions

View File

@@ -1,6 +1,7 @@
package channeldb
import (
"bufio"
"bytes"
"encoding/binary"
"errors"
@@ -12,6 +13,30 @@ import (
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/kvdb"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
)
const (
EdgePolicy2MsgType = tlv.Type(0)
EdgePolicy2ToNode = tlv.Type(1)
// chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge
// policy encoding to signal that the new style encoding which is
// prefixed with a type byte is being used instead of the legacy
// encoding which would start with 0x02 due to the fact that the
// encoding would start with a DER encoded ecdsa signature.
chanEdgePolicyNewEncodingPrefix = 0xff
)
// edgePolicyEncoding indicates how the bytes for a channel edge policy have
// been serialised.
type edgePolicyEncodingType uint8
const (
// edgePolicy2EncodingType will be used as a prefix for edge policies
// advertised using the ChannelUpdate2 message. The type indicates how
// the bytes following should be deserialized.
edgePolicy2EncodingType edgePolicyEncodingType = 0
)
func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1,
@@ -63,7 +88,14 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1,
return err
}
oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix())
oldPol, ok := oldEdgePolicy.(*models.ChannelEdgePolicy1)
if !ok {
return fmt.Errorf("expected "+
"*models.ChannelEdgePolicy1, got: %T",
oldEdgePolicy)
}
oldUpdateTime := uint64(oldPol.LastUpdate.Unix())
var oldIndexKey [8 + 8]byte
byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime)
@@ -169,7 +201,13 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
return nil, err
}
return ep, nil
pol, ok := ep.(*models.ChannelEdgePolicy1)
if !ok {
return nil, fmt.Errorf("expected *models.ChannelEdgePolicy1, "+
"got: %T", ep)
}
return pol, nil
}
func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
@@ -201,8 +239,56 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
return edge1, edge2, nil
}
func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
to []byte) error {
func serializeChanEdgePolicy(w io.Writer,
edgePolicy models.ChannelEdgePolicy, toNode []byte) error {
var (
withTypeByte bool
typeByte edgePolicyEncodingType
serialize func(w io.Writer) error
)
switch policy := edgePolicy.(type) {
case *models.ChannelEdgePolicy1:
serialize = func(w io.Writer) error {
copy(policy.ToNode[:], toNode)
return serializeChanEdgePolicy1(w, policy)
}
case *models.ChannelEdgePolicy2:
withTypeByte = true
typeByte = edgePolicy2EncodingType
serialize = func(w io.Writer) error {
copy(policy.ToNode[:], toNode)
return serializeChanEdgePolicy2(w, policy)
}
default:
return fmt.Errorf("unhandled implementation of "+
"ChannelEdgePolicy: %T", edgePolicy)
}
if withTypeByte {
// First, write the identifying encoding byte to signal that
// this is not using the legacy encoding.
_, err := w.Write([]byte{chanEdgePolicyNewEncodingPrefix})
if err != nil {
return err
}
// Now, write the encoding type.
_, err = w.Write([]byte{byte(typeByte)})
if err != nil {
return err
}
}
return serialize(w)
}
func serializeChanEdgePolicy1(w io.Writer,
edge *models.ChannelEdgePolicy1) error {
err := wire.WriteVarBytes(w, 0, edge.SigBytes)
if err != nil {
@@ -241,7 +327,7 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
return err
}
if _, err := w.Write(to); err != nil {
if _, err := w.Write(edge.ToNode[:]); err != nil {
return err
}
@@ -271,7 +357,36 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
return nil
}
func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) {
func serializeChanEdgePolicy2(w io.Writer,
edge *models.ChannelEdgePolicy2) error {
if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData))
}
var b bytes.Buffer
if err := edge.Encode(&b, 0); err != nil {
return err
}
msg := b.Bytes()
records := []tlv.Record{
tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg),
tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &edge.ToNode),
}
stream, err := tlv.NewStream(records...)
if err != nil {
return err
}
return stream.Encode(w)
}
func deserializeChanEdgePolicy(r io.Reader) (models.ChannelEdgePolicy,
error) {
// Deserialize the policy. Note that in case an optional field is not
// found, both an error and a populated policy object are returned.
edge, deserializeErr := deserializeChanEdgePolicyRaw(r)
@@ -284,7 +399,45 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error)
return edge, deserializeErr
}
func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1,
func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy,
error) {
// Wrap the io.Reader in a bufio.Reader so that we can peak the first
// byte of the stream without actually consuming from the stream.
r := bufio.NewReader(reader)
firstByte, err := r.Peek(1)
if err != nil {
return nil, err
}
if firstByte[0] != chanEdgePolicyNewEncodingPrefix {
return deserializeChanEdgePolicy1Raw(r)
}
// Pop the encoding type byte.
var scratch [1]byte
if _, err = r.Read(scratch[:]); err != nil {
return nil, err
}
// Now, read the encoding type byte.
if _, err = r.Read(scratch[:]); err != nil {
return nil, err
}
encoding := edgePolicyEncodingType(scratch[0])
switch encoding {
case edgePolicy2EncodingType:
return deserializeChanEdgePolicy2Raw(r)
default:
return nil, fmt.Errorf("unknown edge policy encoding type: %d",
encoding)
}
}
func deserializeChanEdgePolicy1Raw(r io.Reader) (*models.ChannelEdgePolicy1,
error) {
edge := &models.ChannelEdgePolicy1{}
@@ -370,3 +523,41 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1,
return edge, nil
}
func deserializeChanEdgePolicy2Raw(r io.Reader) (*models.ChannelEdgePolicy2,
error) {
var (
msgBytes []byte
toNode [33]byte
)
records := []tlv.Record{
tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes),
tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode),
}
stream, err := tlv.NewStream(records...)
if err != nil {
return nil, err
}
err = stream.Decode(r)
if err != nil {
return nil, err
}
var (
chanUpdate lnwire.ChannelUpdate2
reader = bytes.NewReader(msgBytes)
)
err = chanUpdate.Decode(reader, 0)
if err != nil {
return nil, err
}
return &models.ChannelEdgePolicy2{
ChannelUpdate2: chanUpdate,
ToNode: toNode,
}, nil
}

View File

@@ -0,0 +1,174 @@
package channeldb
import (
"bytes"
"math/rand"
"reflect"
"testing"
"testing/quick"
"time"
"github.com/lightningnetwork/lnd/channeldb/models"
"github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/tlv"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// TestEdgePolicySerialisation tests the serialisation and deserialization logic
// for models.ChannelEdgePolicy.
func TestEdgePolicySerialisation(t *testing.T) {
t.Parallel()
mainScenario := func(info models.ChannelEdgePolicy) bool {
var (
b bytes.Buffer
toNode = info.GetToNode()
)
err := serializeChanEdgePolicy(&b, info, toNode[:])
require.NoError(t, err)
newInfo, err := deserializeChanEdgePolicy(&b)
require.NoError(t, err)
return assert.Equal(t, info, newInfo)
}
tests := []struct {
name string
genValue func([]reflect.Value, *rand.Rand)
scenario any
}{
{
name: "ChannelEdgePolicy1",
scenario: func(m models.ChannelEdgePolicy1) bool {
return mainScenario(&m)
},
genValue: func(v []reflect.Value, r *rand.Rand) {
//nolint:lll
policy := &models.ChannelEdgePolicy1{
ChannelID: r.Uint64(),
LastUpdate: time.Unix(r.Int63(), 0),
MessageFlags: lnwire.ChanUpdateMsgFlags(r.Uint32()),
ChannelFlags: lnwire.ChanUpdateChanFlags(r.Uint32()),
TimeLockDelta: uint16(r.Uint32()),
MinHTLC: lnwire.MilliSatoshi(r.Uint64()),
FeeBaseMSat: lnwire.MilliSatoshi(r.Uint64()),
FeeProportionalMillionths: lnwire.MilliSatoshi(r.Uint64()),
ExtraOpaqueData: make([]byte, 0),
}
policy.SigBytes = make([]byte, r.Intn(80))
_, err := r.Read(policy.SigBytes)
require.NoError(t, err)
_, err = r.Read(policy.ToNode[:])
require.NoError(t, err)
numExtraBytes := r.Int31n(1000)
if numExtraBytes > 0 {
policy.ExtraOpaqueData = make(
[]byte, numExtraBytes,
)
_, err := r.Read(
policy.ExtraOpaqueData,
)
require.NoError(t, err)
}
// Sometimes add an MaxHTLC.
if r.Intn(2)%2 == 0 {
policy.MessageFlags |=
lnwire.ChanUpdateRequiredMaxHtlc
policy.MaxHTLC = lnwire.MilliSatoshi(
r.Uint64(),
)
} else {
policy.MessageFlags ^=
lnwire.ChanUpdateRequiredMaxHtlc
}
v[0] = reflect.ValueOf(*policy)
},
},
{
name: "ChannelEdgePolicy2",
scenario: func(m models.ChannelEdgePolicy2) bool {
return mainScenario(&m)
},
genValue: func(v []reflect.Value, r *rand.Rand) {
policy := &models.ChannelEdgePolicy2{
//nolint:lll
ChannelUpdate2: lnwire.ChannelUpdate2{
Signature: testSchnorrSig,
ExtraOpaqueData: make([]byte, 0),
},
ToNode: [33]byte{},
}
policy.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( //nolint:lll
uint64(r.Int63()),
)
policy.BlockHeight.Val = r.Uint32()
policy.HTLCMaximumMsat.Val = lnwire.MilliSatoshi( //nolint:lll
r.Uint64(),
)
policy.HTLCMinimumMsat.Val = lnwire.MilliSatoshi( //nolint:lll
r.Uint64(),
)
policy.CLTVExpiryDelta.Val = uint16(r.Int31())
policy.FeeBaseMsat.Val = r.Uint32()
policy.FeeProportionalMillionths.Val = r.Uint32() //nolint:lll
if r.Intn(2) == 0 {
policy.SecondPeer = tlv.SomeRecordT(
tlv.ZeroRecordT[tlv.TlvType8, lnwire.TrueBoolean](), //nolint:lll
)
}
// Sometimes set the incoming disabled flag.
if r.Int31()%2 == 0 {
policy.DisabledFlags.Val |=
lnwire.ChanUpdateDisableIncoming
}
// Sometimes set the outgoing disabled flag.
if r.Int31()%2 == 0 {
policy.DisabledFlags.Val |=
lnwire.ChanUpdateDisableOutgoing
}
_, err := r.Read(policy.ToNode[:])
require.NoError(t, err)
numExtraBytes := r.Int31n(1000)
if numExtraBytes > 0 {
policy.ExtraOpaqueData = make(
[]byte, numExtraBytes,
)
_, err := r.Read(
policy.ExtraOpaqueData,
)
require.NoError(t, err)
}
v[0] = reflect.ValueOf(*policy)
},
},
}
for _, test := range tests {
test := test
t.Run(test.name, func(t *testing.T) {
t.Parallel()
config := &quick.Config{
Values: test.genValue,
}
err := quick.Check(test.scenario, config)
require.NoError(t, err)
})
}
}

View File

@@ -311,7 +311,13 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) (
return err
}
channelMap[key] = edge
e, ok := edge.(*models.ChannelEdgePolicy1)
if !ok {
return fmt.Errorf("expected "+
"*models.ChannelEdgePolicy1, got: %T", edge)
}
channelMap[key] = e
return nil
})
@@ -2387,7 +2393,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
return err
}
chanInfo.Node1UpdateTimestamp = edge.LastUpdate
e, ok := edge.(*models.ChannelEdgePolicy1)
if !ok {
return fmt.Errorf("expected "+
"*models.ChannelEdgePolicy1, "+
"got %T", edge)
}
chanInfo.Node1UpdateTimestamp = e.LastUpdate
}
rawPolicy = edges.Get(node2Key)
@@ -2402,7 +2415,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
return err
}
chanInfo.Node2UpdateTimestamp = edge.LastUpdate
e, ok := edge.(*models.ChannelEdgePolicy1)
if !ok {
return fmt.Errorf("expected "+
"*models.ChannelEdgePolicy1, "+
"got %T", edge)
}
chanInfo.Node2UpdateTimestamp = e.LastUpdate
}
channelsPerBlock[cid.BlockHeight] = append(