diff --git a/autopilot/graph.go b/autopilot/graph.go index b3653a489..ffac3f352 100644 --- a/autopilot/graph.go +++ b/autopilot/graph.go @@ -91,7 +91,7 @@ func (d *dbNode) Addrs() []net.Addr { func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { return d.db.ForEachNodeChannelTx(d.tx, d.node.PubKeyBytes, func(tx kvdb.RTx, ei models.ChannelEdgeInfo, ep, - _ *models.ChannelEdgePolicy1) error { + _ models.ChannelEdgePolicy) error { // Skip channels for which no outgoing edge policy is // available. @@ -106,16 +106,14 @@ func (d *dbNode) ForEachChannel(cb func(ChannelEdge) error) error { } node, err := d.db.FetchLightningNodeTx( - tx, ep.ToNode, + tx, ep.GetToNode(), ) if err != nil { return err } edge := ChannelEdge{ - ChanID: lnwire.NewShortChanIDFromInt( - ep.ChannelID, - ), + ChanID: ep.SCID(), Capacity: ei.GetCapacity(), Peer: &dbNode{ tx: tx, diff --git a/channeldb/edge_policy.go b/channeldb/edge_policy.go index cde470202..57dbb0e77 100644 --- a/channeldb/edge_policy.go +++ b/channeldb/edge_policy.go @@ -97,7 +97,7 @@ func encodingInfoFromEdgePolicy(policy models.ChannelEdgePolicy) ( } } -func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy1, +func putChanEdgePolicy(edges kvdb.RwBucket, edge models.ChannelEdgePolicy, from, to []byte) error { encodingInfo, err := encodingInfoFromEdgePolicy(edge) diff --git a/channeldb/graph.go b/channeldb/graph.go index 5da02b1e7..cfe5484b1 100644 --- a/channeldb/graph.go +++ b/channeldb/graph.go @@ -249,7 +249,7 @@ func NewChannelGraph(db kvdb.Backend, rejectCacheSize, chanCacheSize int, } err = g.ForEachChannel(func(info models.ChannelEdgeInfo, - policy1, policy2 *models.ChannelEdgePolicy1) error { + policy1, policy2 models.ChannelEdgePolicy) error { g.graphCache.AddChannel(info, policy1, policy2) @@ -275,10 +275,10 @@ type channelMapKey struct { // getChannelMap loads all channel edge policies from the database and stores // them in a map. func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( - map[channelMapKey]*models.ChannelEdgePolicy1, error) { + map[channelMapKey]models.ChannelEdgePolicy, error) { // Create a map to store all channel edge policies. - channelMap := make(map[channelMapKey]*models.ChannelEdgePolicy1) + channelMap := make(map[channelMapKey]models.ChannelEdgePolicy) err := kvdb.ForAll(edges, func(k, edgeBytes []byte) error { // Skip embedded buckets. @@ -321,13 +321,7 @@ func (c *ChannelGraph) getChannelMap(edges kvdb.RBucket) ( return err } - e, ok := edge.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got: %T", edge) - } - - channelMap[key] = e + channelMap[key] = edge return nil }) @@ -440,7 +434,7 @@ func (c *ChannelGraph) NewPathFindTx() (kvdb.RTx, error) { // for that particular channel edge routing policy will be passed into the // callback. func (c *ChannelGraph) ForEachChannel(cb func(models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1) error) error { + models.ChannelEdgePolicy, models.ChannelEdgePolicy) error) error { return c.db.View(func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -510,7 +504,7 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, } dbCallback := func(tx kvdb.RTx, e models.ChannelEdgeInfo, p1, - p2 *models.ChannelEdgePolicy1) error { + p2 models.ChannelEdgePolicy) error { var cachedInPolicy *models.CachedEdgePolicy if p2 != nil { @@ -521,11 +515,14 @@ func (c *ChannelGraph) ForEachNodeDirectedChannel(tx kvdb.RTx, var inboundFee lnwire.Fee if p1 != nil { - // Extract inbound fee. If there is a decoding error, - // skip this edge. - _, err := p1.ExtraOpaqueData.ExtractRecords(&inboundFee) - if err != nil { - return nil + // TODO(elle): add inbound fees to new messages! + if policy, ok := p1.(*models.ChannelEdgePolicy1); ok { + // Extract inbound fee. If there is a decoding error, + // skip this edge. + _, err := policy.ExtraOpaqueData.ExtractRecords(&inboundFee) + if err != nil { + return nil + } } } @@ -597,8 +594,8 @@ func (c *ChannelGraph) ForEachNodeCached(cb func(node route.Vertex, err := c.ForEachNodeChannelTx(tx, node.PubKeyBytes, func(tx kvdb.RTx, e models.ChannelEdgeInfo, - p1 *models.ChannelEdgePolicy1, - p2 *models.ChannelEdgePolicy1) error { + p1 models.ChannelEdgePolicy, + p2 models.ChannelEdgePolicy) error { toNodeCallback := func() route.Vertex { return node.PubKeyBytes @@ -3005,7 +3002,7 @@ func makeZombiePubkeys(info models.ChannelEdgeInfo, // updated, otherwise it's the second node's information. The node ordering is // determined by the lexicographical ordering of the identity public keys of the // nodes on either side of the channel. -func (c *ChannelGraph) UpdateEdgePolicy(edge *models.ChannelEdgePolicy1, +func (c *ChannelGraph) UpdateEdgePolicy(edge models.ChannelEdgePolicy, op ...batch.SchedulerOption) error { var ( @@ -3085,7 +3082,7 @@ func (c *ChannelGraph) updateEdgeCache(e models.ChannelEdgePolicy, // buckets using an existing database transaction. The returned boolean will be // true if the updated policy belongs to node1, and false if the policy belonged // to node2. -func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy1, +func updateEdgePolicy(tx kvdb.RwTx, edge models.ChannelEdgePolicy, graphCache *GraphCache) (bool, error) { edges := tx.ReadWriteBucket(edgeBucket) @@ -3100,7 +3097,7 @@ func updateEdgePolicy(tx kvdb.RwTx, edge *models.ChannelEdgePolicy1, // Create the channelID key be converting the channel ID // integer into a byte slice. var chanID [8]byte - byteOrder.PutUint64(chanID[:], edge.ChannelID) + byteOrder.PutUint64(chanID[:], edge.SCID().ToUint64()) // With the channel ID, we then fetch the value storing the two // nodes which connect this channel edge. @@ -3289,8 +3286,8 @@ func (c *ChannelGraph) isPublic(tx kvdb.RTx, nodePub route.Vertex, nodeIsPublic := false errDone := errors.New("done") err := c.ForEachNodeChannelTx(tx, nodePub, func(tx kvdb.RTx, - info models.ChannelEdgeInfo, _ *models.ChannelEdgePolicy1, - _ *models.ChannelEdgePolicy1) error { + info models.ChannelEdgeInfo, _ models.ChannelEdgePolicy, + _ models.ChannelEdgePolicy) error { // If this edge doesn't extend to the source node, we'll // terminate our search as we can now conclude that the node is @@ -3433,8 +3430,8 @@ func (n *graphCacheNode) Features() *lnwire.FeatureVector { // // Unknown policies are passed into the callback as nil values. func (n *graphCacheNode) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error { + cb func(kvdb.RTx, models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error { return nodeTraversal(tx, n.pubKeyBytes[:], nil, cb) } @@ -3494,8 +3491,8 @@ func (c *ChannelGraph) HasLightningNode(nodePub [33]byte) (time.Time, bool, erro // nodeTraversal is used to traverse all channels of a node given by its // public key and passes channel information into the specified callback. func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, - cb func(kvdb.RTx, models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error { + cb func(kvdb.RTx, models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error { traversal := func(tx kvdb.RTx) error { edges := tx.ReadBucket(edgeBucket) @@ -3563,30 +3560,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, return err } - var ( - in, out *models.ChannelEdgePolicy1 - ok bool - ) - if outgoingPolicy != nil { - out, ok = outgoingPolicy.(*models.ChannelEdgePolicy1) //nolint:lll - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, "+ - "got %T", outgoingPolicy) - } - } - - if incomingPolicy != nil { - in, ok = incomingPolicy.(*models.ChannelEdgePolicy1) //nolint:lll - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, "+ - "got %T", incomingPolicy) - } - } - // Finally, we execute the callback. - err = cb(tx, edgeInfo, out, in) + err = cb(tx, edgeInfo, outgoingPolicy, incomingPolicy) if err != nil { return err } @@ -3615,8 +3590,8 @@ func nodeTraversal(tx kvdb.RTx, nodePub []byte, db kvdb.Backend, // // Unknown policies are passed into the callback as nil values. func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, - cb func(kvdb.RTx, models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error { + cb func(kvdb.RTx, models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error { return nodeTraversal(nil, nodePub[:], c.db, cb) } @@ -3636,8 +3611,8 @@ func (c *ChannelGraph) ForEachNodeChannel(nodePub route.Vertex, // traversal. func (c *ChannelGraph) ForEachNodeChannelTx(tx kvdb.RTx, nodePub route.Vertex, cb func(kvdb.RTx, models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error { + models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error { return nodeTraversal(tx, nodePub[:], c.db, cb) } @@ -3724,13 +3699,13 @@ func computeEdgePolicyKeys(info models.ChannelEdgeInfo) ([]byte, []byte) { // information for the channel itself is returned as well as two structs that // contain the routing policies for the channel in either direction. func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *models.ChannelEdgePolicy1 - policy2 *models.ChannelEdgePolicy1 + policy1 models.ChannelEdgePolicy + policy2 models.ChannelEdgePolicy ) err := kvdb.View(c.db, func(tx kvdb.RTx) error { @@ -3779,34 +3754,13 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( // Once we have the information about the channels' parameters, // we'll fetch the routing policies for each for the directed // edges. - edge1, edge2, err := fetchChanEdgePolicies( + e1, e2, err := fetchChanEdgePolicies( edgeIndex, edges, chanID, ) if err != nil { return fmt.Errorf("failed to find policy: %w", err) } - var ( - e1, e2 *models.ChannelEdgePolicy1 - ok bool - ) - if edge1 != nil { - e1, ok = edge1.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got %T", - edge1) - } - } - if edge2 != nil { - e2, ok = edge2.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got %T", - edge1) - } - } - policy1 = e1 policy2 = e2 return nil @@ -3832,13 +3786,13 @@ func (c *ChannelGraph) FetchChannelEdgesByOutpoint(op *wire.OutPoint) ( // within the database. In this case, the ChannelEdgePolicy1's will be nil, and // the ChannelEdgeInfo1 will only include the public keys of each node. func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { var ( edgeInfo models.ChannelEdgeInfo - policy1 *models.ChannelEdgePolicy1 - policy2 *models.ChannelEdgePolicy1 + policy1 models.ChannelEdgePolicy + policy2 models.ChannelEdgePolicy channelID [8]byte ) @@ -3904,34 +3858,13 @@ func (c *ChannelGraph) FetchChannelEdgesByID(chanID uint64) ( // Then we'll attempt to fetch the accompanying policies of this // edge. - edge1, edge2, err := fetchChanEdgePolicies( + e1, e2, err := fetchChanEdgePolicies( edgeIndex, edges, channelID[:], ) if err != nil { return err } - var ( - e1, e2 *models.ChannelEdgePolicy1 - ok bool - ) - if edge1 != nil { - e1, ok = edge1.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expecgted "+ - "*models.ChannelEdgePolicy1, got %T", - edge1) - } - } - if edge2 != nil { - e2, ok = edge2.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expecgted "+ - "*models.ChannelEdgePolicy1, got %T", - edge1) - } - } - policy1 = e1 policy2 = e2 return nil diff --git a/channeldb/graph_cache.go b/channeldb/graph_cache.go index 86e73a341..9a7cbdaa5 100644 --- a/channeldb/graph_cache.go +++ b/channeldb/graph_cache.go @@ -29,8 +29,8 @@ type GraphCacheNode interface { // to the caller. ForEachChannel(kvdb.RTx, func(kvdb.RTx, models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error + models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error } // DirectedChannel is a type that stores the channel information as seen from @@ -143,22 +143,10 @@ func (c *GraphCache) AddNode(tx kvdb.RTx, node GraphCacheNode) error { return node.ForEachChannel( tx, func(tx kvdb.RTx, info models.ChannelEdgeInfo, - outPolicy *models.ChannelEdgePolicy1, - inPolicy *models.ChannelEdgePolicy1) error { + outPolicy models.ChannelEdgePolicy, + inPolicy models.ChannelEdgePolicy) error { - // TODO(elle): remove once the ForEachChannel call back - // passes down the interface values instead of the - // pointers. This is temporarily required to prevent - // a nil pointer dereference. - var in, out models.ChannelEdgePolicy - if outPolicy != nil { - out = outPolicy - } - if inPolicy != nil { - in = inPolicy - } - - c.AddChannel(info, out, in) + c.AddChannel(info, outPolicy, inPolicy) return nil }, diff --git a/channeldb/graph_cache_test.go b/channeldb/graph_cache_test.go index 89581022e..300b3f82e 100644 --- a/channeldb/graph_cache_test.go +++ b/channeldb/graph_cache_test.go @@ -42,8 +42,8 @@ func (n *node) Features() *lnwire.FeatureVector { } func (n *node) ForEachChannel(tx kvdb.RTx, - cb func(kvdb.RTx, models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error { + cb func(kvdb.RTx, models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error { for idx := range n.edgeInfos { err := cb( diff --git a/channeldb/graph_test.go b/channeldb/graph_test.go index f2341d7ac..e5f411c88 100644 --- a/channeldb/graph_test.go +++ b/channeldb/graph_test.go @@ -1055,8 +1055,8 @@ func TestGraphTraversal(t *testing.T) { // again if the map is empty that indicates that all edges have // properly been reached. err = graph.ForEachChannel(func(ei models.ChannelEdgeInfo, - _ *models.ChannelEdgePolicy1, - _ *models.ChannelEdgePolicy1) error { + _ models.ChannelEdgePolicy, + _ models.ChannelEdgePolicy) error { delete(chanIndex, ei.GetChanID()) return nil @@ -1070,7 +1070,7 @@ func TestGraphTraversal(t *testing.T) { firstNode, secondNode := nodeList[0], nodeList[1] err = graph.ForEachNodeChannel(firstNode.PubKeyBytes, func(_ kvdb.RTx, _ models.ChannelEdgeInfo, outEdge, - inEdge *models.ChannelEdgePolicy1) error { + inEdge models.ChannelEdgePolicy) error { // All channels between first and second node should // have fully (both sides) specified policies. @@ -1080,8 +1080,9 @@ func TestGraphTraversal(t *testing.T) { // Each should indicate that it's outgoing (pointed // towards the second node). + outToNode := outEdge.GetToNode() if !bytes.Equal( - outEdge.ToNode[:], secondNode.PubKeyBytes[:], + outToNode[:], secondNode.PubKeyBytes[:], ) { return fmt.Errorf("wrong outgoing edge") @@ -1089,8 +1090,9 @@ func TestGraphTraversal(t *testing.T) { // The incoming edge should also indicate that it's // pointing to the origin node. + inToNode := inEdge.GetToNode() if !bytes.Equal( - inEdge.ToNode[:], firstNode.PubKeyBytes[:], + inToNode[:], firstNode.PubKeyBytes[:], ) { return fmt.Errorf("wrong outgoing edge") @@ -1151,8 +1153,8 @@ func TestGraphTraversalCacheable(t *testing.T) { err := node.ForEachChannel( tx, func(tx kvdb.RTx, info models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy1, - policy2 *models.ChannelEdgePolicy1) error { //nolint:lll + policy models.ChannelEdgePolicy, + policy2 models.ChannelEdgePolicy) error { //nolint:lll delete(chanIndex, info.GetChanID()) return nil @@ -1335,8 +1337,8 @@ func assertPruneTip(t *testing.T, graph *ChannelGraph, blockHash *chainhash.Hash func assertNumChans(t *testing.T, graph *ChannelGraph, n int) { numChans := 0 if err := graph.ForEachChannel(func(models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error { + models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error { numChans++ return nil @@ -2762,7 +2764,7 @@ func TestIncompleteChannelPolicies(t *testing.T) { calls := 0 err := graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, _ models.ChannelEdgeInfo, outEdge, - inEdge *models.ChannelEdgePolicy1) error { + inEdge models.ChannelEdgePolicy) error { if !expectedOut && outEdge != nil { t.Fatalf("Expected no outgoing policy") @@ -3917,8 +3919,8 @@ func BenchmarkForEachChannel(b *testing.B) { for _, n := range nodes { cb := func(tx kvdb.RTx, info models.ChannelEdgeInfo, - policy *models.ChannelEdgePolicy1, - policy2 *models.ChannelEdgePolicy1) error { //nolint:lll + policy models.ChannelEdgePolicy, + policy2 models.ChannelEdgePolicy) error { //nolint:lll // We need to do something with // the data here, otherwise the @@ -3926,8 +3928,10 @@ func BenchmarkForEachChannel(b *testing.B) { // this away, and we get bogus // results. totalCapacity += info.GetCapacity() - maxHTLCs += policy.MaxHTLC - maxHTLCs += policy2.MaxHTLC + maxHTLCs += policy.ForwardingPolicy(). + MaxHTLC + maxHTLCs += policy2.ForwardingPolicy(). + MaxHTLC return nil } diff --git a/channeldb/models/channel_edge_policy.go b/channeldb/models/channel_edge_policy.go index 1393b9477..eece025ff 100644 --- a/channeldb/models/channel_edge_policy.go +++ b/channeldb/models/channel_edge_policy.go @@ -188,6 +188,10 @@ func (c *ChannelEdgePolicy1) AfterUpdateMsg(msg lnwire.ChannelUpdate) (bool, return c.LastUpdate.After(timestamp), nil } +func (c *ChannelEdgePolicy1) ExtraData() lnwire.ExtraOpaqueData { + return c.ExtraOpaqueData +} + // Sig returns the signature of the update message. // // NOTE: This is part of the ChannelEdgePolicy interface. diff --git a/channeldb/models/interfaces.go b/channeldb/models/interfaces.go index f250bc95c..bfca995fa 100644 --- a/channeldb/models/interfaces.go +++ b/channeldb/models/interfaces.go @@ -94,4 +94,6 @@ type ChannelEdgePolicy interface { // Sig returns the signature of the update message. Sig() (input.Signature, error) + + ExtraData() lnwire.ExtraOpaqueData } diff --git a/discovery/gossiper.go b/discovery/gossiper.go index 36f1be28c..01f4adac8 100644 --- a/discovery/gossiper.go +++ b/discovery/gossiper.go @@ -556,7 +556,7 @@ type EdgeWithInfo struct { Info models.ChannelEdgeInfo // Edge describes the policy in one direction of the channel. - Edge *models.ChannelEdgePolicy1 + Edge models.ChannelEdgePolicy } // PropagateChanPolicyUpdate signals the AuthenticatedGossiper to perform the @@ -1621,7 +1621,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // within the prune interval or re-broadcast interval. type updateTuple struct { info models.ChannelEdgeInfo - edge *models.ChannelEdgePolicy1 + edge models.ChannelEdgePolicy } var ( @@ -1631,7 +1631,7 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { err := d.cfg.Graph.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy1) error { + edge models.ChannelEdgePolicy) error { // If there's no auth proof attached to this edge, it means // that it is a private channel not meant to be announced to @@ -1652,31 +1652,58 @@ func (d *AuthenticatedGossiper) retransmitStaleAnns(now time.Time) error { // If this edge has a ChannelUpdate that was created before the // introduction of the MaxHTLC field, then we'll update this // edge to propagate this information in the network. - if !edge.MessageFlags.HasMaxHtlc() { - // We'll make sure we support the new max_htlc field if - // not already present. - edge.MessageFlags |= lnwire.ChanUpdateRequiredMaxHtlc - edge.MaxHTLC = lnwire.NewMSatFromSatoshis( - info.GetCapacity(), - ) + if e, ok := edge.(*models.ChannelEdgePolicy1); ok { + if !e.MessageFlags.HasMaxHtlc() { + // We'll make sure we support the new max_htlc + // field if not already present. + e.MessageFlags |= + lnwire.ChanUpdateRequiredMaxHtlc + e.MaxHTLC = lnwire.NewMSatFromSatoshis( + info.GetCapacity(), + ) + edgesToUpdate = append( + edgesToUpdate, updateTuple{ + info: info, + edge: e, + }, + ) - edgesToUpdate = append(edgesToUpdate, updateTuple{ - info: info, - edge: edge, - }) - return nil + return nil + } } - timeElapsed := now.Sub(edge.LastUpdate) + switch e := edge.(type) { + case *models.ChannelEdgePolicy1: + timeElapsed := now.Sub(e.LastUpdate) - // If it's been longer than RebroadcastInterval since we've - // re-broadcasted the channel, add the channel to the set of - // edges we need to update. - if timeElapsed >= d.cfg.RebroadcastInterval { - edgesToUpdate = append(edgesToUpdate, updateTuple{ - info: info, - edge: edge, - }) + // If it's been longer than RebroadcastInterval since + // we've re-broadcasted the channel, add the channel to + // the set of edges we need to update. + if timeElapsed >= d.cfg.RebroadcastInterval { + edgesToUpdate = append(edgesToUpdate, + updateTuple{ + info: info, + edge: e, + }, + ) + } + + case *models.ChannelEdgePolicy2: + blocksSince := d.latestHeight() - e.BlockHeight.Val + + // If it's been longer than RebroadcastInterval since + // we've re-broadcasted the channel, add the channel to + // the set of edges we need to update. + if blocksSince >= + uint32(d.cfg.RebroadcastInterval.Hours()*6) { + + edgesToUpdate = append(edgesToUpdate, + updateTuple{ + info: info, + edge: e, + }, + ) + } } return nil @@ -2167,8 +2194,8 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // can safely delete the local proof from the database. return chanInfo.GetAuthProof() != nil - case *lnwire.ChannelUpdate1: - _, p1, p2, err := d.cfg.Graph.GetChannelByID(msg.ShortChannelID) + case lnwire.ChannelUpdate: + _, p1, p2, err := d.cfg.Graph.GetChannelByID(msg.SCID()) // If the channel cannot be found, it is most likely a leftover // message for a channel that was closed, so we can consider it @@ -2178,15 +2205,15 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { } if err != nil { log.Debugf("Unable to retrieve channel=%v from graph: "+ - "%v", msg.ShortChannelID, err) + "%v", msg.SCID(), err) return false } // Otherwise, we'll retrieve the correct policy that we // currently have stored within our graph to check if this // message is stale by comparing its timestamp. - var p *models.ChannelEdgePolicy1 - if msg.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + var p models.ChannelEdgePolicy + if msg.IsNode1() { p = p1 } else { p = p2 @@ -2198,8 +2225,16 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { return false } - timestamp := time.Unix(int64(msg.Timestamp), 0) - return p.LastUpdate.After(timestamp) + after, err := p.AfterUpdateMsg(msg) + if err != nil { + log.Errorf("Unable to check if stored policy is "+ + "after message for channel=%v: %v", + msg.SCID(), err) + + return false + } + + return after default: // We'll make sure to not mark any unsupported messages as stale @@ -2211,17 +2246,23 @@ func (d *AuthenticatedGossiper) isMsgStale(msg lnwire.Message) bool { // updateChannel creates a new fully signed update for the channel, and updates // the underlying graph with the new state. func (d *AuthenticatedGossiper) updateChannel(edgeInfo models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy1) (lnwire.ChannelAnnouncement, + edgePolicy models.ChannelEdgePolicy) (lnwire.ChannelAnnouncement, *lnwire.ChannelUpdate1, error) { // Parse the unsigned edge into a channel update. chanUpdate, err := netann.UnsignedChannelUpdateFromEdge( - edgeInfo.GetChainHash(), edge, + edgeInfo.GetChainHash(), edgePolicy, ) if err != nil { return nil, nil, err } + edge, ok := edgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + return nil, nil, fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", edgePolicy) + } + // We'll generate a new signature over a digest of the channel // announcement itself and update the timestamp to ensure it propagate. err = netann.SignChannelUpdate( @@ -3023,7 +3064,7 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, // being updated. var ( pubKey *btcec.PublicKey - edgeToUpdate *models.ChannelEdgePolicy1 + edgeToUpdate models.ChannelEdgePolicy ) direction := upd.ChannelFlags & lnwire.ChanUpdateDirection switch direction { @@ -3057,15 +3098,31 @@ func (d *AuthenticatedGossiper) handleChanUpdate(nMsg *networkMsg, return nil, false } + var edge *models.ChannelEdgePolicy1 + if edgeToUpdate != nil { + var ok bool + edge, ok = edgeToUpdate.(*models.ChannelEdgePolicy1) + if !ok { + rErr := fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", + edgeToUpdate) + + log.Error(rErr) + nMsg.err <- rErr + + return nil, false + } + } + // If we have a previous version of the edge being updated, we'll want // to rate limit its updates to prevent spam throughout the network. - if nMsg.isRemote && edgeToUpdate != nil { + if nMsg.isRemote && edge != nil { // If it's a keep-alive update, we'll only propagate one if // it's been a day since the previous. This follows our own // heuristic of sending keep-alive updates after the same // duration (see retransmitStaleAnns). - timeSinceLastUpdate := timestamp.Sub(edgeToUpdate.LastUpdate) - if IsKeepAliveUpdate(upd, edgeToUpdate) { + timeSinceLastUpdate := timestamp.Sub(edge.LastUpdate) + if IsKeepAliveUpdate(upd, edge) { if timeSinceLastUpdate < d.cfg.RebroadcastInterval { log.Debugf("Ignoring keep alive update not "+ "within %v period for channel %v", diff --git a/discovery/gossiper_test.go b/discovery/gossiper_test.go index c328de59f..3c994c203 100644 --- a/discovery/gossiper_test.go +++ b/discovery/gossiper_test.go @@ -215,7 +215,7 @@ func (r *mockGraphSource) ForEachNode(func(node *channeldb.LightningNode) error) func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, i models.ChannelEdgeInfo, - c *models.ChannelEdgePolicy1) error) error { + c models.ChannelEdgePolicy) error) error { r.mu.Lock() defer r.mu.Unlock() @@ -237,14 +237,7 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, } for _, channel := range chans { - pol, ok := channel.Policy1.(*models.ChannelEdgePolicy1) - if !ok { - return fmt.Errorf("expected "+ - "*models.ChannelEdgePolicy1, got %T", - channel.Policy1) - } - - if err := cb(nil, channel.Info, pol); err != nil { + if err := cb(nil, channel.Info, channel.Policy1); err != nil { return err } } @@ -254,8 +247,8 @@ func (r *mockGraphSource) ForAllOutgoingChannels(cb func(tx kvdb.RTx, func (r *mockGraphSource) GetChannelByID(chanID lnwire.ShortChannelID) ( models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { r.mu.Lock() defer r.mu.Unlock() @@ -3501,7 +3494,12 @@ out: err = ctx.router.ForAllOutgoingChannels(func( _ kvdb.RTx, info models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy1) error { + edgePolicy models.ChannelEdgePolicy) error { + + edge, ok := edgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + t.Fatalf("expected *models.ChannelEdgePolicy1") + } edge.TimeLockDelta = uint16(newTimeLockDelta) edgesToUpdate = append(edgesToUpdate, EdgeWithInfo{ diff --git a/funding/manager.go b/funding/manager.go index 063ba7594..8805e7823 100644 --- a/funding/manager.go +++ b/funding/manager.go @@ -534,7 +534,7 @@ type Config struct { // DeleteAliasEdge allows the Manager to delete an alias channel edge // from the graph. It also returns our local to-be-deleted policy. DeleteAliasEdge func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy1, error) + models.ChannelEdgePolicy, error) // AliasManager is an implementation of the aliasHandler interface that // abstracts away the handling of many alias functions. @@ -3427,7 +3427,7 @@ func (f *Manager) extractAnnounceParams(c *channeldb.OpenChannel) ( func (f *Manager) addToGraph(completeChan *channeldb.OpenChannel, shortChanID *lnwire.ShortChannelID, peerAlias *lnwire.ShortChannelID, - ourPolicy *models.ChannelEdgePolicy1) error { + ourPolicy models.ChannelEdgePolicy) error { chanID := lnwire.NewChanIDFromOutPoint(completeChan.FundingOutpoint) @@ -4160,9 +4160,19 @@ func (f *Manager) newChanAnnouncement(localPubKey, remotePubKey *btcec.PublicKey, localFundingKey *keychain.KeyDescriptor, remoteFundingKey *btcec.PublicKey, shortChanID lnwire.ShortChannelID, chanID lnwire.ChannelID, fwdMinHTLC, fwdMaxHTLC lnwire.MilliSatoshi, - ourPolicy *models.ChannelEdgePolicy1, + ourEdgePolicy models.ChannelEdgePolicy, chanType channeldb.ChannelType) (*chanAnnouncement, error) { + var ourPolicy *models.ChannelEdgePolicy1 + if ourEdgePolicy != nil { + var ok bool + ourPolicy, ok = ourEdgePolicy.(*models.ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected "+ + "ChannelEdgePolicy1, got: %T", ourEdgePolicy) + } + } + chainHash := *f.cfg.Wallet.Cfg.NetParams.GenesisHash // The unconditional section of the announcement is the ShortChannelID diff --git a/funding/manager_test.go b/funding/manager_test.go index 8d8b81c1d..26fd0ca3c 100644 --- a/funding/manager_test.go +++ b/funding/manager_test.go @@ -554,7 +554,7 @@ func createTestFundingManager(t *testing.T, privKey *btcec.PrivateKey, OpenChannelPredicate: chainedAcceptor, NotifyPendingOpenChannelEvent: evt.NotifyPendingOpenChannelEvent, DeleteAliasEdge: func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgePolicy, error) { return nil, nil }, diff --git a/graph/builder.go b/graph/builder.go index 409e5d6d1..d74c3b6b0 100644 --- a/graph/builder.go +++ b/graph/builder.go @@ -1411,6 +1411,90 @@ func (b *Builder) processUpdate(msg interface{}, lnutils.SpewLogClosure(msg)) b.stats.incNumChannelUpdates() + case *models.ChannelEdgePolicy2: + chanID := msg.ShortChannelID.Val.ToUint64() + + log.Debugf("Received ChannelEdgePolicy2 for channel %v", chanID) + + // We make sure to hold the mutex for this channel ID, + // such that no other goroutine is concurrently doing + // database accesses for the same channel ID. + b.channelEdgeMtx.Lock(chanID) + defer b.channelEdgeMtx.Unlock(chanID) + + edge1Height, edge2Height, exists, isZombie, err := + b.cfg.Graph.HasChannelEdge2(chanID) + if err != nil && + !errors.Is(err, channeldb.ErrGraphNoEdgesFound) { + + return errors.Errorf("unable to check for edge "+ + "existence: %v", err) + } + + // If the channel is marked as a zombie in our database, and + // we consider this a stale update, then we should not apply the + // policy. + blocksSinceMsg := b.SyncedHeight() - msg.BlockHeight.Val + isStaleUpdate := blocksSinceMsg > uint32( + b.cfg.ChannelPruneExpiry.Hours()*6, + ) + if isZombie && isStaleUpdate { + return NewErrf(ErrIgnored, "ignoring stale update "+ + "(is_node_1=%v|disable_flags=%v) for zombie "+ + "chan_id=%v", msg.IsNode1(), msg.DisabledFlags, + chanID) + } + + // If the channel doesn't exist in our database, we cannot + // apply the updated policy. + if !exists { + return NewErrf(ErrIgnored, "ignoring update "+ + "(is_node_1=%v|disable_flags=%v) for unknown "+ + "chan_id=%v", msg.IsNode1(), msg.DisabledFlags, + chanID) + } + + // As edges are directional edge node has a unique policy for + // the direction of the edge they control. Therefore we first + // check if we already have the most up to date information for + // that edge. If this message has a timestamp not strictly + // newer than what we already know of we can exit early. + switch { + case msg.IsNode1(): + // Ignore outdated message. + if edge1Height >= msg.BlockHeight.Val { + return NewErrf(ErrOutdated, "Ignoring "+ + "outdated update "+ + "(is_node_1=%v|disable_flags=%v) for "+ + "known chan_id=%v", msg.IsNode1(), + msg.DisabledFlags, chanID) + } + + case !msg.IsNode1(): + // Ignore outdated message. + if edge2Height >= msg.BlockHeight.Val { + return NewErrf(ErrOutdated, "Ignoring "+ + "outdated update "+ + "(is_node_1=%v|disable_flags=%v) for "+ + "known chan_id=%v", msg.IsNode1(), + msg.DisabledFlags, chanID) + } + } + + // Now that we know this isn't a stale update, we'll apply the + // new edge policy to the proper directional edge within the + // channel graph. + if err = b.cfg.Graph.UpdateEdgePolicy(msg, op...); err != nil { + err := errors.Errorf("unable to add channel: %v", err) + log.Error(err) + return err + } + + log.Tracef("New channel update applied: %v", + lnutils.SpewLogClosure(msg)) + + b.stats.incNumChannelUpdates() + default: return errors.Errorf("wrong routing update message type") } @@ -1579,8 +1663,8 @@ func (b *Builder) SyncedHeight() uint32 { // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) GetChannelByID(chanID lnwire.ShortChannelID) ( models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { return b.cfg.Graph.FetchChannelEdgesByID(chanID.ToUint64()) } @@ -1613,12 +1697,12 @@ func (b *Builder) ForEachNode( // // NOTE: This method is part of the ChannelGraphSource interface. func (b *Builder) ForAllOutgoingChannels(cb func(kvdb.RTx, - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1) error) error { + models.ChannelEdgeInfo, models.ChannelEdgePolicy) error) error { return b.cfg.Graph.ForEachNodeChannel(b.cfg.SelfNode, func(tx kvdb.RTx, c models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy1, - _ *models.ChannelEdgePolicy1) error { + e models.ChannelEdgePolicy, + _ models.ChannelEdgePolicy) error { if e == nil { return fmt.Errorf("channel from self node " + diff --git a/graph/interfaces.go b/graph/interfaces.go index 09ab5f6a0..df8d16932 100644 --- a/graph/interfaces.go +++ b/graph/interfaces.go @@ -71,7 +71,7 @@ type ChannelGraphSource interface { // star-graph. ForAllOutgoingChannels(cb func(tx kvdb.RTx, c models.ChannelEdgeInfo, - e *models.ChannelEdgePolicy1) error) error + e models.ChannelEdgePolicy) error) error // CurrentBlockHeight returns the block height from POV of the router // subsystem. @@ -79,8 +79,8 @@ type ChannelGraphSource interface { // GetChannelByID return the channel by the channel id. GetChannelByID(chanID lnwire.ShortChannelID) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) // FetchLightningNode attempts to look up a target node by its identity // public key. channeldb.ErrGraphNodeNotFound is returned if the node @@ -178,6 +178,7 @@ type DB interface { HasChannelEdge(chanID uint64) (bool, bool, error) HasChannelEdge1(chanID uint64) (time.Time, time.Time, bool, bool, error) + HasChannelEdge2(chanID uint64) (uint32, uint32, bool, bool, error) // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. If the channel can't be @@ -191,7 +192,7 @@ type DB interface { // will be nil, and the ChannelEdgeInfo1 will only include the public // keys of each node. FetchChannelEdgesByID(chanID uint64) (models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) + models.ChannelEdgePolicy, models.ChannelEdgePolicy, error) // AddLightningNode adds a vertex/node to the graph database. If the // node is not in the database from before, this will add a new, @@ -225,7 +226,7 @@ type DB interface { // node's information. The node ordering is determined by the // lexicographical ordering of the identity public keys of the nodes on // either side of the channel. - UpdateEdgePolicy(edge *models.ChannelEdgePolicy1, + UpdateEdgePolicy(edge models.ChannelEdgePolicy, op ...batch.SchedulerOption) error // HasLightningNode determines if the graph has a vertex identified by @@ -258,8 +259,8 @@ type DB interface { // Unknown policies are passed into the callback as nil values. ForEachNodeChannel(nodePub route.Vertex, cb func(kvdb.RTx, models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1) error) error + models.ChannelEdgePolicy, + models.ChannelEdgePolicy) error) error // UpdateChannelEdge retrieves and update edge of the graph database. // Method only reserved for updating an edge info after its already been diff --git a/graph/notifications.go b/graph/notifications.go index 8ba68dd64..040f266dc 100644 --- a/graph/notifications.go +++ b/graph/notifications.go @@ -342,11 +342,13 @@ func addToTopologyChange(graph DB, update *TopologyChange, // Any new ChannelUpdateAnnouncements will generate a corresponding // ChannelEdgeUpdate notification. - case *models.ChannelEdgePolicy1: + case models.ChannelEdgePolicy: // We'll need to fetch the edge's information from the database // in order to get the information concerning which nodes are // being connected. - edgeInfo, _, _, err := graph.FetchChannelEdgesByID(m.ChannelID) + edgeInfo, _, _, err := graph.FetchChannelEdgesByID( + m.SCID().ToUint64(), + ) if err != nil { return errors.Errorf("unable fetch channel edge: %v", err) @@ -356,7 +358,7 @@ func addToTopologyChange(graph DB, update *TopologyChange, // the second node. sourceNode := edgeInfo.NodeKey1 connectingNode := edgeInfo.NodeKey2 - if m.ChannelFlags&lnwire.ChanUpdateDirection == 1 { + if !m.IsNode1() { sourceNode = edgeInfo.NodeKey2 connectingNode = edgeInfo.NodeKey1 } @@ -370,19 +372,20 @@ func addToTopologyChange(graph DB, update *TopologyChange, return err } + policy := m.ForwardingPolicy() edgeUpdate := &ChannelEdgeUpdate{ - ChanID: m.ChannelID, + ChanID: m.SCID().ToUint64(), ChanPoint: edgeInfo.GetChanPoint(), - TimeLockDelta: m.TimeLockDelta, + TimeLockDelta: policy.TimeLockDelta, Capacity: edgeInfo.GetCapacity(), - MinHTLC: m.MinHTLC, - MaxHTLC: m.MaxHTLC, - BaseFee: m.FeeBaseMSat, - FeeRate: m.FeeProportionalMillionths, + MinHTLC: policy.MinHTLC, + MaxHTLC: policy.MaxHTLC, + BaseFee: policy.BaseFee, + FeeRate: policy.FeeRate, AdvertisingNode: aNode, ConnectingNode: cNode, - Disabled: m.ChannelFlags&lnwire.ChanUpdateDisabled != 0, - ExtraOpaqueData: m.ExtraOpaqueData, + Disabled: m.IsDisabled(), + ExtraOpaqueData: m.ExtraData(), } // TODO(roasbeef): add bit to toggle diff --git a/graph/validation_barrier.go b/graph/validation_barrier.go index fcd40d0f2..bb431522a 100644 --- a/graph/validation_barrier.go +++ b/graph/validation_barrier.go @@ -144,7 +144,7 @@ func (v *ValidationBarrier) InitJobDependencies(job interface{}) { // These other types don't have any dependants, so no further // initialization needs to be done beyond just occupying a job slot. - case *models.ChannelEdgePolicy1: + case models.ChannelEdgePolicy: return case *lnwire.ChannelUpdate1: return @@ -186,14 +186,13 @@ func (v *ValidationBarrier) WaitForDependants(job interface{}) error { v.Lock() switch msg := job.(type) { - // Any ChannelUpdate or NodeAnnouncement jobs will need to wait on the - // completion of any active ChannelAnnouncement jobs related to them. - case *models.ChannelEdgePolicy1: - shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) - signals, ok = v.chanEdgeDependencies[shortID] + // Any ChannelUpdate1 or NodeAnnouncement1 jobs will need to wait on the + // completion of any active ChannelAnnouncement1 jobs related to them. + case models.ChannelEdgePolicy: + signals, ok = v.chanEdgeDependencies[msg.SCID()] - jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy1, scid=%v", - msg.ChannelID) + jobDesc = fmt.Sprintf("job=lnwire.ChannelEdgePolicy, scid=%v", + msg.SCID().ToUint64()) case *channeldb.LightningNode: vertex := route.Vertex(msg.PubKeyBytes) @@ -299,9 +298,8 @@ func (v *ValidationBarrier) SignalDependants(job interface{}, allow bool) { delete(v.nodeAnnDependencies, route.Vertex(msg.NodeID)) case *lnwire.ChannelUpdate1: delete(v.chanEdgeDependencies, msg.ShortChannelID) - case *models.ChannelEdgePolicy1: - shortID := lnwire.NewShortChanIDFromInt(msg.ChannelID) - delete(v.chanEdgeDependencies, shortID) + case models.ChannelEdgePolicy: + delete(v.chanEdgeDependencies, msg.SCID()) case *lnwire.AnnounceSignatures1: return diff --git a/lnrpc/invoicesrpc/addinvoice.go b/lnrpc/invoicesrpc/addinvoice.go index 7cc2223ab..a5d199b8e 100644 --- a/lnrpc/invoicesrpc/addinvoice.go +++ b/lnrpc/invoicesrpc/addinvoice.go @@ -766,7 +766,7 @@ type SelectHopHintsCfg struct { // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. FetchChannelEdgesByID func(chanID uint64) (models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, + models.ChannelEdgePolicy, models.ChannelEdgePolicy, error) // GetAlias allows the peer's alias SCID to be retrieved for private diff --git a/lnrpc/invoicesrpc/addinvoice_test.go b/lnrpc/invoicesrpc/addinvoice_test.go index c7960ea9b..9be854cdd 100644 --- a/lnrpc/invoicesrpc/addinvoice_test.go +++ b/lnrpc/invoicesrpc/addinvoice_test.go @@ -67,8 +67,8 @@ func (h *hopHintsConfigMock) FetchAllChannels() ([]*channeldb.OpenChannel, // FetchChannelEdgesByID attempts to lookup the two directed edges for // the channel identified by the channel ID. func (h *hopHintsConfigMock) FetchChannelEdgesByID(chanID uint64) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { args := h.Mock.Called(chanID) diff --git a/netann/chan_status_manager_test.go b/netann/chan_status_manager_test.go index 9617c0b34..1288056a8 100644 --- a/netann/chan_status_manager_test.go +++ b/netann/chan_status_manager_test.go @@ -170,7 +170,7 @@ func (g *mockGraph) FetchAllOpenChannels() ([]*channeldb.OpenChannel, error) { func (g *mockGraph) FetchChannelEdgesByOutpoint( op *wire.OutPoint) (models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) { + models.ChannelEdgePolicy, models.ChannelEdgePolicy, error) { g.mu.Lock() defer g.mu.Unlock() diff --git a/netann/channel_update.go b/netann/channel_update.go index 413314ba7..902b8a176 100644 --- a/netann/channel_update.go +++ b/netann/channel_update.go @@ -85,8 +85,7 @@ func SignChannelUpdate(signer lnwallet.MessageSigner, keyLoc keychain.KeyLocator // // NOTE: The passed policies can be nil. func ExtractChannelUpdate(ownerPubKey []byte, - info models.ChannelEdgeInfo, - policies ...*models.ChannelEdgePolicy1) ( + info models.ChannelEdgeInfo, policies ...models.ChannelEdgePolicy) ( *lnwire.ChannelUpdate1, error) { // Helper function to extract the owner of the given policy. @@ -108,7 +107,13 @@ func ExtractChannelUpdate(ownerPubKey []byte, // Extract the channel update from the policy we own, if any. for _, edge := range policies { - if edge != nil && bytes.Equal(ownerPubKey, owner(edge)) { + e, ok := edge.(*models.ChannelEdgePolicy1) + if !ok { + return nil, fmt.Errorf("expected "+ + "*models.ChannelEdgePolicy1, got: %T", edge) + } + + if edge != nil && bytes.Equal(ownerPubKey, owner(e)) { return ChannelUpdateFromEdge(info, edge) } } diff --git a/netann/interface.go b/netann/interface.go index 15b1abfc6..96d1abafc 100644 --- a/netann/interface.go +++ b/netann/interface.go @@ -20,5 +20,5 @@ type ChannelGraph interface { // FetchChannelEdgesByOutpoint returns the channel edge info and most // recent channel edge policies for a given outpoint. FetchChannelEdgesByOutpoint(*wire.OutPoint) (models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) + models.ChannelEdgePolicy, models.ChannelEdgePolicy, error) } diff --git a/peer/brontide.go b/peer/brontide.go index 89c2d4a04..adffc26d2 100644 --- a/peer/brontide.go +++ b/peer/brontide.go @@ -1051,25 +1051,27 @@ func (p *Brontide) loadActiveChannels(chans []*channeldb.OpenChannel) ( var forwardingPolicy *models.ForwardingPolicy if selfPolicy != nil { var inboundWireFee lnwire.Fee - _, err := selfPolicy.ExtraOpaqueData.ExtractRecords( - &inboundWireFee, - ) - if err != nil { - return nil, err + if pol, ok := selfPolicy.(*models.ChannelEdgePolicy1); ok { + _, err := pol.ExtraOpaqueData.ExtractRecords( + &inboundWireFee, + ) + if err != nil { + return nil, err + } } inboundFee := models.NewInboundFeeFromWire( inboundWireFee, ) + pol := selfPolicy.ForwardingPolicy() forwardingPolicy = &models.ForwardingPolicy{ - MinHTLCOut: selfPolicy.MinHTLC, - MaxHTLC: selfPolicy.MaxHTLC, - BaseFee: selfPolicy.FeeBaseMSat, - FeeRate: selfPolicy.FeeProportionalMillionths, - TimeLockDelta: uint32(selfPolicy.TimeLockDelta), - - InboundFee: inboundFee, + MinHTLCOut: pol.MinHTLC, + MaxHTLC: pol.MaxHTLC, + BaseFee: pol.BaseFee, + FeeRate: pol.FeeRate, + TimeLockDelta: uint32(pol.TimeLockDelta), + InboundFee: inboundFee, } } else { p.log.Warnf("Unable to find our forwarding policy "+ diff --git a/routing/blindedpath/blinded_path.go b/routing/blindedpath/blinded_path.go index 9b69851ff..2edd4a194 100644 --- a/routing/blindedpath/blinded_path.go +++ b/routing/blindedpath/blinded_path.go @@ -43,7 +43,7 @@ type BuildBlindedPathCfg struct { // FetchChannelEdgesByID attempts to look up the two directed edges for // the channel identified by the channel ID. FetchChannelEdgesByID func(chanID uint64) (models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1, *models.ChannelEdgePolicy1, error) + models.ChannelEdgePolicy, models.ChannelEdgePolicy, error) // FetchOurOpenChannels fetches this node's set of open channels. FetchOurOpenChannels func() ([]*channeldb.OpenChannel, error) @@ -648,16 +648,24 @@ func getNodeChannelPolicy(cfg *BuildBlindedPathCfg, chanID uint64, return nil, err } + var update1ToNode, update2ToNode [33]byte + if update1 != nil { + update1ToNode = update1.GetToNode() + } + if update2 != nil { + update2ToNode = update2.GetToNode() + } + // Now we need to determine which of the updates was created by the // node in question. We know the update is the correct one if the // "ToNode" for the fetched policy is _not_ equal to the node ID in // question. - var policy *models.ChannelEdgePolicy1 + var policy models.ChannelEdgePolicy switch { - case update1 != nil && !bytes.Equal(update1.ToNode[:], nodeID[:]): + case update1 != nil && !bytes.Equal(update1ToNode[:], nodeID[:]): policy = update1 - case update2 != nil && !bytes.Equal(update2.ToNode[:], nodeID[:]): + case update2 != nil && !bytes.Equal(update2ToNode[:], nodeID[:]): policy = update2 default: @@ -665,12 +673,14 @@ func getNodeChannelPolicy(cfg *BuildBlindedPathCfg, chanID uint64, "%s for channel %d", nodeID, chanID) } + fwdPolicy := policy.ForwardingPolicy() + return &BlindedHopPolicy{ - CLTVExpiryDelta: policy.TimeLockDelta, - FeeRate: uint32(policy.FeeProportionalMillionths), - BaseFee: policy.FeeBaseMSat, - MinHTLCMsat: policy.MinHTLC, - MaxHTLCMsat: policy.MaxHTLC, + CLTVExpiryDelta: fwdPolicy.TimeLockDelta, + FeeRate: uint32(fwdPolicy.FeeRate), + BaseFee: fwdPolicy.BaseFee, + MinHTLCMsat: fwdPolicy.MinHTLC, + MaxHTLCMsat: fwdPolicy.MaxHTLC, }, nil } diff --git a/routing/blindedpath/blinded_path_test.go b/routing/blindedpath/blinded_path_test.go index 4ee4f444f..548103f47 100644 --- a/routing/blindedpath/blinded_path_test.go +++ b/routing/blindedpath/blinded_path_test.go @@ -580,12 +580,12 @@ func TestBuildBlindedPath(t *testing.T) { }, } - realPolicies := map[uint64]*models.ChannelEdgePolicy1{ - chanCB: { + realPolicies := map[uint64]models.ChannelEdgePolicy{ + chanCB: &models.ChannelEdgePolicy1{ ChannelID: chanCB, ToNode: bob, }, - chanBA: { + chanBA: &models.ChannelEdgePolicy1{ ChannelID: chanBA, ToNode: alice, }, @@ -598,8 +598,8 @@ func TestBuildBlindedPath(t *testing.T) { return []*route.Route{realRoute}, nil }, FetchChannelEdgesByID: func(chanID uint64) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { return nil, realPolicies[chanID], nil, nil }, @@ -766,8 +766,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { return []*route.Route{realRoute}, nil }, FetchChannelEdgesByID: func(chanID uint64) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { policy, ok := realPolicies[chanID] if !ok { @@ -937,8 +937,8 @@ func TestBuildBlindedPathWithDummyHops(t *testing.T) { nil }, FetchChannelEdgesByID: func(chanID uint64) ( - models.ChannelEdgeInfo, *models.ChannelEdgePolicy1, - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgeInfo, models.ChannelEdgePolicy, + models.ChannelEdgePolicy, error) { // Force the call to error for the first 2 channels. if errCount < 2 { diff --git a/routing/localchans/manager.go b/routing/localchans/manager.go index 072b5eb9f..8ae46f8bd 100644 --- a/routing/localchans/manager.go +++ b/routing/localchans/manager.go @@ -33,7 +33,7 @@ type Manager struct { // channels. ForAllOutgoingChannels func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1) error) error + models.ChannelEdgePolicy) error) error // FetchChannel is used to query local channel parameters. Optionally an // existing db tx can be supplied. @@ -75,7 +75,7 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, err := r.ForAllOutgoingChannels(func( tx kvdb.RTx, info models.ChannelEdgeInfo, - edge *models.ChannelEdgePolicy1) error { + edge models.ChannelEdgePolicy) error { var chanPoint = info.GetChanPoint() @@ -108,21 +108,23 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, Edge: edge, }) - // Extract inbound fees from the ExtraOpaqueData. var inboundWireFee lnwire.Fee - _, err = edge.ExtraOpaqueData.ExtractRecords(&inboundWireFee) - if err != nil { - return err + if pol, ok := edge.(*models.ChannelEdgePolicy1); ok { + _, err = pol.ExtraOpaqueData.ExtractRecords(&inboundWireFee) + if err != nil { + return err + } } inboundFee := models.NewInboundFeeFromWire(inboundWireFee) // Add updated policy to list of policies to send to switch. + fwdPol := edge.ForwardingPolicy() policiesToUpdate[chanPoint] = models.ForwardingPolicy{ - BaseFee: edge.FeeBaseMSat, - FeeRate: edge.FeeProportionalMillionths, - TimeLockDelta: uint32(edge.TimeLockDelta), - MinHTLCOut: edge.MinHTLC, - MaxHTLC: edge.MaxHTLC, + BaseFee: fwdPol.BaseFee, + FeeRate: fwdPol.FeeRate, + TimeLockDelta: uint32(fwdPol.TimeLockDelta), + MinHTLCOut: fwdPol.MinHTLC, + MaxHTLC: fwdPol.MaxHTLC, InboundFee: inboundFee, } @@ -184,86 +186,162 @@ func (r *Manager) UpdatePolicy(newSchema routing.ChannelPolicy, // updateEdge updates the given edge with the new schema. func (r *Manager) updateEdge(tx kvdb.RTx, chanPoint wire.OutPoint, - edge *models.ChannelEdgePolicy1, + edgePolicy models.ChannelEdgePolicy, newSchema routing.ChannelPolicy) error { - // Update forwarding fee scheme and required time lock delta. - edge.FeeBaseMSat = newSchema.BaseFee - edge.FeeProportionalMillionths = lnwire.MilliSatoshi( - newSchema.FeeRate, - ) + switch edge := edgePolicy.(type) { + case *models.ChannelEdgePolicy1: + // Update forwarding fee scheme and required time lock delta. + edge.FeeBaseMSat = newSchema.BaseFee + edge.FeeProportionalMillionths = lnwire.MilliSatoshi( + newSchema.FeeRate, + ) - // If inbound fees are set, we update the edge with them. - err := fn.MapOptionZ(newSchema.InboundFee, - func(f models.InboundFee) error { - inboundWireFee := f.ToWire() - return edge.ExtraOpaqueData.PackRecords( - &inboundWireFee, + // If inbound fees are set, we update the edge with them. + err := fn.MapOptionZ(newSchema.InboundFee, + func(f models.InboundFee) error { + inboundWireFee := f.ToWire() + return edge.ExtraOpaqueData.PackRecords( + &inboundWireFee, + ) + }) + if err != nil { + return err + } + + edge.TimeLockDelta = uint16(newSchema.TimeLockDelta) + + // Retrieve negotiated channel htlc amt limits. + amtMin, amtMax, err := r.getHtlcAmtLimits(tx, chanPoint) + if err != nil { + return err + } + + // We now update the edge max htlc value. + switch { + // If a non-zero max htlc was specified, use it to update the + // edge. Otherwise, keep the value unchanged. + case newSchema.MaxHTLC != 0: + edge.MaxHTLC = newSchema.MaxHTLC + + // If this edge still doesn't have a max htlc set, set it to the + // max. This is an on-the-fly migration. + case !edge.MessageFlags.HasMaxHtlc(): + edge.MaxHTLC = amtMax + + // If this edge has a max htlc that exceeds what the channel can + // actually carry, correct it now. This can happen, because we + // previously set the max htlc to the channel capacity. + case edge.MaxHTLC > amtMax: + edge.MaxHTLC = amtMax + } + + // If a new min htlc is specified, update the edge. + if newSchema.MinHTLC != nil { + edge.MinHTLC = *newSchema.MinHTLC + } + + // If the MaxHtlc flag wasn't already set, we can set it now. + edge.MessageFlags |= lnwire.ChanUpdateRequiredMaxHtlc + + // Validate htlc amount constraints. + switch { + case edge.MinHTLC < amtMin: + return fmt.Errorf("min htlc amount of %v is below "+ + "min htlc parameter of %v", edge.MinHTLC, + amtMin) + + case edge.MaxHTLC > amtMax: + return fmt.Errorf("max htlc size of %v is above max "+ + "pending amount of %v", edge.MaxHTLC, amtMax) + + case edge.MinHTLC > edge.MaxHTLC: + return fmt.Errorf( + "min_htlc %v greater than max_htlc %v", + edge.MinHTLC, edge.MaxHTLC, ) - }) - if err != nil { - return err + } + + // Clear signature to help prevent usage of the previous + // signature. + edge.SetSigBytes(nil) + + case *models.ChannelEdgePolicy2: + // Update forwarding fee scheme and required time lock delta. + edge.FeeBaseMsat.Val = uint32(newSchema.BaseFee) + edge.FeeProportionalMillionths.Val = newSchema.FeeRate + edge.CLTVExpiryDelta.Val = uint16(newSchema.TimeLockDelta) + + // If inbound fees are set, we update the edge with them. + err := fn.MapOptionZ(newSchema.InboundFee, + func(f models.InboundFee) error { + inboundWireFee := f.ToWire() + return edge.ExtraOpaqueData.PackRecords( + &inboundWireFee, + ) + }) + if err != nil { + return err + } + + // Retrieve negotiated channel htlc amt limits. + amtMin, amtMax, err := r.getHtlcAmtLimits(tx, chanPoint) + if err != nil { + return err + } + + // We now update the edge max htlc value. + switch { + // If a non-zero max htlc was specified, use it to update the + // edge. + // Otherwise keep the value unchanged. + case newSchema.MaxHTLC != 0: + edge.HTLCMaximumMsat.Val = newSchema.MaxHTLC + + // If this edge has a max htlc that exceeds what the channel can + // actually carry, correct it now. This can happen, because we + // previously set the max htlc to the channel capacity. + case edge.HTLCMaximumMsat.Val > amtMax: + edge.HTLCMaximumMsat.Val = amtMax + } + + // If a new min htlc is specified, update the edge. + if newSchema.MinHTLC != nil { + edge.HTLCMinimumMsat.Val = *newSchema.MinHTLC + } + + // Validate htlc amount constraints. + switch { + case edge.HTLCMinimumMsat.Val < amtMin: + return fmt.Errorf( + "min htlc amount of %v is below min htlc "+ + "parameter of %v", edge.HTLCMinimumMsat, + amtMin, + ) + + case edge.HTLCMaximumMsat.Val > amtMax: + return fmt.Errorf( + "max htlc size of %v is above max pending "+ + "amount of %v", edge.HTLCMaximumMsat, + amtMax, + ) + + case edge.HTLCMinimumMsat.Val > edge.HTLCMaximumMsat.Val: + return fmt.Errorf( + "min_htlc %v greater than max_htlc %v", + edge.HTLCMinimumMsat, edge.HTLCMaximumMsat, + ) + } + + // Clear signature to help prevent usage of the previous + // signature. + edge.Signature = lnwire.Sig{} + + default: + return fmt.Errorf("unhandled implementation of "+ + "models.ChannelEdgePolicy: %T", edgePolicy) } - edge.TimeLockDelta = uint16(newSchema.TimeLockDelta) - - // Retrieve negotiated channel htlc amt limits. - amtMin, amtMax, err := r.getHtlcAmtLimits(tx, chanPoint) - if err != nil { - return err - } - - // We now update the edge max htlc value. - switch { - // If a non-zero max htlc was specified, use it to update the edge. - // Otherwise keep the value unchanged. - case newSchema.MaxHTLC != 0: - edge.MaxHTLC = newSchema.MaxHTLC - - // If this edge still doesn't have a max htlc set, set it to the max. - // This is an on-the-fly migration. - case !edge.MessageFlags.HasMaxHtlc(): - edge.MaxHTLC = amtMax - - // If this edge has a max htlc that exceeds what the channel can - // actually carry, correct it now. This can happen, because we - // previously set the max htlc to the channel capacity. - case edge.MaxHTLC > amtMax: - edge.MaxHTLC = amtMax - } - - // If a new min htlc is specified, update the edge. - if newSchema.MinHTLC != nil { - edge.MinHTLC = *newSchema.MinHTLC - } - - // If the MaxHtlc flag wasn't already set, we can set it now. - edge.MessageFlags |= lnwire.ChanUpdateRequiredMaxHtlc - - // Validate htlc amount constraints. - switch { - case edge.MinHTLC < amtMin: - return fmt.Errorf( - "min htlc amount of %v is below min htlc parameter of %v", - edge.MinHTLC, amtMin, - ) - - case edge.MaxHTLC > amtMax: - return fmt.Errorf( - "max htlc size of %v is above max pending amount of %v", - edge.MaxHTLC, amtMax, - ) - - case edge.MinHTLC > edge.MaxHTLC: - return fmt.Errorf( - "min_htlc %v greater than max_htlc %v", - edge.MinHTLC, edge.MaxHTLC, - ) - } - - // Clear signature to help prevent usage of the previous signature. - edge.SetSigBytes(nil) - return nil } diff --git a/routing/localchans/manager_test.go b/routing/localchans/manager_test.go index 30b23c1b7..9e9d6a511 100644 --- a/routing/localchans/manager_test.go +++ b/routing/localchans/manager_test.go @@ -85,7 +85,9 @@ func TestManager(t *testing.T) { } for _, edge := range edgesToUpdate { - policy := edge.Edge + policy, ok := edge.Edge.(*models.ChannelEdgePolicy1) + require.True(t, ok) + if !policy.MessageFlags.HasMaxHtlc() { t.Fatal("expected max htlc flag") } @@ -108,7 +110,7 @@ func TestManager(t *testing.T) { forAllOutgoingChannels := func(cb func(kvdb.RTx, models.ChannelEdgeInfo, - *models.ChannelEdgePolicy1) error) error { + models.ChannelEdgePolicy) error) error { for _, c := range channelSet { if err := cb(nil, c.edgeInfo, ¤tPolicy); err != nil { diff --git a/routing/pathfind_test.go b/routing/pathfind_test.go index 8af7dd952..0a59104f4 100644 --- a/routing/pathfind_test.go +++ b/routing/pathfind_test.go @@ -2026,9 +2026,12 @@ func runRouteFailMaxHTLC(t *testing.T, useCache bool) { graph := ctx.testGraphInstance.graph _, midEdge, _, err := graph.FetchChannelEdgesByID(firstToSecondID) require.NoError(t, err, "unable to fetch channel edges by ID") - midEdge.MessageFlags = 1 - midEdge.MaxHTLC = payAmt - 1 - if err := graph.UpdateEdgePolicy(midEdge); err != nil { + + midEdgePol, ok := midEdge.(*models.ChannelEdgePolicy1) + require.True(t, ok) + midEdgePol.MessageFlags = 1 + midEdgePol.MaxHTLC = payAmt - 1 + if err := graph.UpdateEdgePolicy(midEdgePol); err != nil { t.Fatalf("unable to update edge: %v", err) } @@ -2067,8 +2070,16 @@ func runRouteFailDisabledEdge(t *testing.T, useCache bool) { // path finding, as we don't consider the disable flag for local // channels (and roasbeef is the source). roasToPham := uint64(999991) - _, e1, e2, err := graph.graph.FetchChannelEdgesByID(roasToPham) + + _, edge1, edge2, err := graph.graph.FetchChannelEdgesByID(roasToPham) require.NoError(t, err, "unable to fetch edge") + + e1, ok := edge1.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + e2, ok := edge2.(*models.ChannelEdgePolicy1) + require.True(t, ok) + e1.ChannelFlags |= lnwire.ChanUpdateDisabled if err := graph.graph.UpdateEdgePolicy(e1); err != nil { t.Fatalf("unable to update edge: %v", err) @@ -2088,8 +2099,12 @@ func runRouteFailDisabledEdge(t *testing.T, useCache bool) { // Now, we'll modify the edge from phamnuwen -> sophon, to read that // it's disabled. phamToSophon := uint64(99999) - _, e, _, err := graph.graph.FetchChannelEdgesByID(phamToSophon) + _, edge, _, err := graph.graph.FetchChannelEdgesByID(phamToSophon) require.NoError(t, err, "unable to fetch edge") + + e, ok := edge.(*models.ChannelEdgePolicy1) + require.True(t, ok) + e.ChannelFlags |= lnwire.ChanUpdateDisabled if err := graph.graph.UpdateEdgePolicy(e); err != nil { t.Fatalf("unable to update edge: %v", err) @@ -2169,8 +2184,15 @@ func runPathSourceEdgesBandwidth(t *testing.T, useCache bool) { // Finally, set the roasbeef->songoku bandwidth, but also set its // disable flag. bandwidths.hints[roasToSongoku] = 2 * payAmt - _, e1, e2, err := graph.graph.FetchChannelEdgesByID(roasToSongoku) + _, edge1, edge2, err := graph.graph.FetchChannelEdgesByID(roasToSongoku) require.NoError(t, err, "unable to fetch edge") + + e1, ok := edge1.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + e2, ok := edge2.(*models.ChannelEdgePolicy1) + require.True(t, ok) + e1.ChannelFlags |= lnwire.ChanUpdateDisabled if err := graph.graph.UpdateEdgePolicy(e1); err != nil { t.Fatalf("unable to update edge: %v", err) diff --git a/routing/router_test.go b/routing/router_test.go index 0f7528455..cab0bff27 100644 --- a/routing/router_test.go +++ b/routing/router_test.go @@ -31,6 +31,7 @@ import ( "github.com/lightningnetwork/lnd/lntypes" "github.com/lightningnetwork/lnd/lnwallet" "github.com/lightningnetwork/lnd/lnwire" + "github.com/lightningnetwork/lnd/netann" "github.com/lightningnetwork/lnd/record" "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/zpay32" @@ -450,13 +451,17 @@ func TestChannelUpdateValidation(t *testing.T) { ctx := createTestCtxFromGraphInstance(t, startingBlockHeight, testGraph) // Assert that the initially configured fee is retrieved correctly. - _, e1, e2, err := ctx.graph.FetchChannelEdgesByID( + _, edge1, edge2, err := ctx.graph.FetchChannelEdgesByID( lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") - require.Equal(t, feeRate, e1.FeeProportionalMillionths, "invalid fee") - require.Equal(t, feeRate, e2.FeeProportionalMillionths, "invalid fee") + require.Equal( + t, feeRate, edge1.ForwardingPolicy().FeeRate, "invalid fee", + ) + require.Equal( + t, feeRate, edge2.ForwardingPolicy().FeeRate, "invalid fee", + ) // Setup a route from source a to destination c. The route will be used // in a call to SendToRoute. SendToRoute also applies channel updates, @@ -482,6 +487,12 @@ func TestChannelUpdateValidation(t *testing.T) { ) require.NoError(t, err, "unable to create route") + e1, ok := edge1.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + e2, ok := edge2.(*models.ChannelEdgePolicy1) + require.True(t, ok) + // Set up a channel update message with an invalid signature to be // returned to the sender. var invalidSignature lnwire.Sig @@ -522,11 +533,17 @@ func TestChannelUpdateValidation(t *testing.T) { _, err = ctx.router.SendToRoute(payment, rt) require.Error(t, err, "expected route to fail with channel update") - _, e1, e2, err = ctx.graph.FetchChannelEdgesByID( + _, edge1, edge2, err = ctx.graph.FetchChannelEdgesByID( lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") + _, ok = edge1.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + e2, ok = edge2.(*models.ChannelEdgePolicy1) + require.True(t, ok) + require.Equal(t, feeRate, e1.FeeProportionalMillionths, "fee updated without valid signature") require.Equal(t, feeRate, e2.FeeProportionalMillionths, @@ -544,11 +561,17 @@ func TestChannelUpdateValidation(t *testing.T) { // This time a valid signature was supplied and the policy change should // have been applied to the graph. - _, e1, e2, err = ctx.graph.FetchChannelEdgesByID( + _, edge1, edge2, err = ctx.graph.FetchChannelEdgesByID( lnwire.NewShortChanIDFromInt(1).ToUint64(), ) require.NoError(t, err, "cannot retrieve channel") + e1, ok = edge1.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + e2, ok = edge2.(*models.ChannelEdgePolicy1) + require.True(t, ok) + require.Equal(t, feeRate, e1.FeeProportionalMillionths, "fee should not be updated") require.EqualValues(t, 500, int(e2.FeeProportionalMillionths), @@ -590,21 +613,15 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { ) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate1{ - ShortChannelID: lnwire.NewShortChanIDFromInt( - songokuSophonChanID, - ), - Timestamp: uint32(edgeUpdateToFail.LastUpdate.Unix()), - MessageFlags: edgeUpdateToFail.MessageFlags, - ChannelFlags: edgeUpdateToFail.ChannelFlags, - TimeLockDelta: edgeUpdateToFail.TimeLockDelta, - HtlcMinimumMsat: edgeUpdateToFail.MinHTLC, - HtlcMaximumMsat: edgeUpdateToFail.MaxHTLC, - BaseFee: uint32(edgeUpdateToFail.FeeBaseMSat), - FeeRate: uint32(edgeUpdateToFail.FeeProportionalMillionths), - } + edgeUpdToFail, ok := edgeUpdateToFail.(*models.ChannelEdgePolicy1) + require.True(t, ok) - signErrChanUpdate(t, ctx.privKeys["songoku"], &errChanUpdate) + errChanUpdate, err := netann.UnsignedChannelUpdateFromEdge( + chainhash.Hash{}, edgeUpdToFail, + ) + require.NoError(t, err) + + signErrChanUpdate(t, ctx.privKeys["songoku"], errChanUpdate) // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to Son goku. This will be a fee related error, so @@ -625,7 +642,7 @@ func TestSendPaymentErrorRepeatedFeeInsufficient(t *testing.T) { // reflect the new fee schedule for the // node/channel. &lnwire.FailFeeInsufficient{ - Update: errChanUpdate, + Update: *errChanUpdate, }, 1, ) } @@ -936,17 +953,13 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { _, _, edgeUpdateToFail, err := ctx.graph.FetchChannelEdgesByID(chanID) require.NoError(t, err, "unable to fetch chan id") - errChanUpdate := lnwire.ChannelUpdate1{ - ShortChannelID: lnwire.NewShortChanIDFromInt(chanID), - Timestamp: uint32(edgeUpdateToFail.LastUpdate.Unix()), - MessageFlags: edgeUpdateToFail.MessageFlags, - ChannelFlags: edgeUpdateToFail.ChannelFlags, - TimeLockDelta: edgeUpdateToFail.TimeLockDelta, - HtlcMinimumMsat: edgeUpdateToFail.MinHTLC, - HtlcMaximumMsat: edgeUpdateToFail.MaxHTLC, - BaseFee: uint32(edgeUpdateToFail.FeeBaseMSat), - FeeRate: uint32(edgeUpdateToFail.FeeProportionalMillionths), - } + edgeUpdToFail, ok := edgeUpdateToFail.(*models.ChannelEdgePolicy1) + require.True(t, ok) + + errChanUpdate, err := netann.UnsignedChannelUpdateFromEdge( + chainhash.Hash{}, edgeUpdToFail, + ) + require.NoError(t, err) // We'll now modify the SendToSwitch method to return an error for the // outgoing channel to son goku. Since this is a time lock related @@ -957,7 +970,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailExpiryTooSoon{ - Update: errChanUpdate, + Update: *errChanUpdate, }, 1, ) } @@ -1005,7 +1018,7 @@ func TestSendPaymentErrorNonFinalTimeLockErrors(t *testing.T) { if firstHop == roasbeefSongoku { return [32]byte{}, htlcswitch.NewForwardingError( &lnwire.FailIncorrectCltvExpiry{ - Update: errChanUpdate, + Update: *errChanUpdate, }, 1, ) } diff --git a/rpcserver.go b/rpcserver.go index 55e5a060d..2396db13e 100644 --- a/rpcserver.go +++ b/rpcserver.go @@ -6221,7 +6221,7 @@ func (r *rpcServer) DescribeGraph(ctx context.Context, // similar response which details both the edge information as well as // the routing policies of th nodes connecting the two edges. err = graph.ForEachChannel(func(edgeInfo models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy1) error { + c1, c2 models.ChannelEdgePolicy) error { // Do not include unannounced channels unless specifically // requested. Unannounced channels include both private channels as @@ -6455,7 +6455,7 @@ func (r *rpcServer) GetChanInfo(_ context.Context, var ( edgeInfo models.ChannelEdgeInfo - edge1, edge2 *models.ChannelEdgePolicy1 + edge1, edge2 models.ChannelEdgePolicy err error ) @@ -6528,7 +6528,7 @@ func (r *rpcServer) GetNodeInfo(ctx context.Context, err = graph.ForEachNodeChannel(node.PubKeyBytes, func(_ kvdb.RTx, edge models.ChannelEdgeInfo, - c1, c2 *models.ChannelEdgePolicy1) error { + c1, c2 models.ChannelEdgePolicy) error { numChannels++ totalCapacity += edge.GetCapacity() @@ -7191,7 +7191,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, var feeReports []*lnrpc.ChannelFeeReport err = channelGraph.ForEachNodeChannel(selfNode.PubKeyBytes, func(_ kvdb.RTx, chanInfo models.ChannelEdgeInfo, - edgePolicy, _ *models.ChannelEdgePolicy1) error { + edgePolicy, _ models.ChannelEdgePolicy) error { // Self node should always have policies for its // channels. @@ -7205,17 +7205,18 @@ func (r *rpcServer) FeeReport(ctx context.Context, // rate. The fee rate field in the database the amount // of mSAT charged per 1mil mSAT sent, so will divide by // this to get the proper fee rate. - feeRateFixedPoint := - edgePolicy.FeeProportionalMillionths + fwdPol := edgePolicy.ForwardingPolicy() + feeRateFixedPoint := fwdPol.FeeRate feeRate := float64(feeRateFixedPoint) / feeBase // Decode inbound fee from extra data. var inboundFee lnwire.Fee - _, err := edgePolicy.ExtraOpaqueData.ExtractRecords( - &inboundFee, - ) - if err != nil { - return err + + if pol, ok := edgePolicy.(*models.ChannelEdgePolicy1); ok { + _, err := pol.ExtraOpaqueData.ExtractRecords(&inboundFee) + if err != nil { + return err + } } // TODO(roasbeef): also add stats for revenue for each @@ -7223,7 +7224,7 @@ func (r *rpcServer) FeeReport(ctx context.Context, feeReports = append(feeReports, &lnrpc.ChannelFeeReport{ ChanId: chanInfo.GetChanID(), ChannelPoint: chanInfo.GetChanPoint().String(), - BaseFeeMsat: int64(edgePolicy.FeeBaseMSat), + BaseFeeMsat: int64(fwdPol.BaseFee), FeePerMil: int64(feeRateFixedPoint), FeeRate: feeRate, diff --git a/server.go b/server.go index 750540e60..763c1e618 100644 --- a/server.go +++ b/server.go @@ -1290,7 +1290,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, // Wrap the DeleteChannelEdges method so that the funding manager can // use it without depending on several layers of indirection. deleteAliasEdge := func(scid lnwire.ShortChannelID) ( - *models.ChannelEdgePolicy1, error) { + models.ChannelEdgePolicy, error) { info, e1, e2, err := s.graphDB.FetchChannelEdgesByID( scid.ToUint64(), @@ -1309,7 +1309,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, var ourKey [33]byte copy(ourKey[:], nodeKeyDesc.PubKey.SerializeCompressed()) - var ourPolicy *models.ChannelEdgePolicy1 + var ourPolicy models.ChannelEdgePolicy if info != nil && info.Node1Bytes() == ourKey { ourPolicy = e1 } else { @@ -1324,6 +1324,7 @@ func newServer(cfg *Config, listenAddrs []net.Addr, err = s.graphDB.DeleteChannelEdges( false, false, scid.ToUint64(), ) + return ourPolicy, err } @@ -3230,7 +3231,7 @@ func (s *server) establishPersistentConnections() error { err = s.graphDB.ForEachNodeChannel(sourceNode.PubKeyBytes, func( tx kvdb.RTx, chanInfo models.ChannelEdgeInfo, - policy, _ *models.ChannelEdgePolicy1) error { + policy, _ models.ChannelEdgePolicy) error { chanPoint := chanInfo.GetChanPoint()