mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-09-01 02:02:10 +02:00
Merge pull request #3488 from joostjager/fix-channeldb-test
channeldb/test: make route comparison a pure function
This commit is contained in:
@@ -2,6 +2,7 @@ package channeldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
@@ -199,14 +200,15 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
assertRouteHopRecordsEqual(&s.Route, &newAttemptInfo.Route)
|
||||
err = assertRouteEqual(&s.Route, &newAttemptInfo.Route)
|
||||
if err != nil {
|
||||
t.Fatalf("Routes do not match after "+
|
||||
"serialization/deserialization: %v", err)
|
||||
}
|
||||
|
||||
// With the hop recrods, equal, we'll now blank them out as
|
||||
// reflect.DeepEqual can't properly compare tlv.Record instances.
|
||||
newAttemptInfo.Route.Hops[0].TLVRecords = nil
|
||||
newAttemptInfo.Route.Hops[1].TLVRecords = nil
|
||||
s.Route.Hops[0].TLVRecords = nil
|
||||
s.Route.Hops[1].TLVRecords = nil
|
||||
// Clear routes to allow DeepEqual to compare the remaining fields.
|
||||
newAttemptInfo.Route = route.Route{}
|
||||
s.Route = route.Route{}
|
||||
|
||||
if !reflect.DeepEqual(s, newAttemptInfo) {
|
||||
s.SessionKey.Curve = nil
|
||||
@@ -216,14 +218,53 @@ func TestSentPaymentSerialization(t *testing.T) {
|
||||
spew.Sdump(s), spew.Sdump(newAttemptInfo),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// assertRouteEquals compares to routes for equality and returns an error if
|
||||
// they are not equal.
|
||||
func assertRouteEqual(a, b *route.Route) error {
|
||||
err := assertRouteHopRecordsEqual(a, b)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TLV records have already been compared and need to be cleared to
|
||||
// properly compare the remaining fields using DeepEqual.
|
||||
copyRouteNoHops := func(r *route.Route) *route.Route {
|
||||
copy := *r
|
||||
copy.Hops = make([]*route.Hop, len(r.Hops))
|
||||
for i, hop := range r.Hops {
|
||||
hopCopy := *hop
|
||||
hopCopy.TLVRecords = nil
|
||||
copy.Hops[i] = &hopCopy
|
||||
}
|
||||
return ©
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(copyRouteNoHops(a), copyRouteNoHops(b)) {
|
||||
return fmt.Errorf("PaymentAttemptInfos don't match: %v vs %v",
|
||||
spew.Sdump(a), spew.Sdump(b))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertRouteHopRecordsEqual(r1, r2 *route.Route) error {
|
||||
if len(r1.Hops) != len(r2.Hops) {
|
||||
return errors.New("route hop count mismatch")
|
||||
}
|
||||
|
||||
for i := 0; i < len(r1.Hops); i++ {
|
||||
for j := 0; j < len(r1.Hops[i].TLVRecords); j++ {
|
||||
expectedRecord := r1.Hops[i].TLVRecords[j]
|
||||
newRecord := r2.Hops[i].TLVRecords[j]
|
||||
records1 := r1.Hops[i].TLVRecords
|
||||
records2 := r2.Hops[i].TLVRecords
|
||||
if len(records1) != len(records2) {
|
||||
return fmt.Errorf("route record count for hop %v "+
|
||||
"mismatch", i)
|
||||
}
|
||||
|
||||
for j := 0; j < len(records1); j++ {
|
||||
expectedRecord := records1[j]
|
||||
newRecord := records2[j]
|
||||
|
||||
err := assertHopRecordsEqual(expectedRecord, newRecord)
|
||||
if err != nil {
|
||||
@@ -275,20 +316,8 @@ func TestRouteSerialization(t *testing.T) {
|
||||
|
||||
// First we verify all the records match up porperly, as they aren't
|
||||
// able to be properly compared using reflect.DeepEqual.
|
||||
err = assertRouteHopRecordsEqual(&testRoute, &route2)
|
||||
err = assertRouteEqual(&testRoute, &route2)
|
||||
if err != nil {
|
||||
t.Fatalf("route tlv records don't match: %v", err)
|
||||
}
|
||||
|
||||
// Now that we know the records match up, we'll examine the remainder
|
||||
// of the route without the TLV records attached as reflect.DeepEqual
|
||||
// can't properly assert their equality.
|
||||
testRoute.Hops[0].TLVRecords = nil
|
||||
testRoute.Hops[1].TLVRecords = nil
|
||||
route2.Hops[0].TLVRecords = nil
|
||||
route2.Hops[1].TLVRecords = nil
|
||||
|
||||
if !reflect.DeepEqual(testRoute, route2) {
|
||||
t.Fatalf("routes not equal: \n%v vs \n%v",
|
||||
spew.Sdump(testRoute), spew.Sdump(route2))
|
||||
}
|
||||
|
Reference in New Issue
Block a user