mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-20 13:04:28 +02:00
channeldb: add encoding for ChannelEdgePolicy2
Similarly to the previous commit, here we add the encoding for the new ChannelEdgePolicy2. This is done in the same was as for ChannelEdgeInfo2: - a 0xff prefix - followed by a type-byte - followed by the TLV encoding of the ChannelEdgePolicy2.
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
@@ -12,6 +13,30 @@ import (
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/kvdb"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
)
|
||||
|
||||
const (
|
||||
EdgePolicy2MsgType = tlv.Type(0)
|
||||
EdgePolicy2ToNode = tlv.Type(1)
|
||||
|
||||
// chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge
|
||||
// policy encoding to signal that the new style encoding which is
|
||||
// prefixed with a type byte is being used instead of the legacy
|
||||
// encoding which would start with 0x02 due to the fact that the
|
||||
// encoding would start with a DER encoded ecdsa signature.
|
||||
chanEdgePolicyNewEncodingPrefix = 0xff
|
||||
)
|
||||
|
||||
// edgePolicyEncoding indicates how the bytes for a channel edge policy have
|
||||
// been serialised.
|
||||
type edgePolicyEncodingType uint8
|
||||
|
||||
const (
|
||||
// edgePolicy2EncodingType will be used as a prefix for edge policies
|
||||
// advertised using the ChannelUpdate2 message. The type indicates how
|
||||
// the bytes following should be deserialized.
|
||||
edgePolicy2EncodingType edgePolicyEncodingType = 0
|
||||
)
|
||||
|
||||
func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1,
|
||||
@@ -63,7 +88,14 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1,
|
||||
return err
|
||||
}
|
||||
|
||||
oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix())
|
||||
oldPol, ok := oldEdgePolicy.(*models.ChannelEdgePolicy1)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected "+
|
||||
"*models.ChannelEdgePolicy1, got: %T",
|
||||
oldEdgePolicy)
|
||||
}
|
||||
|
||||
oldUpdateTime := uint64(oldPol.LastUpdate.Unix())
|
||||
|
||||
var oldIndexKey [8 + 8]byte
|
||||
byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime)
|
||||
@@ -169,7 +201,13 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte,
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ep, nil
|
||||
pol, ok := ep.(*models.ChannelEdgePolicy1)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("expected *models.ChannelEdgePolicy1, "+
|
||||
"got: %T", ep)
|
||||
}
|
||||
|
||||
return pol, nil
|
||||
}
|
||||
|
||||
func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
|
||||
@@ -201,8 +239,56 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket,
|
||||
return edge1, edge2, nil
|
||||
}
|
||||
|
||||
func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
|
||||
to []byte) error {
|
||||
func serializeChanEdgePolicy(w io.Writer,
|
||||
edgePolicy models.ChannelEdgePolicy, toNode []byte) error {
|
||||
|
||||
var (
|
||||
withTypeByte bool
|
||||
typeByte edgePolicyEncodingType
|
||||
serialize func(w io.Writer) error
|
||||
)
|
||||
|
||||
switch policy := edgePolicy.(type) {
|
||||
case *models.ChannelEdgePolicy1:
|
||||
serialize = func(w io.Writer) error {
|
||||
copy(policy.ToNode[:], toNode)
|
||||
|
||||
return serializeChanEdgePolicy1(w, policy)
|
||||
}
|
||||
case *models.ChannelEdgePolicy2:
|
||||
withTypeByte = true
|
||||
typeByte = edgePolicy2EncodingType
|
||||
|
||||
serialize = func(w io.Writer) error {
|
||||
copy(policy.ToNode[:], toNode)
|
||||
|
||||
return serializeChanEdgePolicy2(w, policy)
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("unhandled implementation of "+
|
||||
"ChannelEdgePolicy: %T", edgePolicy)
|
||||
}
|
||||
|
||||
if withTypeByte {
|
||||
// First, write the identifying encoding byte to signal that
|
||||
// this is not using the legacy encoding.
|
||||
_, err := w.Write([]byte{chanEdgePolicyNewEncodingPrefix})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Now, write the encoding type.
|
||||
_, err = w.Write([]byte{byte(typeByte)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return serialize(w)
|
||||
}
|
||||
|
||||
func serializeChanEdgePolicy1(w io.Writer,
|
||||
edge *models.ChannelEdgePolicy1) error {
|
||||
|
||||
err := wire.WriteVarBytes(w, 0, edge.SigBytes)
|
||||
if err != nil {
|
||||
@@ -241,7 +327,7 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := w.Write(to); err != nil {
|
||||
if _, err := w.Write(edge.ToNode[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -271,7 +357,36 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1,
|
||||
return nil
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) {
|
||||
func serializeChanEdgePolicy2(w io.Writer,
|
||||
edge *models.ChannelEdgePolicy2) error {
|
||||
|
||||
if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
|
||||
return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData))
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := edge.Encode(&b, 0); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
msg := b.Bytes()
|
||||
|
||||
records := []tlv.Record{
|
||||
tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg),
|
||||
tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &edge.ToNode),
|
||||
}
|
||||
|
||||
stream, err := tlv.NewStream(records...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return stream.Encode(w)
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicy(r io.Reader) (models.ChannelEdgePolicy,
|
||||
error) {
|
||||
|
||||
// Deserialize the policy. Note that in case an optional field is not
|
||||
// found, both an error and a populated policy object are returned.
|
||||
edge, deserializeErr := deserializeChanEdgePolicyRaw(r)
|
||||
@@ -284,7 +399,45 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error)
|
||||
return edge, deserializeErr
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1,
|
||||
func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy,
|
||||
error) {
|
||||
|
||||
// Wrap the io.Reader in a bufio.Reader so that we can peak the first
|
||||
// byte of the stream without actually consuming from the stream.
|
||||
r := bufio.NewReader(reader)
|
||||
|
||||
firstByte, err := r.Peek(1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if firstByte[0] != chanEdgePolicyNewEncodingPrefix {
|
||||
return deserializeChanEdgePolicy1Raw(r)
|
||||
}
|
||||
|
||||
// Pop the encoding type byte.
|
||||
var scratch [1]byte
|
||||
if _, err = r.Read(scratch[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now, read the encoding type byte.
|
||||
if _, err = r.Read(scratch[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encoding := edgePolicyEncodingType(scratch[0])
|
||||
switch encoding {
|
||||
case edgePolicy2EncodingType:
|
||||
return deserializeChanEdgePolicy2Raw(r)
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown edge policy encoding type: %d",
|
||||
encoding)
|
||||
}
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicy1Raw(r io.Reader) (*models.ChannelEdgePolicy1,
|
||||
error) {
|
||||
|
||||
edge := &models.ChannelEdgePolicy1{}
|
||||
@@ -370,3 +523,41 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1,
|
||||
|
||||
return edge, nil
|
||||
}
|
||||
|
||||
func deserializeChanEdgePolicy2Raw(r io.Reader) (*models.ChannelEdgePolicy2,
|
||||
error) {
|
||||
|
||||
var (
|
||||
msgBytes []byte
|
||||
toNode [33]byte
|
||||
)
|
||||
|
||||
records := []tlv.Record{
|
||||
tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes),
|
||||
tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode),
|
||||
}
|
||||
|
||||
stream, err := tlv.NewStream(records...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = stream.Decode(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var (
|
||||
chanUpdate lnwire.ChannelUpdate2
|
||||
reader = bytes.NewReader(msgBytes)
|
||||
)
|
||||
err = chanUpdate.Decode(reader, 0)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &models.ChannelEdgePolicy2{
|
||||
ChannelUpdate2: chanUpdate,
|
||||
ToNode: toNode,
|
||||
}, nil
|
||||
}
|
||||
|
174
channeldb/edge_policy_test.go
Normal file
174
channeldb/edge_policy_test.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/lightningnetwork/lnd/channeldb/models"
|
||||
"github.com/lightningnetwork/lnd/lnwire"
|
||||
"github.com/lightningnetwork/lnd/tlv"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
// TestEdgePolicySerialisation tests the serialisation and deserialization logic
|
||||
// for models.ChannelEdgePolicy.
|
||||
func TestEdgePolicySerialisation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mainScenario := func(info models.ChannelEdgePolicy) bool {
|
||||
var (
|
||||
b bytes.Buffer
|
||||
toNode = info.GetToNode()
|
||||
)
|
||||
|
||||
err := serializeChanEdgePolicy(&b, info, toNode[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
newInfo, err := deserializeChanEdgePolicy(&b)
|
||||
require.NoError(t, err)
|
||||
|
||||
return assert.Equal(t, info, newInfo)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
genValue func([]reflect.Value, *rand.Rand)
|
||||
scenario any
|
||||
}{
|
||||
{
|
||||
name: "ChannelEdgePolicy1",
|
||||
scenario: func(m models.ChannelEdgePolicy1) bool {
|
||||
return mainScenario(&m)
|
||||
},
|
||||
genValue: func(v []reflect.Value, r *rand.Rand) {
|
||||
//nolint:lll
|
||||
policy := &models.ChannelEdgePolicy1{
|
||||
ChannelID: r.Uint64(),
|
||||
LastUpdate: time.Unix(r.Int63(), 0),
|
||||
MessageFlags: lnwire.ChanUpdateMsgFlags(r.Uint32()),
|
||||
ChannelFlags: lnwire.ChanUpdateChanFlags(r.Uint32()),
|
||||
TimeLockDelta: uint16(r.Uint32()),
|
||||
MinHTLC: lnwire.MilliSatoshi(r.Uint64()),
|
||||
FeeBaseMSat: lnwire.MilliSatoshi(r.Uint64()),
|
||||
FeeProportionalMillionths: lnwire.MilliSatoshi(r.Uint64()),
|
||||
ExtraOpaqueData: make([]byte, 0),
|
||||
}
|
||||
|
||||
policy.SigBytes = make([]byte, r.Intn(80))
|
||||
_, err := r.Read(policy.SigBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = r.Read(policy.ToNode[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
numExtraBytes := r.Int31n(1000)
|
||||
if numExtraBytes > 0 {
|
||||
policy.ExtraOpaqueData = make(
|
||||
[]byte, numExtraBytes,
|
||||
)
|
||||
_, err := r.Read(
|
||||
policy.ExtraOpaqueData,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Sometimes add an MaxHTLC.
|
||||
if r.Intn(2)%2 == 0 {
|
||||
policy.MessageFlags |=
|
||||
lnwire.ChanUpdateRequiredMaxHtlc
|
||||
policy.MaxHTLC = lnwire.MilliSatoshi(
|
||||
r.Uint64(),
|
||||
)
|
||||
} else {
|
||||
policy.MessageFlags ^=
|
||||
lnwire.ChanUpdateRequiredMaxHtlc
|
||||
}
|
||||
|
||||
v[0] = reflect.ValueOf(*policy)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ChannelEdgePolicy2",
|
||||
scenario: func(m models.ChannelEdgePolicy2) bool {
|
||||
return mainScenario(&m)
|
||||
},
|
||||
genValue: func(v []reflect.Value, r *rand.Rand) {
|
||||
policy := &models.ChannelEdgePolicy2{
|
||||
//nolint:lll
|
||||
ChannelUpdate2: lnwire.ChannelUpdate2{
|
||||
Signature: testSchnorrSig,
|
||||
ExtraOpaqueData: make([]byte, 0),
|
||||
},
|
||||
ToNode: [33]byte{},
|
||||
}
|
||||
|
||||
policy.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( //nolint:lll
|
||||
uint64(r.Int63()),
|
||||
)
|
||||
policy.BlockHeight.Val = r.Uint32()
|
||||
policy.HTLCMaximumMsat.Val = lnwire.MilliSatoshi( //nolint:lll
|
||||
r.Uint64(),
|
||||
)
|
||||
policy.HTLCMinimumMsat.Val = lnwire.MilliSatoshi( //nolint:lll
|
||||
r.Uint64(),
|
||||
)
|
||||
policy.CLTVExpiryDelta.Val = uint16(r.Int31())
|
||||
policy.FeeBaseMsat.Val = r.Uint32()
|
||||
policy.FeeProportionalMillionths.Val = r.Uint32() //nolint:lll
|
||||
|
||||
if r.Intn(2) == 0 {
|
||||
policy.SecondPeer = tlv.SomeRecordT(
|
||||
tlv.ZeroRecordT[tlv.TlvType8, lnwire.TrueBoolean](), //nolint:lll
|
||||
)
|
||||
}
|
||||
|
||||
// Sometimes set the incoming disabled flag.
|
||||
if r.Int31()%2 == 0 {
|
||||
policy.DisabledFlags.Val |=
|
||||
lnwire.ChanUpdateDisableIncoming
|
||||
}
|
||||
|
||||
// Sometimes set the outgoing disabled flag.
|
||||
if r.Int31()%2 == 0 {
|
||||
policy.DisabledFlags.Val |=
|
||||
lnwire.ChanUpdateDisableOutgoing
|
||||
}
|
||||
|
||||
_, err := r.Read(policy.ToNode[:])
|
||||
require.NoError(t, err)
|
||||
|
||||
numExtraBytes := r.Int31n(1000)
|
||||
if numExtraBytes > 0 {
|
||||
policy.ExtraOpaqueData = make(
|
||||
[]byte, numExtraBytes,
|
||||
)
|
||||
_, err := r.Read(
|
||||
policy.ExtraOpaqueData,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
v[0] = reflect.ValueOf(*policy)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
test := test
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
config := &quick.Config{
|
||||
Values: test.genValue,
|
||||
}
|
||||
|
||||
err := quick.Check(test.scenario, config)
|
||||
require.NoError(t, err)
|
||||
})
|
||||
}
|
||||
}
|
@@ -311,7 +311,13 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) (
|
||||
return err
|
||||
}
|
||||
|
||||
channelMap[key] = edge
|
||||
e, ok := edge.(*models.ChannelEdgePolicy1)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected "+
|
||||
"*models.ChannelEdgePolicy1, got: %T", edge)
|
||||
}
|
||||
|
||||
channelMap[key] = e
|
||||
|
||||
return nil
|
||||
})
|
||||
@@ -2387,7 +2393,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
|
||||
return err
|
||||
}
|
||||
|
||||
chanInfo.Node1UpdateTimestamp = edge.LastUpdate
|
||||
e, ok := edge.(*models.ChannelEdgePolicy1)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected "+
|
||||
"*models.ChannelEdgePolicy1, "+
|
||||
"got %T", edge)
|
||||
}
|
||||
|
||||
chanInfo.Node1UpdateTimestamp = e.LastUpdate
|
||||
}
|
||||
|
||||
rawPolicy = edges.Get(node2Key)
|
||||
@@ -2402,7 +2415,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight,
|
||||
return err
|
||||
}
|
||||
|
||||
chanInfo.Node2UpdateTimestamp = edge.LastUpdate
|
||||
e, ok := edge.(*models.ChannelEdgePolicy1)
|
||||
if !ok {
|
||||
return fmt.Errorf("expected "+
|
||||
"*models.ChannelEdgePolicy1, "+
|
||||
"got %T", edge)
|
||||
}
|
||||
|
||||
chanInfo.Node2UpdateTimestamp = e.LastUpdate
|
||||
}
|
||||
|
||||
channelsPerBlock[cid.BlockHeight] = append(
|
||||
|
Reference in New Issue
Block a user