diff --git a/channeldb/edge_policy.go b/channeldb/edge_policy.go index 49edd18e1..8d507a49d 100644 --- a/channeldb/edge_policy.go +++ b/channeldb/edge_policy.go @@ -1,6 +1,7 @@ package channeldb import ( + "bufio" "bytes" "encoding/binary" "errors" @@ -12,6 +13,30 @@ import ( "github.com/lightningnetwork/lnd/channeldb/models" "github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" +) + +const ( + EdgePolicy2MsgType = tlv.Type(0) + EdgePolicy2ToNode = tlv.Type(1) + + // chanEdgePolicyNewEncodingPrefix is a byte used in the channel edge + // policy encoding to signal that the new style encoding which is + // prefixed with a type byte is being used instead of the legacy + // encoding which would start with 0x02 due to the fact that the + // encoding would start with a DER encoded ecdsa signature. + chanEdgePolicyNewEncodingPrefix = 0xff +) + +// edgePolicyEncoding indicates how the bytes for a channel edge policy have +// been serialised. +type edgePolicyEncodingType uint8 + +const ( + // edgePolicy2EncodingType will be used as a prefix for edge policies + // advertised using the ChannelUpdate2 message. The type indicates how + // the bytes following should be deserialized. + edgePolicy2EncodingType edgePolicyEncodingType = 0 ) func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, @@ -63,7 +88,14 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, return err } - oldUpdateTime := uint64(oldEdgePolicy.LastUpdate.Unix()) + oldPol, ok := oldEdgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", + oldEdgePolicy) + } + + oldUpdateTime := uint64(oldPol.LastUpdate.Unix()) var oldIndexKey [8 + 8]byte byteOrder.PutUint64(oldIndexKey[:8], oldUpdateTime) @@ -169,7 +201,13 @@ func fetchChanEdgePolicy(edges kvdb.RBucket, chanID []byte, return nil, err } - return ep, nil + pol, ok := ep.(*models.ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected *models.ChannelEdgePolicy1, "+ + "got: %T", ep) + } + + return pol, nil } func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, @@ -201,8 +239,56 @@ func fetchChanEdgePolicies(edgeIndex kvdb.RBucket, edges kvdb.RBucket, return edge1, edge2, nil } -func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, - to []byte) error { +func serializeChanEdgePolicy(w io.Writer, + edgePolicy models.ChannelEdgePolicy, toNode []byte) error { + + var ( + withTypeByte bool + typeByte edgePolicyEncodingType + serialize func(w io.Writer) error + ) + + switch policy := edgePolicy.(type) { + case *models.ChannelEdgePolicy1: + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy1(w, policy) + } + case *models.ChannelEdgePolicy2: + withTypeByte = true + typeByte = edgePolicy2EncodingType + + serialize = func(w io.Writer) error { + copy(policy.ToNode[:], toNode) + + return serializeChanEdgePolicy2(w, policy) + } + default: + return fmt.Errorf("unhandled implementation of "+ + "ChannelEdgePolicy: %T", edgePolicy) + } + + if withTypeByte { + // First, write the identifying encoding byte to signal that + // this is not using the legacy encoding. + _, err := w.Write([]byte{chanEdgePolicyNewEncodingPrefix}) + if err != nil { + return err + } + + // Now, write the encoding type. + _, err = w.Write([]byte{byte(typeByte)}) + if err != nil { + return err + } + } + + return serialize(w) +} + +func serializeChanEdgePolicy1(w io.Writer, + edge *models.ChannelEdgePolicy1) error { err := wire.WriteVarBytes(w, 0, edge.SigBytes) if err != nil { @@ -241,7 +327,7 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, return err } - if _, err := w.Write(to); err != nil { + if _, err := w.Write(edge.ToNode[:]); err != nil { return err } @@ -271,7 +357,36 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy1, return nil } -func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) { +func serializeChanEdgePolicy2(w io.Writer, + edge *models.ChannelEdgePolicy2) error { + + if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { + return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) + } + + var b bytes.Buffer + if err := edge.Encode(&b, 0); err != nil { + return err + } + + msg := b.Bytes() + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msg), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &edge.ToNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return err + } + + return stream.Encode(w) +} + +func deserializeChanEdgePolicy(r io.Reader) (models.ChannelEdgePolicy, + error) { + // Deserialize the policy. Note that in case an optional field is not // found, both an error and a populated policy object are returned. edge, deserializeErr := deserializeChanEdgePolicyRaw(r) @@ -284,7 +399,45 @@ func deserializeChanEdgePolicy(r io.Reader) (*models.ChannelEdgePolicy1, error) return edge, deserializeErr } -func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1, +func deserializeChanEdgePolicyRaw(reader io.Reader) (models.ChannelEdgePolicy, + error) { + + // Wrap the io.Reader in a bufio.Reader so that we can peak the first + // byte of the stream without actually consuming from the stream. + r := bufio.NewReader(reader) + + firstByte, err := r.Peek(1) + if err != nil { + return nil, err + } + + if firstByte[0] != chanEdgePolicyNewEncodingPrefix { + return deserializeChanEdgePolicy1Raw(r) + } + + // Pop the encoding type byte. + var scratch [1]byte + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + // Now, read the encoding type byte. + if _, err = r.Read(scratch[:]); err != nil { + return nil, err + } + + encoding := edgePolicyEncodingType(scratch[0]) + switch encoding { + case edgePolicy2EncodingType: + return deserializeChanEdgePolicy2Raw(r) + + default: + return nil, fmt.Errorf("unknown edge policy encoding type: %d", + encoding) + } +} + +func deserializeChanEdgePolicy1Raw(r io.Reader) (*models.ChannelEdgePolicy1, error) { edge := &models.ChannelEdgePolicy1{} @@ -370,3 +523,41 @@ func deserializeChanEdgePolicyRaw(r io.Reader) (*models.ChannelEdgePolicy1, return edge, nil } + +func deserializeChanEdgePolicy2Raw(r io.Reader) (*models.ChannelEdgePolicy2, + error) { + + var ( + msgBytes []byte + toNode [33]byte + ) + + records := []tlv.Record{ + tlv.MakePrimitiveRecord(EdgePolicy2MsgType, &msgBytes), + tlv.MakePrimitiveRecord(EdgePolicy2ToNode, &toNode), + } + + stream, err := tlv.NewStream(records...) + if err != nil { + return nil, err + } + + err = stream.Decode(r) + if err != nil { + return nil, err + } + + var ( + chanUpdate lnwire.ChannelUpdate2 + reader = bytes.NewReader(msgBytes) + ) + err = chanUpdate.Decode(reader, 0) + if err != nil { + return nil, err + } + + return &models.ChannelEdgePolicy2{ + ChannelUpdate2: chanUpdate, + ToNode: toNode, + }, nil +} diff --git a/channeldb/edge_policy_test.go b/channeldb/edge_policy_test.go new file mode 100644 index 000000000..16f03c554 --- /dev/null +++ b/channeldb/edge_policy_test.go @@ -0,0 +1,174 @@ +package channeldb + +import ( + "bytes" + "math/rand" + "reflect" + "testing" + "testing/quick" + "time" + + "github.com/lightningnetwork/lnd/channeldb/models" + "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/tlv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestEdgePolicySerialisation tests the serialisation and deserialization logic +// for models.ChannelEdgePolicy. +func TestEdgePolicySerialisation(t *testing.T) { + t.Parallel() + + mainScenario := func(info models.ChannelEdgePolicy) bool { + var ( + b bytes.Buffer + toNode = info.GetToNode() + ) + + err := serializeChanEdgePolicy(&b, info, toNode[:]) + require.NoError(t, err) + + newInfo, err := deserializeChanEdgePolicy(&b) + require.NoError(t, err) + + return assert.Equal(t, info, newInfo) + } + + tests := []struct { + name string + genValue func([]reflect.Value, *rand.Rand) + scenario any + }{ + { + name: "ChannelEdgePolicy1", + scenario: func(m models.ChannelEdgePolicy1) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + //nolint:lll + policy := &models.ChannelEdgePolicy1{ + ChannelID: r.Uint64(), + LastUpdate: time.Unix(r.Int63(), 0), + MessageFlags: lnwire.ChanUpdateMsgFlags(r.Uint32()), + ChannelFlags: lnwire.ChanUpdateChanFlags(r.Uint32()), + TimeLockDelta: uint16(r.Uint32()), + MinHTLC: lnwire.MilliSatoshi(r.Uint64()), + FeeBaseMSat: lnwire.MilliSatoshi(r.Uint64()), + FeeProportionalMillionths: lnwire.MilliSatoshi(r.Uint64()), + ExtraOpaqueData: make([]byte, 0), + } + + policy.SigBytes = make([]byte, r.Intn(80)) + _, err := r.Read(policy.SigBytes) + require.NoError(t, err) + + _, err = r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + // Sometimes add an MaxHTLC. + if r.Intn(2)%2 == 0 { + policy.MessageFlags |= + lnwire.ChanUpdateRequiredMaxHtlc + policy.MaxHTLC = lnwire.MilliSatoshi( + r.Uint64(), + ) + } else { + policy.MessageFlags ^= + lnwire.ChanUpdateRequiredMaxHtlc + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + { + name: "ChannelEdgePolicy2", + scenario: func(m models.ChannelEdgePolicy2) bool { + return mainScenario(&m) + }, + genValue: func(v []reflect.Value, r *rand.Rand) { + policy := &models.ChannelEdgePolicy2{ + //nolint:lll + ChannelUpdate2: lnwire.ChannelUpdate2{ + Signature: testSchnorrSig, + ExtraOpaqueData: make([]byte, 0), + }, + ToNode: [33]byte{}, + } + + policy.ShortChannelID.Val = lnwire.NewShortChanIDFromInt( //nolint:lll + uint64(r.Int63()), + ) + policy.BlockHeight.Val = r.Uint32() + policy.HTLCMaximumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.HTLCMinimumMsat.Val = lnwire.MilliSatoshi( //nolint:lll + r.Uint64(), + ) + policy.CLTVExpiryDelta.Val = uint16(r.Int31()) + policy.FeeBaseMsat.Val = r.Uint32() + policy.FeeProportionalMillionths.Val = r.Uint32() //nolint:lll + + if r.Intn(2) == 0 { + policy.SecondPeer = tlv.SomeRecordT( + tlv.ZeroRecordT[tlv.TlvType8, lnwire.TrueBoolean](), //nolint:lll + ) + } + + // Sometimes set the incoming disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableIncoming + } + + // Sometimes set the outgoing disabled flag. + if r.Int31()%2 == 0 { + policy.DisabledFlags.Val |= + lnwire.ChanUpdateDisableOutgoing + } + + _, err := r.Read(policy.ToNode[:]) + require.NoError(t, err) + + numExtraBytes := r.Int31n(1000) + if numExtraBytes > 0 { + policy.ExtraOpaqueData = make( + []byte, numExtraBytes, + ) + _, err := r.Read( + policy.ExtraOpaqueData, + ) + require.NoError(t, err) + } + + v[0] = reflect.ValueOf(*policy) + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + config := &quick.Config{ + Values: test.genValue, + } + + err := quick.Check(test.scenario, config) + require.NoError(t, err) + }) + } +} diff --git a/channeldb/graph.go b/channeldb/graph.go index 89e777628..f855fa2fc 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -311,7 +311,13 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } - channelMap[key] = edge + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", edge) + } + + channelMap[key] = e return nil }) @@ -2387,7 +2393,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node1UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node1UpdateTimestamp = e.LastUpdate } rawPolicy = edges.Get(node2Key) @@ -2402,7 +2415,14 @@ func (c *ChannelGraph) FilterChannelRange(startHeight, return err } - chanInfo.Node2UpdateTimestamp = edge.LastUpdate + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, "+ + "got %T", edge) + } + + chanInfo.Node2UpdateTimestamp = e.LastUpdate } channelsPerBlock[cid.BlockHeight] = append(