diff --git a/routing/heap.go b/routing/heap.go new file mode 100644 index 000000000..68f14c033 --- /dev/null +++ b/routing/heap.go @@ -0,0 +1,59 @@ +package routing + +import "github.com/lightningnetwork/lnd/channeldb" + +// nodeWithDist is a helper struct that couples the distance from the current +// source to a node with a pointer to the node itself. +type nodeWithDist struct { + // dist is the distance to this node from the source node in our + // current context. + dist float64 + + // node is the vertex itself. This pointer can be used to explore all + // the outgoing edges (channels) emanating from a node. + node *channeldb.LightningNode +} + +// distanceHeap is a min-distance heap that's used within our path finding +// algorithm to keep track of the "closest" node to our source node. +type distanceHeap struct { + nodes []nodeWithDist +} + +// Len returns the number of nodes in the priority queue. +// +// NOTE: This is part of the heap.Interface implementation. +func (d *distanceHeap) Len() int { return len(d.nodes) } + +// Less returns whether the item in the priority queue with index i should sort +// before the item with index j. +// +// NOTE: This is part of the heap.Interface implementation. +func (d *distanceHeap) Less(i, j int) bool { + return d.nodes[i].dist < d.nodes[j].dist +} + +// Swap swaps the nodes at the passed indices in the priority queue. +// +// NOTE: This is part of the heap.Interface implementation. +func (d *distanceHeap) Swap(i, j int) { + d.nodes[i], d.nodes[j] = d.nodes[j], d.nodes[i] +} + +// Push pushes the passed item onto the priority queue. +// +// NOTE: This is part of the heap.Interface implementation. +func (d *distanceHeap) Push(x interface{}) { + d.nodes = append(d.nodes, x.(nodeWithDist)) +} + +// Pop removes the highest priority item (according to Less) from the priority +// queue and returns it. +// +// NOTE: This is part of the heap.Interface implementation. +func (d *distanceHeap) Pop() interface{} { + n := len(d.nodes) + x := d.nodes[n-1] + d.nodes = d.nodes[0 : n-1] + return x +} diff --git a/routing/heap_test.go b/routing/heap_test.go new file mode 100644 index 000000000..70a955e38 --- /dev/null +++ b/routing/heap_test.go @@ -0,0 +1,52 @@ +package routing + +import ( + "container/heap" + prand "math/rand" + "reflect" + "sort" + "testing" + "time" +) + +// TestHeapOrdering ensures that the items inserted into the heap are properly +// retrieved in minimum order of distance. +func TestHeapOrdering(t *testing.T) { + // First, create a blank heap, we'll use this to push on randomly + // generated items. + var nodeHeap distanceHeap + + prand.Seed(time.Now().Unix()) + + // Create 100 random entries adding them to the heap created above, but + // also a list that we'll sort with the entries. + const numEntries = 100 + sortedEntries := make([]nodeWithDist, 0, numEntries) + for i := 0; i < numEntries; i++ { + entry := nodeWithDist{ + dist: prand.Float64(), + } + + heap.Push(&nodeHeap, entry) + sortedEntries = append(sortedEntries, entry) + } + + // Sort the regular slice, we'll compare this against all the entries + // popped from the heap. + sort.Sort(&distanceHeap{sortedEntries}) + + // One by one, pop of all the entries from the heap, they should come + // out in sorted order. + var poppedEntries []nodeWithDist + for nodeHeap.Len() != 0 { + e := heap.Pop(&nodeHeap).(nodeWithDist) + poppedEntries = append(poppedEntries, e) + } + + // Finally, ensure that the items popped from the heap and the items we + // sorted are identical at this rate. + if !reflect.DeepEqual(poppedEntries, sortedEntries) { + t.Fatalf("items don't match: expected %v, got %v", sortedEntries, + poppedEntries) + } +} diff --git a/routing/pathfind.go b/routing/pathfind.go index 41f0c9494..e320d2f2b 100644 --- a/routing/pathfind.go +++ b/routing/pathfind.go @@ -208,13 +208,6 @@ func newVertex(pub *btcec.PublicKey) vertex { return v } -// nodeWithDist is a helper struct that couples the distance from the current -// source to a node with a pointer to the node itself. -type nodeWithDist struct { - dist float64 - node *channeldb.LightningNode -} - // edgeWithPrev is a helper struct used in path finding that couples an // directional edge with the node's ID in the opposite direction. type edgeWithPrev struct {