diff --git a/autopilot/agent.go b/autopilot/agent.go index 2f257173c..9610427c6 100644 --- a/autopilot/agent.go +++ b/autopilot/agent.go @@ -819,55 +819,3 @@ func (a *Agent) executeDirective(directive AttachmentDirective) { // directive in goroutine? a.OnChannelPendingOpen() } - -// HeuristicScores is an alias for a map that maps heuristic names to a map of -// scores for pubkeys. -type HeuristicScores map[string]map[NodeID]float64 - -// queryHeuristics gets node scores from all available simple heuristics, and -// the agent's current active heuristic. -func (a *Agent) queryHeuristics(nodes map[NodeID]struct{}) ( - HeuristicScores, error) { - - // Get the agent's current channel state. - a.chanStateMtx.Lock() - a.pendingMtx.Lock() - totalChans := mergeChanState(a.pendingOpens, a.chanState) - a.pendingMtx.Unlock() - a.chanStateMtx.Unlock() - - // As channel size we'll use the maximum size. - chanSize := a.cfg.Constraints.MaxChanSize() - - // We'll start by getting the scores from each available sub-heuristic, - // in addition the active agent heuristic. - report := make(HeuristicScores) - for _, h := range append(availableHeuristics, a.cfg.Heuristic) { - name := h.Name() - - // If the active agent heuristic is among the simple heuristics - // it might get queried more than once. As an optimization - // we'll just skip it the second time. - if _, ok := report[name]; ok { - continue - } - - s, err := h.NodeScores( - a.cfg.Graph, totalChans, chanSize, nodes, - ) - if err != nil { - return nil, fmt.Errorf("unable to get sub score: %v", err) - } - - log.Debugf("Heuristic \"%v\" scored %d nodes", name, len(s)) - - scores := make(map[NodeID]float64) - for nID, score := range s { - scores[nID] = score.Score - } - - report[name] = scores - } - - return report, nil -} diff --git a/autopilot/manager.go b/autopilot/manager.go index 283aab027..17ce145c8 100644 --- a/autopilot/manager.go +++ b/autopilot/manager.go @@ -269,23 +269,84 @@ func (m *Manager) StopAgent() error { return nil } -// QueryHeuristics queries the active autopilot agent for node scores. +// QueryHeuristics queries the available autopilot heuristics for node scores. func (m *Manager) QueryHeuristics(nodes []NodeID) (HeuristicScores, error) { m.Lock() defer m.Unlock() - // Not active, so we can return early. - if m.pilot == nil { - return nil, fmt.Errorf("autopilot not active") - } - n := make(map[NodeID]struct{}) for _, node := range nodes { n[node] = struct{}{} } log.Debugf("Querying heuristics for %d nodes", len(n)) - return m.pilot.queryHeuristics(n) + return m.queryHeuristics(n) +} + +// HeuristicScores is an alias for a map that maps heuristic names to a map of +// scores for pubkeys. +type HeuristicScores map[string]map[NodeID]float64 + +// queryHeuristics gets node scores from all available simple heuristics, and +// the agent's current active heuristic. +// +// NOTE: Must be called with the manager's lock. +func (m *Manager) queryHeuristics(nodes map[NodeID]struct{}) ( + HeuristicScores, error) { + + // Fetch the current set of channels. + totalChans, err := m.cfg.ChannelState() + if err != nil { + return nil, err + } + + // If the agent is active, we can merge the channel state with the + // channels pending open. + if m.pilot != nil { + m.pilot.chanStateMtx.Lock() + m.pilot.pendingMtx.Lock() + totalChans = mergeChanState( + m.pilot.pendingOpens, m.pilot.chanState, + ) + m.pilot.pendingMtx.Unlock() + m.pilot.chanStateMtx.Unlock() + } + + // As channel size we'll use the maximum size. + chanSize := m.cfg.PilotCfg.Constraints.MaxChanSize() + + // We'll start by getting the scores from each available sub-heuristic, + // in addition the current agent heuristic. + report := make(HeuristicScores) + for _, h := range append(availableHeuristics, m.cfg.PilotCfg.Heuristic) { + name := h.Name() + + // If the agent heuristic is among the simple heuristics it + // might get queried more than once. As an optimization we'll + // just skip it the second time. + if _, ok := report[name]; ok { + continue + } + + s, err := h.NodeScores( + m.cfg.PilotCfg.Graph, totalChans, chanSize, nodes, + ) + if err != nil { + return nil, fmt.Errorf("unable to get sub score: %v", + err) + } + + log.Debugf("Heuristic \"%v\" scored %d nodes", name, len(s)) + + scores := make(map[NodeID]float64) + for nID, score := range s { + scores[nID] = score.Score + } + + report[name] = scores + } + + return report, nil } // SetNodeScores is used to set the scores of the given heuristic, if it is