diff --git a/channeldb/payments.go b/channeldb/payments.go index 3d8c6e34e..0e2b865df 100644 --- a/channeldb/payments.go +++ b/channeldb/payments.go @@ -606,7 +606,9 @@ func serializeHop(w io.Writer, h *route.Hop) error { if h.MPP != nil { records = append(records, h.MPP.Record()) } - records = append(records, h.TLVRecords...) + + tlvRecords := tlv.MapToRecords(h.CustomRecords) + records = append(records, tlvRecords...) // Otherwise, we'll transform our slice of records into a map of the // raw bytes, then serialize them in-line with a length (number of @@ -710,7 +712,7 @@ func deserializeHop(r io.Reader) (*route.Hop, error) { h.MPP = mpp } - h.TLVRecords = tlv.MapToRecords(tlvMap) + h.CustomRecords = tlvMap return h, nil } diff --git a/channeldb/payments_test.go b/channeldb/payments_test.go index bec8e528c..8eba5e5ac 100644 --- a/channeldb/payments_test.go +++ b/channeldb/payments_test.go @@ -2,7 +2,6 @@ package channeldb import ( "bytes" - "errors" "fmt" "math/rand" "reflect" @@ -28,9 +27,9 @@ var ( ChannelID: 12345, OutgoingTimeLock: 111, AmtToForward: 555, - TLVRecords: []tlv.Record{ - tlv.MakeStaticRecord(1, nil, 3, tlvEncoder, nil), - tlv.MakeStaticRecord(2, nil, 3, tlvEncoder, nil), + CustomRecords: record.CustomSet{ + 1: []byte{}, + 2: []byte{}, }, MPP: record.NewMPP(32, [32]byte{0x42}), } @@ -144,25 +143,7 @@ func TestSentPaymentSerialization(t *testing.T) { // assertRouteEquals compares to routes for equality and returns an error if // they are not equal. func assertRouteEqual(a, b *route.Route) error { - err := assertRouteHopRecordsEqual(a, b) - if err != nil { - return err - } - - // TLV records have already been compared and need to be cleared to - // properly compare the remaining fields using DeepEqual. - copyRouteNoHops := func(r *route.Route) *route.Route { - copy := *r - copy.Hops = make([]*route.Hop, len(r.Hops)) - for i, hop := range r.Hops { - hopCopy := *hop - hopCopy.TLVRecords = nil - copy.Hops[i] = &hopCopy - } - return © - } - - if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) { + if !reflect.DeepEqual(a, b) { return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v", spew.Sdump(a), spew.Sdump(b)) } @@ -170,57 +151,6 @@ func assertRouteEqual(a, b *route.Route) error { return nil } -func assertRouteHopRecordsEqual(r1, r2 *route.Route) error { - if len(r1.Hops) != len(r2.Hops) { - return errors.New("route hop count mismatch") - } - - for i := 0; i < len(r1.Hops); i++ { - records1 := r1.Hops[i].TLVRecords - records2 := r2.Hops[i].TLVRecords - if len(records1) != len(records2) { - return fmt.Errorf("route record count for hop %v "+ - "mismatch", i) - } - - for j := 0; j < len(records1); j++ { - expectedRecord := records1[j] - newRecord := records2[j] - - err := assertHopRecordsEqual(expectedRecord, newRecord) - if err != nil { - return fmt.Errorf("route record mismatch: %v", err) - } - } - } - - return nil -} - -func assertHopRecordsEqual(h1, h2 tlv.Record) error { - if h1.Type() != h2.Type() { - return fmt.Errorf("wrong type: expected %v, got %v", h1.Type(), - h2.Type()) - } - - var b bytes.Buffer - if err := h2.Encode(&b); err != nil { - return fmt.Errorf("unable to encode record: %v", err) - } - - if !bytes.Equal(b.Bytes(), tlvBytes) { - return fmt.Errorf("wrong raw record: expected %x, got %x", - tlvBytes, b.Bytes()) - } - - if h1.Size() != h2.Size() { - return fmt.Errorf("wrong size: expected %v, "+ - "got %v", h1.Size(), h2.Size()) - } - - return nil -} - func TestRouteSerialization(t *testing.T) { t.Parallel() diff --git a/lnrpc/routerrpc/router_backend.go b/lnrpc/routerrpc/router_backend.go index b9de96174..dd918f541 100644 --- a/lnrpc/routerrpc/router_backend.go +++ b/lnrpc/routerrpc/router_backend.go @@ -19,7 +19,6 @@ import ( "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/zpay32" ) @@ -45,7 +44,7 @@ type RouterBackend struct { // routes. FindRoute func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, - destTlvRecords []tlv.Record, + destCustomRecords record.CustomSet, finalExpiry ...uint16) (*route.Route, error) MissionControl MissionControl @@ -232,7 +231,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // If we have any TLV records destined for the final hop, then we'll // attempt to decode them now into a form that the router can more // easily manipulate. - destTlvRecords, err := UnmarshallCustomRecords(in.DestCustomRecords) + err = ValidateCustomRecords(in.DestCustomRecords) if err != nil { return nil, err } @@ -242,7 +241,7 @@ func (r *RouterBackend) QueryRoutes(ctx context.Context, // the route. route, err := r.FindRoute( sourcePubKey, targetPubKey, amt, restrictions, - destTlvRecords, finalCLTVDelta, + in.DestCustomRecords, finalCLTVDelta, ) if err != nil { return nil, err @@ -346,11 +345,6 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) } } - tlvMap, err := tlv.RecordsToMap(hop.TLVRecords) - if err != nil { - return nil, err - } - resp.Hops[i] = &lnrpc.Hop{ ChanId: hop.ChannelID, ChanCapacity: int64(chanCapacity), @@ -362,7 +356,7 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) PubKey: hex.EncodeToString( hop.PubKeyBytes[:], ), - CustomRecords: tlvMap, + CustomRecords: hop.CustomRecords, TlvPayload: !hop.LegacyPayload, MppRecord: mpp, } @@ -372,24 +366,16 @@ func (r *RouterBackend) MarshallRoute(route *route.Route) (*lnrpc.Route, error) return resp, nil } -// UnmarshallCustomRecords unmarshall rpc custom records to tlv records. -func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record, - error) { - - if len(rpcRecords) == 0 { - return nil, nil +// ValidateCustomRecords checks that all custom records are in the custom type +// range. +func ValidateCustomRecords(rpcRecords map[uint64][]byte) error { + for key := range rpcRecords { + if key < record.CustomTypeStart { + return fmt.Errorf("no custom records with types "+ + "below %v allowed", record.CustomTypeStart) + } } - - tlvRecords := tlv.MapToRecords(rpcRecords) - - // tlvRecords is sorted, so we only need to check that the first - // element is within the custom range. - if uint64(tlvRecords[0].Type()) < record.CustomTypeStart { - return nil, fmt.Errorf("no custom records with types "+ - "below %v allowed", record.CustomTypeStart) - } - - return tlvRecords, nil + return nil } // UnmarshallHopWithPubkey unmarshalls an rpc hop for which the pubkey has @@ -397,7 +383,7 @@ func UnmarshallCustomRecords(rpcRecords map[uint64][]byte) ([]tlv.Record, func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop, error) { - tlvRecords, err := UnmarshallCustomRecords(rpcHop.CustomRecords) + err := ValidateCustomRecords(rpcHop.CustomRecords) if err != nil { return nil, err } @@ -412,7 +398,7 @@ func UnmarshallHopWithPubkey(rpcHop *lnrpc.Hop, pubkey route.Vertex) (*route.Hop AmtToForward: lnwire.MilliSatoshi(rpcHop.AmtToForwardMsat), PubKeyBytes: pubkey, ChannelID: rpcHop.ChanId, - TLVRecords: tlvRecords, + CustomRecords: rpcHop.CustomRecords, LegacyPayload: !rpcHop.TlvPayload, MPP: mpp, }, nil @@ -540,12 +526,11 @@ func (r *RouterBackend) extractIntentFromSendRequest( return nil, errors.New("timeout_seconds must be specified") } - payIntent.FinalDestRecords, err = UnmarshallCustomRecords( - rpcPayReq.DestCustomRecords, - ) + err = ValidateCustomRecords(rpcPayReq.DestCustomRecords) if err != nil { return nil, err } + payIntent.DestCustomRecords = rpcPayReq.DestCustomRecords payIntent.PayAttemptTimeout = time.Second * time.Duration(rpcPayReq.TimeoutSeconds) diff --git a/lnrpc/routerrpc/router_backend_test.go b/lnrpc/routerrpc/router_backend_test.go index 6e55e5ad8..f6dfc2cce 100644 --- a/lnrpc/routerrpc/router_backend_test.go +++ b/lnrpc/routerrpc/router_backend_test.go @@ -8,9 +8,9 @@ import ( "github.com/btcsuite/btcutil" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/lnrpc" ) @@ -92,7 +92,7 @@ func testQueryRoutes(t *testing.T, useMissionControl bool, useMsat bool) { findRoute := func(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *routing.RestrictParams, - _ []tlv.Record, + _ record.CustomSet, finalExpiry ...uint16) (*route.Route, error) { if int64(amt) != amtSat*1000 { diff --git a/routing/pathfind.go b/routing/pathfind.go index 3f4f8f603..e8db244a4 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -11,8 +11,8 @@ import ( "github.com/coreos/bbolt" "github.com/lightningnetwork/lnd/channeldb" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" - "github.com/lightningnetwork/lnd/tlv" ) const ( @@ -100,7 +100,7 @@ type edgePolicyWithSource struct { func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex, pathEdges []*channeldb.ChannelEdgePolicy, currentHeight uint32, finalCLTVDelta uint16, - finalDestRecords []tlv.Record) (*route.Route, error) { + destCustomRecords record.CustomSet) (*route.Route, error) { var ( hops []*route.Hop @@ -198,8 +198,8 @@ func newRoute(amtToSend lnwire.MilliSatoshi, sourceVertex route.Vertex, // If this is the last hop, then we'll populate any TLV records // destined for it. - if i == len(pathEdges)-1 && len(finalDestRecords) != 0 { - currentHop.TLVRecords = finalDestRecords + if i == len(pathEdges)-1 && len(destCustomRecords) != 0 { + currentHop.CustomRecords = destCustomRecords } hops = append([]*route.Hop{currentHop}, hops...) diff --git a/routing/payment_session.go b/routing/payment_session.go index 979b7a7a4..124becbdc 100644 --- a/routing/payment_session.go +++ b/routing/payment_session.go @@ -129,7 +129,7 @@ func (p *paymentSession) RequestRoute(payment *LightningPayment, sourceVertex := route.Vertex(ss.SelfNode.PubKeyBytes) route, err := newRoute( payment.Amount, sourceVertex, path, height, finalCltvDelta, - payment.FinalDestRecords, + payment.DestCustomRecords, ) if err != nil { // TODO(roasbeef): return which edge/vertex didn't work diff --git a/routing/route/route.go b/routing/route/route.go index 6858f3b8d..1444b9df2 100644 --- a/routing/route/route.go +++ b/routing/route/route.go @@ -107,9 +107,9 @@ type Hop struct { // only be set for the final hop. MPP *record.MPP - // TLVRecords if non-nil are a set of additional TLV records that + // CustomRecords if non-nil are a set of additional TLV records that // should be included in the forwarding instructions for this node. - TLVRecords []tlv.Record + CustomRecords record.CustomSet // LegacyPayload if true, then this signals that this node doesn't // understand the new TLV payload, so we must instead use the legacy @@ -165,7 +165,8 @@ func (h *Hop) PackHopPayload(w io.Writer, nextChanID uint64) error { } // Append any custom types destined for this hop. - records = append(records, h.TLVRecords...) + tlvRecords := tlv.MapToRecords(h.CustomRecords) + records = append(records, tlvRecords...) // To ensure we produce a canonical stream, we'll sort the records // before encoding them as a stream in the hop payload. diff --git a/routing/router.go b/routing/router.go index 5020aefc4..f157b7bda 100644 --- a/routing/router.go +++ b/routing/router.go @@ -24,10 +24,10 @@ import ( "github.com/lightningnetwork/lnd/lnwallet/chanvalidate" "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/multimutex" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/chainview" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/ticker" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/zpay32" ) @@ -1401,7 +1401,7 @@ type routingMsg struct { // factoring in channel capacities and cumulative fees along the route. func (r *ChannelRouter) FindRoute(source, target route.Vertex, amt lnwire.MilliSatoshi, restrictions *RestrictParams, - destTlvRecords []tlv.Record, + destCustomRecords record.CustomSet, finalExpiry ...uint16) (*route.Route, error) { var finalCLTVDelta uint16 @@ -1455,7 +1455,7 @@ func (r *ChannelRouter) FindRoute(source, target route.Vertex, // Create the route with absolute time lock values. route, err := newRoute( amt, source, path, uint32(currentHeight), finalCLTVDelta, - destTlvRecords, + destCustomRecords, ) if err != nil { return nil, err @@ -1608,11 +1608,11 @@ type LightningPayment struct { // attempting to complete. PaymentRequest []byte - // FinalDestRecords are TLV records that are to be sent to the final + // DestCustomRecords are TLV records that are to be sent to the final // hop in the new onion payload format. If the destination does not // understand this new onion payload format, then the payment will // fail. - FinalDestRecords []tlv.Record + DestCustomRecords record.CustomSet } // SendPayment attempts to send a payment as described within the passed diff --git a/rpcserver.go b/rpcserver.go index 225fdf006..fd2ef044c 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -50,11 +50,11 @@ import ( "github.com/lightningnetwork/lnd/lnwire" "github.com/lightningnetwork/lnd/macaroons" "github.com/lightningnetwork/lnd/monitoring" + "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/signal" "github.com/lightningnetwork/lnd/sweep" - "github.com/lightningnetwork/lnd/tlv" "github.com/lightningnetwork/lnd/watchtower" "github.com/lightningnetwork/lnd/zpay32" "github.com/tv42/zbase32" @@ -3099,7 +3099,7 @@ type rpcPaymentIntent struct { lastHop *route.Vertex payReq []byte - destTLV []tlv.Record + destCustomRecords record.CustomSet route *route.Route } @@ -3158,12 +3158,13 @@ func (r *rpcServer) extractPaymentIntent(rpcPayReq *rpcPaymentRequest) (rpcPayme } payIntent.cltvLimit = cltvLimit - payIntent.destTLV, err = routerrpc.UnmarshallCustomRecords( + err = routerrpc.ValidateCustomRecords( rpcPayReq.DestCustomRecords, ) if err != nil { return payIntent, err } + payIntent.destCustomRecords = rpcPayReq.DestCustomRecords validateDest := func(dest route.Vertex) error { if rpcPayReq.AllowSelfPayment { @@ -3348,7 +3349,7 @@ func (r *rpcServer) dispatchPaymentIntent( LastHop: payIntent.lastHop, PaymentRequest: payIntent.payReq, PayAttemptTimeout: routing.DefaultPayAttemptTimeout, - FinalDestRecords: payIntent.destTLV, + DestCustomRecords: payIntent.destCustomRecords, } preImage, route, routerErr = r.server.chanRouter.SendPayment(