multi: do not use tlv.Record outside wire format handling

This commit prepares for more manipulation of custom records. A list of
tlv.Record types is more difficult to use than the more basic
map[uint64][]byte.

Furthermore fields and variables are renamed to make them more
consistent.
This commit is contained in:
Joost Jager 2019-12-11 10:52:27 +01:00
parent 8b5bb0ac63
commit d02de70d20
No known key found for this signature in database
GPG Key ID: A61B9D4C393C59C7
9 changed files with 46 additions and 127 deletions

View File

@ -606,7 +606,9 @@ func serializeHop(w io.Writer, h *route.Hop) error {
if h.MPP != nil { if h.MPP != nil {
records = append(records, h.MPP.Record()) records = append(records, h.MPP.Record())
} }
records = append(records, h.TLVRecords...)
tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
// Otherwise, we'll transform our slice of records into a map of the // Otherwise, we'll transform our slice of records into a map of the
// raw bytes, then serialize them in-line with a length (number of // raw bytes, then serialize them in-line with a length (number of
@ -710,7 +712,7 @@ func deserializeHop(r io.Reader) (*route.Hop, error) {
h.MPP = mpp h.MPP = mpp
} }
h.TLVRecords = tlv.MapToRecords(tlvMap) h.CustomRecords = tlvMap
return h, nil return h, nil
} }

View File

