diff --git a/lntest/itest/assertions.go b/lntest/itest/assertions.go index 160d65a9c..fe881f1d1 100644 --- a/lntest/itest/assertions.go +++ b/lntest/itest/assertions.go @@ -686,7 +686,7 @@ out: continue } - err := checkChannelPolicy( + err := lntest.CheckChannelPolicy( update.RoutingPolicy, exp.expectedPolicy, ) @@ -810,44 +810,13 @@ func assertChannelPolicy(t *harnessTest, node *lntest.HarnessNode, policies := getChannelPolicies(t, node, advertisingNode, chanPoints...) for _, policy := range policies { - err := checkChannelPolicy(policy, expectedPolicy) + err := lntest.CheckChannelPolicy(policy, expectedPolicy) if err != nil { t.Fatalf(err.Error()) } } } -// checkChannelPolicy checks that the policy matches the expected one. -func checkChannelPolicy(policy, expectedPolicy *lnrpc.RoutingPolicy) error { - if policy.FeeBaseMsat != expectedPolicy.FeeBaseMsat { - return fmt.Errorf("expected base fee %v, got %v", - expectedPolicy.FeeBaseMsat, policy.FeeBaseMsat) - } - if policy.FeeRateMilliMsat != expectedPolicy.FeeRateMilliMsat { - return fmt.Errorf("expected fee rate %v, got %v", - expectedPolicy.FeeRateMilliMsat, - policy.FeeRateMilliMsat) - } - if policy.TimeLockDelta != expectedPolicy.TimeLockDelta { - return fmt.Errorf("expected time lock delta %v, got %v", - expectedPolicy.TimeLockDelta, - policy.TimeLockDelta) - } - if policy.MinHtlc != expectedPolicy.MinHtlc { - return fmt.Errorf("expected min htlc %v, got %v", - expectedPolicy.MinHtlc, policy.MinHtlc) - } - if policy.MaxHtlcMsat != expectedPolicy.MaxHtlcMsat { - return fmt.Errorf("expected max htlc %v, got %v", - expectedPolicy.MaxHtlcMsat, policy.MaxHtlcMsat) - } - if policy.Disabled != expectedPolicy.Disabled { - return errors.New("edge should be disabled but isn't") - } - - return nil -} - // assertMinerBlockHeightDelta ensures that tempMiner is 'delta' blocks ahead // of miner. func assertMinerBlockHeightDelta(t *harnessTest, diff --git a/lntest/node.go b/lntest/node.go index d4207ed45..71d2044e2 100644 --- a/lntest/node.go +++ b/lntest/node.go @@ -341,6 +341,21 @@ func (cfg NodeConfig) genArgs() []string { return args } +// policyUpdateMap defines a type to store channel policy updates. It has the +// format, +// { +// "chanPoint1": { +// "advertisingNode1": [ +// policy1, policy2, ... +// ], +// "advertisingNode2": [ +// policy1, policy2, ... +// ] +// }, +// "chanPoint2": ... +// } +type policyUpdateMap map[string]map[string][]*lnrpc.RoutingPolicy + // HarnessNode represents an instance of lnd running within our test network // harness. Each HarnessNode instance also fully embeds an RPC client in // order to pragmatically drive the node. @@ -376,6 +391,10 @@ type HarnessNode struct { closedChans map[wire.OutPoint]struct{} closeChanWatchers map[wire.OutPoint][]chan struct{} + // policyUpdates stores a slice of seen polices by each advertising + // node and the outpoint. + policyUpdates policyUpdateMap + quit chan struct{} wg sync.WaitGroup @@ -451,6 +470,8 @@ func newNode(cfg NodeConfig) (*HarnessNode, error) { closedChans: make(map[wire.OutPoint]struct{}), closeChanWatchers: make(map[wire.OutPoint][]chan struct{}), + + policyUpdates: policyUpdateMap{}, }, nil } @@ -1306,6 +1327,10 @@ const ( // watchCloseChannel specifies that this is a request to watch a close // channel event. watchCloseChannel + + // watchPolicyUpdate specifies that this is a request to watch a policy + // update event. + watchPolicyUpdate ) // closeChanWatchRequest is a request to the lightningNetworkWatcher to be @@ -1317,6 +1342,10 @@ type chanWatchRequest struct { chanWatchType chanWatchType eventChan chan struct{} + + advertisingNode string + policy *lnrpc.RoutingPolicy + includeUnannounced bool } // getChanPointFundingTxid returns the given channel point's funding txid in @@ -1407,6 +1436,9 @@ func (hn *HarnessNode) lightningNetworkWatcher() { case watchCloseChannel: hn.handleCloseChannelWatchRequest(watchRequest) + + case watchPolicyUpdate: + hn.handlePolicyUpdateWatchRequest(watchRequest) } case <-hn.quit: @@ -1475,6 +1507,47 @@ func (hn *HarnessNode) WaitForNetworkChannelClose(ctx context.Context, } } +// WaitForChannelPolicyUpdate will block until a channel policy with the target +// outpoint and advertisingNode is seen within the network. +func (hn *HarnessNode) WaitForChannelPolicyUpdate(ctx context.Context, + advertisingNode string, policy *lnrpc.RoutingPolicy, + chanPoint *lnrpc.ChannelPoint, includeUnannounced bool) error { + + eventChan := make(chan struct{}) + + op, err := MakeOutpoint(chanPoint) + if err != nil { + return fmt.Errorf("failed to create outpoint for %v"+ + "got err: %v", chanPoint, err) + } + + ticker := time.NewTicker(wait.PollInterval) + defer ticker.Stop() + + for { + select { + // Send a watch request every second. + case <-ticker.C: + hn.chanWatchRequests <- &chanWatchRequest{ + chanPoint: op, + eventChan: eventChan, + chanWatchType: watchPolicyUpdate, + policy: policy, + advertisingNode: advertisingNode, + includeUnannounced: includeUnannounced, + } + + case <-eventChan: + return nil + + case <-ctx.Done(): + return fmt.Errorf("channel:%s policy not updated "+ + "before timeout: [%s:%v] %s", op, + advertisingNode, policy, hn.String()) + } + } +} + // WaitForBlockchainSync waits for the target node to be fully synchronized with // the blockchain. If the passed context object has a set timeout, it will // continually poll until the timeout has elapsed. In the case that the chain @@ -1581,6 +1654,25 @@ func (hn *HarnessNode) handleChannelEdgeUpdates( close(eventChan) } delete(hn.openChanWatchers, op) + + // Check whether there's a routing policy update. If so, save + // it to the node state. + if newChan.RoutingPolicy == nil { + continue + } + + // Append the policy to the slice. + node := newChan.AdvertisingNode + policies := hn.policyUpdates[op.String()] + + // If the map[op] is nil, we need to initialize the map first. + if policies == nil { + policies = make(map[string][]*lnrpc.RoutingPolicy) + } + policies[node] = append( + policies[node], newChan.RoutingPolicy, + ) + hn.policyUpdates[op.String()] = policies } } @@ -1754,3 +1846,110 @@ func (hn *HarnessNode) receiveTopologyClientStream( } } } + +// CheckChannelPolicy checks that the policy matches the expected one. +func CheckChannelPolicy(policy, expectedPolicy *lnrpc.RoutingPolicy) error { + if policy.FeeBaseMsat != expectedPolicy.FeeBaseMsat { + return fmt.Errorf("expected base fee %v, got %v", + expectedPolicy.FeeBaseMsat, policy.FeeBaseMsat) + } + if policy.FeeRateMilliMsat != expectedPolicy.FeeRateMilliMsat { + return fmt.Errorf("expected fee rate %v, got %v", + expectedPolicy.FeeRateMilliMsat, + policy.FeeRateMilliMsat) + } + if policy.TimeLockDelta != expectedPolicy.TimeLockDelta { + return fmt.Errorf("expected time lock delta %v, got %v", + expectedPolicy.TimeLockDelta, + policy.TimeLockDelta) + } + if policy.MinHtlc != expectedPolicy.MinHtlc { + return fmt.Errorf("expected min htlc %v, got %v", + expectedPolicy.MinHtlc, policy.MinHtlc) + } + if policy.MaxHtlcMsat != expectedPolicy.MaxHtlcMsat { + return fmt.Errorf("expected max htlc %v, got %v", + expectedPolicy.MaxHtlcMsat, policy.MaxHtlcMsat) + } + if policy.Disabled != expectedPolicy.Disabled { + return errors.New("edge should be disabled but isn't") + } + + return nil +} + +// handlePolicyUpdateWatchRequest checks that if the expected policy can be +// found either in the node's interval state or describe graph response. If +// found, it will signal the request by closing the event channel. Otherwise it +// does nothing but returns nil. +func (hn *HarnessNode) handlePolicyUpdateWatchRequest(req *chanWatchRequest) { + op := req.chanPoint + + // Get a list of known policies for this chanPoint+advertisingNode + // combination. Start searching in the node state first. + policies, ok := hn.policyUpdates[op.String()][req.advertisingNode] + + if !ok { + // If it cannot be found in the node state, try searching it + // from the node's DescribeGraph. + policyMap := hn.getChannelPolicies(req.includeUnannounced) + policies, ok = policyMap[op.String()][req.advertisingNode] + if !ok { + return + } + } + + // Check if there's a matched policy. + for _, policy := range policies { + if CheckChannelPolicy(policy, req.policy) == nil { + close(req.eventChan) + return + } + } +} + +// getChannelPolicies queries the channel graph and formats the policies into +// the format defined in type policyUpdateMap. +func (hn *HarnessNode) getChannelPolicies(include bool) policyUpdateMap { + + ctxt, cancel := context.WithTimeout( + context.Background(), DefaultTimeout, + ) + defer cancel() + + graph, err := hn.DescribeGraph(ctxt, &lnrpc.ChannelGraphRequest{ + IncludeUnannounced: include, + }) + if err != nil { + hn.PrintErr("DescribeGraph got err: %v", err) + return nil + } + + policyUpdates := policyUpdateMap{} + + for _, e := range graph.Edges { + + policies := policyUpdates[e.ChanPoint] + + // If the map[op] is nil, we need to initialize the map first. + if policies == nil { + policies = make(map[string][]*lnrpc.RoutingPolicy) + } + + if e.Node1Policy != nil { + policies[e.Node1Pub] = append( + policies[e.Node1Pub], e.Node1Policy, + ) + } + + if e.Node2Policy != nil { + policies[e.Node2Pub] = append( + policies[e.Node2Pub], e.Node2Policy, + ) + } + + policyUpdates[e.ChanPoint] = policies + } + + return policyUpdates +}