@ -2,7 +2,6 @@ package channeldb
import ( import (
"bytes" "bytes"
"errors"
"fmt" "fmt"
"math/rand" "math/rand"
"reflect" "reflect"
@ -28,9 +27,9 @@ var (
ChannelID: 12345, ChannelID: 12345,
OutgoingTimeLock: 111, OutgoingTimeLock: 111,
AmtToForward: 555, AmtToForward: 555,
TLVRecords: []tlv.Record{ CustomRecords: record.CustomSet{
tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), 1: []byte{},
tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), 2: []byte{},
}, },
MPP: record.NewMPP(32, [32]byte{0x42}), MPP: record.NewMPP(32, [32]byte{0x42}),
} }
@ -144,25 +143,7 @@ func TestSentPaymentSerialization(t *testing.T) {
// assertRouteEquals compares to routes for equality and returns an error if // assertRouteEquals compares to routes for equality and returns an error if
// they are not equal. // they are not equal.
func assertRouteEqual(a, b *route.Route) error { func assertRouteEqual(a, b *route.Route) error {
err := assertRouteHopRecordsEqual(a, b) if !reflect.DeepEqual(a, b) {
if err != nil {
return err
}
// TLV records have already been compared and need to be cleared to
// properly compare the remaining fields using DeepEqual.
copyRouteNoHops := func(r *route.Route) *route.Route {
copy := *r
copy.Hops = make([]*route.Hop, len(r.Hops))
for i, hop := range r.Hops {
hopCopy := *hop
hopCopy.TLVRecords = nil
copy.Hops[i] = &hopCopy
}
return &copy
}
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
spew.Sdump(a), spew.Sdump(b)) spew.Sdump(a), spew.Sdump(b))
} }
@ -170,57 +151,6 @@ func assertRouteEqual(a, b *route.Route) error {
return nil return nil
} }
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
if len(r1.Hops) != len(r2.Hops) {
return errors.New("route hop count mismatch")
}
for i := 0; i < len(r1.Hops); i++ {
records1 := r1.Hops[i].TLVRecords
records2 := r2.Hops[i].TLVRecords
if len(records1) != len(records2) {
return fmt.Errorf("route record count for hop %v "+
"mismatch", i)
}
for j := 0; j < len(records1); j++ {
expectedRecord := records1[j]
newRecord := records2[j]
err := assertHopRecordsEqual(expectedRecord, newRecord)
if err != nil {
return fmt.Errorf("route record mismatch: %v", err)
}
}
}
return nil
}
func assertHopRecordsEqual(h1, h2 tlv.Record) error {
if h1.Type() != h2.Type() {
return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(),
h2.Type())
}
var b bytes.Buffer
if err := h2.Encode(&b); err != nil {
return fmt.Errorf("unable to encode record: %v", err)
}
if !bytes.Equal(b.Bytes(), tlvBytes) {
return fmt.Errorf("wrong raw record: expected %x, got %x",
tlvBytes, b.Bytes())
}
if h1.Size() != h2.Size() {
return fmt.Errorf("wrong size: expected %v, "+
"got %v", h1.Size(), h2.Size())
}
return nil
}
func TestRouteSerialization(t *testing.T) { func TestRouteSerialization(t *testing.T) {
t.Parallel() t.Parallel()

View File

@ -19,7 +19,6 @@ import (
"github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
) )
@ -45,7 +44,7 @@ type RouterBackend struct {
// routes. // routes.
FindRoute func(source, target route.Vertex, FindRoute func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
destTlvRecords []tlv.Record, destCustomRecords record.CustomSet,
finalExpiry ...uint16) (*route.Route, error) finalExpiry ...uint16) (*route.Route, error)
MissionControl MissionControl MissionControl MissionControl
@ -232,7 +231,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// If we have any TLV records destined for the final hop, then we'll // If we have any TLV records destined for the final hop, then we'll
// attempt to decode them now into a form that the router can more // attempt to decode them now into a form that the router can more
// easily manipulate. // easily manipulate.
destTlvRecords, err := UnmarshallCustomRecords(in.DestCustomRecords) err = ValidateCustomRecords(in.DestCustomRecords)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -242,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context,
// the route. // the route.
route, err := r.FindRoute( route, err := r.FindRoute(
sourcePubKey, targetPubKey, amt, restrictions, sourcePubKey, targetPubKey, amt, restrictions,
destTlvRecords, finalCLTVDelta, in.DestCustomRecords, finalCLTVDelta,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -346,11 +345,6 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
} }
} }
tlvMap, err := tlv.RecordsToMap(hop.TLVRecords)
if err != nil {
return nil, err
}
resp.Hops[i] = &lnrpc.Hop{ resp.Hops[i] = &lnrpc.Hop{
ChanId: hop.ChannelID, ChanId: hop.ChannelID,
ChanCapacity: int64(chanCapacity), ChanCapacity: int64(chanCapacity),
@ -362,7 +356,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
PubKey: hex.EncodeToString( PubKey: hex.EncodeToString(
hop.PubKeyBytes[:], hop.PubKeyBytes[:],
), ),
CustomRecords: tlvMap, CustomRecords: hop.CustomRecords,
TlvPayload: !hop.LegacyPayload, TlvPayload: !hop.LegacyPayload,
MppRecord: mpp, MppRecord: mpp,
} }
@ -372,24 +366,16 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error)
return resp, nil return resp, nil
} }
// UnmarshallCustomRecords unmarshall rpc custom records to tlv records. // ValidateCustomRecords checks that all custom records are in the custom type
func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record, // range.
error) { func ValidateCustomRecords(rpcRecords map[uint64][]byte) error {
for key := range rpcRecords {
if len(rpcRecords) == 0 { if key < record.CustomTypeStart {
return nil, nil return fmt.Errorf("no custom records with types "+
}
tlvRecords := tlv.MapToRecords(rpcRecords)
// tlvRecords is sorted, so we only need to check that the first
// element is within the custom range.
if uint64(tlvRecords[0].Type()) < record.CustomTypeStart {
return nil, fmt.Errorf("no custom records with types "+
"below %v allowed", record.CustomTypeStart) "below %v allowed", record.CustomTypeStart)
} }
}
return tlvRecords, nil return nil
} }
// UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has // UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has
@ -397,7 +383,7 @@ func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record,
func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop, func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop,
error) { error) {
tlvRecords, err := UnmarshallCustomRecords(rpcHop.CustomRecords) err := ValidateCustomRecords(rpcHop.CustomRecords)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -412,7 +398,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop
AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat), AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat),
PubKeyBytes: pubkey, PubKeyBytes: pubkey,
ChannelID: rpcHop.ChanId, ChannelID: rpcHop.ChanId,
TLVRecords: tlvRecords, CustomRecords: rpcHop.CustomRecords,
LegacyPayload: !rpcHop.TlvPayload, LegacyPayload: !rpcHop.TlvPayload,
MPP: mpp, MPP: mpp,
}, nil }, nil
@ -540,12 +526,11 @@ func (r *RouterBackend) extractIntentFromSendRequest(
return nil, errors.New("timeout_seconds must be specified") return nil, errors.New("timeout_seconds must be specified")
} }
payIntent.FinalDestRecords, err = UnmarshallCustomRecords( err = ValidateCustomRecords(rpcPayReq.DestCustomRecords)
rpcPayReq.DestCustomRecords,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords
payIntent.PayAttemptTimeout = time.Second * payIntent.PayAttemptTimeout = time.Second *
time.Duration(rpcPayReq.TimeoutSeconds) time.Duration(rpcPayReq.TimeoutSeconds)

View File

@ -8,9 +8,9 @@ import (
"github.com/btcsuite/btcutil" "github.com/btcsuite/btcutil"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/lnrpc" "github.com/lightningnetwork/lnd/lnrpc"
) )
@ -92,7 +92,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool) {
findRoute := func(source, target route.Vertex, findRoute := func(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams,
_ []tlv.Record, _ record.CustomSet,
finalExpiry ...uint16) (*route.Route, error) { finalExpiry ...uint16) (*route.Route, error) {
if int64(amt) != amtSat*1000 { if int64(amt) != amtSat*1000 {

View File

@ -11,8 +11,8 @@ import (
"github.com/coreos/bbolt" "github.com/coreos/bbolt"
"github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/channeldb"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/tlv"
) )
const ( const (
@ -100,7 +100,7 @@ type edgePolicyWithSource struct {
func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex, func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex,
pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32,
finalCLTVDelta uint16, finalCLTVDelta uint16,
finalDestRecords []tlv.Record) (*route.Route, error) { destCustomRecords record.CustomSet) (*route.Route, error) {
var ( var (
hops []*route.Hop hops []*route.Hop
@ -198,8 +198,8 @@ func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex,
// If this is the last hop, then we'll populate any TLV records // If this is the last hop, then we'll populate any TLV records
// destined for it. // destined for it.
if i == len(pathEdges)-1 && len(finalDestRecords) != 0 { if i == len(pathEdges)-1 && len(destCustomRecords) != 0 {
currentHop.TLVRecords = finalDestRecords currentHop.CustomRecords = destCustomRecords
} }
hops = append([]*route.Hop{currentHop}, hops...) hops = append([]*route.Hop{currentHop}, hops...)

View File

@ -129,7 +129,7 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment,
sourceVertex := route.Vertex(ss.SelfNode.PubKeyBytes) sourceVertex := route.Vertex(ss.SelfNode.PubKeyBytes)
route, err := newRoute( route, err := newRoute(
payment.Amount, sourceVertex, path, height, finalCltvDelta, payment.Amount, sourceVertex, path, height, finalCltvDelta,
payment.FinalDestRecords, payment.DestCustomRecords,
) )
if err != nil { if err != nil {
// TODO(roasbeef): return which edge/vertex didn't work // TODO(roasbeef): return which edge/vertex didn't work

View File

@ -107,9 +107,9 @@ type Hop struct {
// only be set for the final hop. // only be set for the final hop.
MPP *record.MPP MPP *record.MPP
// TLVRecords if non-nil are a set of additional TLV records that // CustomRecords if non-nil are a set of additional TLV records that
// should be included in the forwarding instructions for this node. // should be included in the forwarding instructions for this node.
TLVRecords []tlv.Record CustomRecords record.CustomSet
// LegacyPayload if true, then this signals that this node doesn't // LegacyPayload if true, then this signals that this node doesn't
// understand the new TLV payload, so we must instead use the legacy // understand the new TLV payload, so we must instead use the legacy
@ -165,7 +165,8 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error {
} }
// Append any custom types destined for this hop. // Append any custom types destined for this hop.
records = append(records, h.TLVRecords...) tlvRecords := tlv.MapToRecords(h.CustomRecords)
records = append(records, tlvRecords...)
// To ensure we produce a canonical stream, we'll sort the records // To ensure we produce a canonical stream, we'll sort the records
// before encoding them as a stream in the hop payload. // before encoding them as a stream in the hop payload.

View File

@ -24,10 +24,10 @@ import (
"github.com/lightningnetwork/lnd/lnwallet/chanvalidate" "github.com/lightningnetwork/lnd/lnwallet/chanvalidate"
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/multimutex" "github.com/lightningnetwork/lnd/multimutex"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing/chainview" "github.com/lightningnetwork/lnd/routing/chainview"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/ticker" "github.com/lightningnetwork/lnd/ticker"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
) )
@ -1401,7 +1401,7 @@ type routingMsg struct {
// factoring in channel capacities and cumulative fees along the route. // factoring in channel capacities and cumulative fees along the route.
func (r *ChannelRouter) FindRoute(source, target route.Vertex, func (r *ChannelRouter) FindRoute(source, target route.Vertex,
amt lnwire.MilliSatoshi, restrictions *RestrictParams, amt lnwire.MilliSatoshi, restrictions *RestrictParams,
destTlvRecords []tlv.Record, destCustomRecords record.CustomSet,
finalExpiry ...uint16) (*route.Route, error) { finalExpiry ...uint16) (*route.Route, error) {
var finalCLTVDelta uint16 var finalCLTVDelta uint16
@ -1455,7 +1455,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex,
// Create the route with absolute time lock values. // Create the route with absolute time lock values.
route, err := newRoute( route, err := newRoute(
amt, source, path, uint32(currentHeight), finalCLTVDelta, amt, source, path, uint32(currentHeight), finalCLTVDelta,
destTlvRecords, destCustomRecords,
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -1608,11 +1608,11 @@ type LightningPayment struct {
// attempting to complete. // attempting to complete.
PaymentRequest []byte PaymentRequest []byte
// FinalDestRecords are TLV records that are to be sent to the final // DestCustomRecords are TLV records that are to be sent to the final
// hop in the new onion payload format. If the destination does not // hop in the new onion payload format. If the destination does not
// understand this new onion payload format, then the payment will // understand this new onion payload format, then the payment will
// fail. // fail.
FinalDestRecords []tlv.Record DestCustomRecords record.CustomSet
} }
// SendPayment attempts to send a payment as described within the passed // SendPayment attempts to send a payment as described within the passed

View File

@ -50,11 +50,11 @@ import (
"github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/lnwire"
"github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/macaroons"
"github.com/lightningnetwork/lnd/monitoring" "github.com/lightningnetwork/lnd/monitoring"
"github.com/lightningnetwork/lnd/record"
"github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing"
"github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/signal"
"github.com/lightningnetwork/lnd/sweep" "github.com/lightningnetwork/lnd/sweep"
"github.com/lightningnetwork/lnd/tlv"
"github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/watchtower"
"github.com/lightningnetwork/lnd/zpay32" "github.com/lightningnetwork/lnd/zpay32"
"github.com/tv42/zbase32" "github.com/tv42/zbase32"
@ -3099,7 +3099,7 @@ type rpcPaymentIntent struct {
lastHop *route.Vertex lastHop *route.Vertex
payReq []byte payReq []byte
destTLV []tlv.Record destCustomRecords record.CustomSet
route *route.Route route *route.Route
} }
@ -3158,12 +3158,13 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme
} }
payIntent.cltvLimit = cltvLimit payIntent.cltvLimit = cltvLimit
payIntent.destTLV, err = routerrpc.UnmarshallCustomRecords( err = routerrpc.ValidateCustomRecords(
rpcPayReq.DestCustomRecords, rpcPayReq.DestCustomRecords,
) )
if err != nil { if err != nil {
return payIntent, err return payIntent, err
} }
payIntent.destCustomRecords = rpcPayReq.DestCustomRecords
validateDest := func(dest route.Vertex) error { validateDest := func(dest route.Vertex) error {
if rpcPayReq.AllowSelfPayment { if rpcPayReq.AllowSelfPayment {
@ -3348,7 +3349,7 @@ func (r *rpcServer) dispatchPaymentIntent(
LastHop: payIntent.lastHop, LastHop: payIntent.lastHop,
PaymentRequest: payIntent.payReq, PaymentRequest: payIntent.payReq,
PayAttemptTimeout: routing.DefaultPayAttemptTimeout, PayAttemptTimeout: routing.DefaultPayAttemptTimeout,
FinalDestRecords: payIntent.destTLV, DestCustomRecords: payIntent.destCustomRecords,
} }
preImage, route, routerErr = r.server.chanRouter.SendPayment( preImage, route, routerErr = r.server.chanRouter.SendPayment(