From 955f4c9182873f0896549ebf4b67c8db8d4cb6d0 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 14 Jul 2025 09:39:58 +0200 Subject: [PATCH 1/3] graph/db: refactor to make postgres fixtures re-usable In preparation for tests where we will want to spin up SQL DBs many times, we do some refactoring so that it is easy to re-use postgres fixtures since those are expensive to spin up. --- graph/db/sql_migration_test.go | 5 +++- graph/db/test_postgres.go | 45 ++++++++++++++++++++++++++++++++++ graph/db/test_sql.go | 23 ----------------- graph/db/test_sqlite.go | 33 +++++++++++++++++++++++++ 4 files changed, 82 insertions(+), 24 deletions(-) delete mode 100644 graph/db/test_sql.go diff --git a/graph/db/sql_migration_test.go b/graph/db/sql_migration_test.go index 9f32603fd..a2d8db261 100644 --- a/graph/db/sql_migration_test.go +++ b/graph/db/sql_migration_test.go @@ -58,6 +58,8 @@ func TestMigrateGraphToSQL(t *testing.T) { t.Parallel() ctx := context.Background() + dbFixture := NewTestDBFixture(t) + writeUpdate := func(t *testing.T, db *KVStore, object any) { t.Helper() @@ -324,7 +326,8 @@ func TestMigrateGraphToSQL(t *testing.T) { } // Set up our destination SQL DB. - sql, ok := NewTestDB(t).(*SQLStore) + db := NewTestDBWithFixture(t, dbFixture) + sql, ok := db.(*SQLStore) require.True(t, ok) // Run the migration. diff --git a/graph/db/test_postgres.go b/graph/db/test_postgres.go index 7af420d19..5e10d94cc 100644 --- a/graph/db/test_postgres.go +++ b/graph/db/test_postgres.go @@ -6,9 +6,46 @@ import ( "database/sql" "testing" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" ) +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB) V1Store { + return NewTestDBWithFixture(t, nil) +} + +// NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes. +func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture { + pgFixture := sqldb.NewTestPgFixture(t, sqldb.DefaultPostgresFixtureLifetime) + t.Cleanup(func() { + pgFixture.TearDown(t) + }) + return pgFixture +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, pgFixture *sqldb.TestPgFixture) V1Store { + var querier BatchedSQLQueries + if pgFixture == nil { + querier = newBatchQuerier(t) + } else { + querier = newBatchQuerierWithFixture(t, pgFixture) + } + + store, err := NewSQLStore( + &SQLStoreConfig{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, + }, querier, + ) + require.NoError(t, err) + + return store +} + // newBatchQuerier creates a new BatchedSQLQueries instance for testing // using a PostgreSQL database fixture. func newBatchQuerier(t testing.TB) BatchedSQLQueries { @@ -19,6 +56,14 @@ func newBatchQuerier(t testing.TB) BatchedSQLQueries { pgFixture.TearDown(t) }) + return newBatchQuerierWithFixture(t, pgFixture) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a PostgreSQL database fixture. +func newBatchQuerierWithFixture(t testing.TB, + pgFixture *sqldb.TestPgFixture) BatchedSQLQueries { + db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB return sqldb.NewTransactionExecutor( diff --git a/graph/db/test_sql.go b/graph/db/test_sql.go deleted file mode 100644 index 9d4d507b3..000000000 --- a/graph/db/test_sql.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build test_db_postgres || test_db_sqlite - -package graphdb - -import ( - "testing" - - "github.com/btcsuite/btcd/chaincfg" - "github.com/stretchr/testify/require" -) - -// NewTestDB is a helper function that creates a SQLStore backed by a SQL -// database for testing. -func NewTestDB(t testing.TB) V1Store { - store, err := NewSQLStore( - &SQLStoreConfig{ - ChainHash: *chaincfg.MainNetParams.GenesisHash, - }, newBatchQuerier(t), - ) - require.NoError(t, err) - - return store -} diff --git a/graph/db/test_sqlite.go b/graph/db/test_sqlite.go index 35f7cb5d8..4d52b00ba 100644 --- a/graph/db/test_sqlite.go +++ b/graph/db/test_sqlite.go @@ -6,12 +6,45 @@ import ( "database/sql" "testing" + "github.com/btcsuite/btcd/chaincfg" "github.com/lightningnetwork/lnd/sqldb" + "github.com/stretchr/testify/require" ) +// NewTestDB is a helper function that creates a SQLStore backed by a SQL +// database for testing. +func NewTestDB(t testing.TB) V1Store { + return NewTestDBWithFixture(t, nil) +} + +// NewTestDBFixture is a no-op for the sqlite build. +func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture { + return nil +} + +// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a +// SQL database for testing. +func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture) V1Store { + store, err := NewSQLStore( + &SQLStoreConfig{ + ChainHash: *chaincfg.MainNetParams.GenesisHash, + }, newBatchQuerier(t), + ) + require.NoError(t, err) + return store +} + // newBatchQuerier creates a new BatchedSQLQueries instance for testing // using a SQLite database. func newBatchQuerier(t testing.TB) BatchedSQLQueries { + return newBatchQuerierWithFixture(t, nil) +} + +// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for +// testing using a SQLite database. +func newBatchQuerierWithFixture(t testing.TB, + _ *sqldb.TestPgFixture) BatchedSQLQueries { + db := sqldb.NewTestSqliteDB(t).BaseDB return sqldb.NewTransactionExecutor( From ed17574196e01ba1d024d13efeb53591e6a12b71 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 14 Jul 2025 09:47:11 +0200 Subject: [PATCH 2/3] lnwire: fix RandNodeAlias to produce valid UTF-8 To ensure that the RandNodeAlias helper can be used elsewhere to generate random aliases, we adjust it in this test to only produce valid UTF-8 strings as required by the spec. --- lnwire/test_utils.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lnwire/test_utils.go b/lnwire/test_utils.go index 1065cbacf..07c9d795b 100644 --- a/lnwire/test_utils.go +++ b/lnwire/test_utils.go @@ -14,7 +14,11 @@ import ( "pgregory.net/rapid" ) -// RandChannelUpdate generates a random ChannelUpdate message using rapid's +// charset contains valid UTF-8 characters that can be used to generate random +// strings for testing purposes. +const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + +// RandPartialSig generates a random ParialSig message using rapid's // generators. func RandPartialSig(t *rapid.T) *PartialSig { // Generate random private key bytes @@ -139,11 +143,11 @@ func RandNodeAlias(t *rapid.T) NodeAlias { var alias NodeAlias aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength") - aliasBytes := rapid.StringN( - 0, aliasLength, aliasLength, + aliasBytes := rapid.SliceOfN( + rapid.SampledFrom([]rune(charset)), aliasLength, aliasLength, ).Draw(t, "alias") - copy(alias[:], aliasBytes) + copy(alias[:], string(aliasBytes)) return alias } From d7b8259a3608b389e93dc968c0599a52a5c436f2 Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Mon, 14 Jul 2025 10:06:04 +0200 Subject: [PATCH 3/3] graph/db: add graph SQL migration rapid unit test --- graph/db/sql_migration_test.go | 418 ++++++++++++++++++++++++++++++++- 1 file changed, 417 insertions(+), 1 deletion(-) diff --git a/graph/db/sql_migration_test.go b/graph/db/sql_migration_test.go index a2d8db261..eaaaee6c5 100644 --- a/graph/db/sql_migration_test.go +++ b/graph/db/sql_migration_test.go @@ -21,6 +21,7 @@ import ( "time" "github.com/btcsuite/btcd/btcec/v2" + "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" @@ -33,6 +34,7 @@ import ( "github.com/lightningnetwork/lnd/routing/route" "github.com/lightningnetwork/lnd/sqldb" "github.com/stretchr/testify/require" + "pgregory.net/rapid" ) var ( @@ -401,6 +403,9 @@ func fetchAllNodes(t *testing.T, store V1Store) []*models.LightningNode { _, err := node.PubKey() require.NoError(t, err) + // Sort the addresses to ensure a consistent order. + sortAddrs(node.Addresses) + nodes = append(nodes, node) return nil @@ -589,7 +594,7 @@ func setUpKVStore(t *testing.T) *KVStore { } // genPubKey generates a new public key for testing purposes. -func genPubKey(t *testing.T) route.Vertex { +func genPubKey(t require.TestingT) route.Vertex { key, err := btcec.NewPrivateKey() require.NoError(t, err) @@ -1165,3 +1170,414 @@ func assertResultState(t *testing.T, sql *SQLStore, expState dbState) { require.True(t, isZombie) } } + +// TestMigrateGraphToSQLRapid tests the migration of graph nodes from a KV +// store to a SQL store using property-based testing to ensure that the +// migration works for a wide variety of randomly generated graph nodes. +func TestMigrateGraphToSQLRapid(t *testing.T) { + t.Parallel() + + dbFixture := NewTestDBFixture(t) + + rapid.Check(t, func(rt *rapid.T) { + const ( + maxNumNodes = 5 + maxNumChannels = 5 + ) + + testMigrateGraphToSQLRapidOnce( + t, rt, dbFixture, maxNumNodes, maxNumChannels, + ) + }) +} + +// testMigrateGraphToSQLRapidOnce is a helper function that performs the actual +// migration test using property-based testing. It sets up a KV store and a +// SQL store, generates random nodes and channels, populates the KV store, +// runs the migration, and asserts that the SQL store contains the expected +// state. +func testMigrateGraphToSQLRapidOnce(t *testing.T, rt *rapid.T, + dbFixture *sqldb.TestPgFixture, maxNumNodes, maxNumChannels int) { + + ctx := context.Background() + + // Set up our source kvdb DB. + kvDB := setUpKVStore(t) + + // Set up our destination SQL DB. + sql, ok := NewTestDBWithFixture(t, dbFixture).(*SQLStore) + require.True(t, ok) + + // Generate a list of random nodes. + nodes := rapid.SliceOfN( + rapid.Custom(genRandomNode), 1, maxNumNodes, + ).Draw(rt, "nodes") + + // Keep track of all nodes that should be in the database. We may expect + // more than just the ones we generated above if we have channels that + // point to shell nodes. + allNodes := make(map[route.Vertex]*models.LightningNode) + var nodePubs []route.Vertex + for _, node := range nodes { + allNodes[node.PubKeyBytes] = node + nodePubs = append(nodePubs, node.PubKeyBytes) + } + + // Generate a list of random channels and policies for those channels. + var ( + channels []*models.ChannelEdgeInfo + chanIDs = make(map[uint64]struct{}) + policies []*models.ChannelEdgePolicy + ) + channelGen := rapid.Custom(func(rtt *rapid.T) *models.ChannelEdgeInfo { + var ( + edge *models.ChannelEdgeInfo + newNodes []route.Vertex + ) + // Loop to ensure that we skip channels with channel IDs + // that we have already used. + for { + edge, newNodes = genRandomChannel(rtt, nodePubs) + if _, ok := chanIDs[edge.ChannelID]; ok { + continue + } + chanIDs[edge.ChannelID] = struct{}{} + + break + } + + // If the new channel points to nodes we don't yet know + // of, then update our expected node list to include + // shell node entries for these. + for _, n := range newNodes { + if _, ok := allNodes[n]; ok { + continue + } + + shellNode := makeTestShellNode( + t, func(node *models.LightningNode) { + node.PubKeyBytes = n + }, + ) + allNodes[n] = shellNode + } + + // Generate either 0, 1 or two policies for this + // channel. + numPolicies := rapid.IntRange(0, 2).Draw( + rtt, "numPolicies", + ) + switch numPolicies { + case 0: + case 1: + // Randomly pick the direction. + policy := genRandomPolicy( + rtt, edge, rapid.Bool().Draw(rtt, "isNode1"), + ) + + policies = append(policies, policy) + case 2: + // Generate two policies, one for each + // direction. + policy1 := genRandomPolicy(rtt, edge, true) + policy2 := genRandomPolicy(rtt, edge, false) + + policies = append(policies, policy1) + policies = append(policies, policy2) + } + + return edge + }) + channels = rapid.SliceOfN( + channelGen, 1, maxNumChannels, + ).Draw(rt, "channels") + + // Write the test objects to the kvdb store. + for _, node := range allNodes { + err := kvDB.AddLightningNode(ctx, node) + require.NoError(t, err) + } + for _, channel := range channels { + err := kvDB.AddChannelEdge(ctx, channel) + require.NoError(t, err) + } + for _, policy := range policies { + _, _, err := kvDB.UpdateEdgePolicy(ctx, policy) + require.NoError(t, err) + } + + // Run the migration. + err := MigrateGraphToSQL(ctx, kvDB.db, sql.db, testChain) + require.NoError(t, err) + + // Create a slice of all nodes. + var nodesSlice []*models.LightningNode + for _, node := range allNodes { + nodesSlice = append(nodesSlice, node) + } + + // Create a map of channels to their policies. + chanMap := make(map[uint64]*chanInfo) + for _, channel := range channels { + chanMap[channel.ChannelID] = &chanInfo{ + edgeInfo: channel, + } + } + + for _, policy := range policies { + info, ok := chanMap[policy.ChannelID] + require.True(t, ok) + + // The IsNode1 flag is encoded in the ChannelFlags. + if policy.ChannelFlags&lnwire.ChanUpdateDirection == 0 { + info.policy1 = policy + } else { + info.policy2 = policy + } + } + + var chanSetForState chanSet + for _, info := range chanMap { + chanSetForState = append(chanSetForState, *info) + } + + // Validate that the sql database has the correct state. + assertResultState(t, sql, dbState{ + nodes: nodesSlice, + chans: chanSetForState, + }) +} + +// genRandomChannel is a rapid generator for creating random channel edge infos. +// It takes a slice of existing node public keys to draw from. If the slice is +// empty, it will always generate new random nodes. +func genRandomChannel(rt *rapid.T, + nodes []route.Vertex) (*models.ChannelEdgeInfo, []route.Vertex) { + + var newNodes []route.Vertex + + // Generate a random channel ID. + chanID := lnwire.RandShortChannelID(rt).ToUint64() + + // Generate a random outpoint. + var hash chainhash.Hash + _, err := rand.Read(hash[:]) + require.NoError(rt, err) + outpoint := wire.OutPoint{ + Hash: hash, + Index: rapid.Uint32().Draw(rt, "outpointIndex"), + } + + // Generate random capacity. + capacity := rapid.Int64Range(1, btcutil.MaxSatoshi).Draw(rt, "capacity") + + // Generate random features. + features := lnwire.NewFeatureVector( + lnwire.RandFeatureVector(rt), + lnwire.Features, + ) + + // Generate random keys for the channel. + bitcoinKey1Bytes := genPubKey(rt) + bitcoinKey2Bytes := genPubKey(rt) + + // Decide if we should use existing nodes or generate new ones. + var nodeKey1Bytes, nodeKey2Bytes route.Vertex + // With a 50/50 chance, we'll use existing nodes. + if len(nodes) > 1 && rapid.Bool().Draw(rt, "useExistingNodes") { + // Pick two random nodes from the existing set. + idx1 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node1") + idx2 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node2") + if idx1 == idx2 { + idx2 = (idx1 + 1) % len(nodes) + } + nodeKey1Bytes = nodes[idx1] + nodeKey2Bytes = nodes[idx2] + } else { + // Generate new random nodes. + nodeKey1Bytes = genPubKey(rt) + nodeKey2Bytes = genPubKey(rt) + newNodes = append(newNodes, nodeKey1Bytes, nodeKey2Bytes) + } + + node1Sig := lnwire.RandSignature(rt) + node2Sig := lnwire.RandSignature(rt) + btc1Sig := lnwire.RandSignature(rt) + btc2Sig := lnwire.RandSignature(rt) + + // Generate a random auth proof. + authProof := &models.ChannelAuthProof{ + NodeSig1Bytes: node1Sig.RawBytes(), + NodeSig2Bytes: node2Sig.RawBytes(), + BitcoinSig1Bytes: btc1Sig.RawBytes(), + BitcoinSig2Bytes: btc2Sig.RawBytes(), + } + + extraOpaque := lnwire.RandExtraOpaqueData(rt, nil) + if len(extraOpaque) == 0 { + extraOpaque = nil + } + + info := &models.ChannelEdgeInfo{ + ChannelID: chanID, + ChainHash: testChain, + NodeKey1Bytes: nodeKey1Bytes, + NodeKey2Bytes: nodeKey2Bytes, + BitcoinKey1Bytes: bitcoinKey1Bytes, + BitcoinKey2Bytes: bitcoinKey2Bytes, + Features: features, + AuthProof: authProof, + ChannelPoint: outpoint, + Capacity: btcutil.Amount(capacity), + ExtraOpaqueData: extraOpaque, + } + + return info, newNodes +} + +// genRandomPolicy is a rapid generator for creating random channel edge +// policies. It takes a slice of existing channels to draw from. +func genRandomPolicy(rt *rapid.T, channel *models.ChannelEdgeInfo, + isNode1 bool) *models.ChannelEdgePolicy { + + var toNode route.Vertex + if isNode1 { + toNode = channel.NodeKey2Bytes + } else { + toNode = channel.NodeKey1Bytes + } + + // Generate a random timestamp. + randTime := time.Unix(rapid.Int64Range( + 0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(), + ).Draw(rt, "policyTimestamp"), 0) + + // Generate random channel update flags and then just make sure to + // unset/set the correct direction bit. + chanFlags := lnwire.ChanUpdateChanFlags( + rapid.Uint8().Draw(rt, "chanFlags"), + ) + if isNode1 { + chanFlags &= ^lnwire.ChanUpdateDirection + } else { + chanFlags |= lnwire.ChanUpdateDirection + } + + extraOpaque := lnwire.RandExtraOpaqueData(rt, nil) + if len(extraOpaque) == 0 { + extraOpaque = nil + } + + hasMaxHTLC := rapid.Bool().Draw(rt, "hasMaxHTLC") + var maxHTLC lnwire.MilliSatoshi + msgFlags := lnwire.ChanUpdateMsgFlags( + rapid.Uint8().Draw(rt, "msgFlags"), + ) + if hasMaxHTLC { + msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc + maxHTLC = lnwire.MilliSatoshi( + rapid.Uint64().Draw(rt, "maxHtlc"), + ) + } else { + msgFlags &= ^lnwire.ChanUpdateRequiredMaxHtlc + } + + return &models.ChannelEdgePolicy{ + SigBytes: testSigBytes, + ChannelID: channel.ChannelID, + LastUpdate: randTime, + MessageFlags: msgFlags, + ChannelFlags: chanFlags, + TimeLockDelta: rapid.Uint16().Draw(rt, "timeLock"), + MinHTLC: lnwire.MilliSatoshi( + rapid.Uint64().Draw(rt, "minHtlc"), + ), + MaxHTLC: maxHTLC, + FeeBaseMSat: lnwire.MilliSatoshi( + rapid.Uint64().Draw(rt, "baseFee"), + ), + FeeProportionalMillionths: lnwire.MilliSatoshi( + rapid.Uint64().Draw(rt, "feeRate"), + ), + ToNode: toNode, + ExtraOpaqueData: extraOpaque, + } +} + +// sortAddrs sorts a slice of net.Addr. +func sortAddrs(addrs []net.Addr) { + if addrs == nil { + return + } + + slices.SortFunc(addrs, func(i, j net.Addr) int { + return strings.Compare(i.String(), j.String()) + }) +} + +// genRandomNode is a rapid generator for creating random lightning nodes. +func genRandomNode(t *rapid.T) *models.LightningNode { + // Generate a random alias that is valid. + alias := lnwire.RandNodeAlias(t) + + // Generate a random public key. + pubKey := lnwire.RandPubKey(t) + var pubKeyBytes [33]byte + copy(pubKeyBytes[:], pubKey.SerializeCompressed()) + + // Generate a random signature. + sig := lnwire.RandSignature(t) + sigBytes := sig.ToSignatureBytes() + + // Generate a random color. + randColor := color.RGBA{ + R: uint8(rapid.IntRange(0, 255). + Draw(t, "R")), + G: uint8(rapid.IntRange(0, 255). + Draw(t, "G")), + B: uint8(rapid.IntRange(0, 255). + Draw(t, "B")), + A: 0, + } + + // Generate a random timestamp. + randTime := time.Unix( + rapid.Int64Range( + 0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(), + ).Draw(t, "timestamp"), 0, + ) + + // Generate random addresses. + addrs := lnwire.RandNetAddrs(t) + sortAddrs(addrs) + + // Generate a random feature vector. + features := lnwire.RandFeatureVector(t) + + // Generate random extra opaque data. + extraOpaqueData := lnwire.RandExtraOpaqueData(t, nil) + if len(extraOpaqueData) == 0 { + extraOpaqueData = nil + } + + node := &models.LightningNode{ + HaveNodeAnnouncement: true, + AuthSigBytes: sigBytes, + LastUpdate: randTime, + Color: randColor, + Alias: alias.String(), + Features: lnwire.NewFeatureVector( + features, lnwire.Features, + ), + Addresses: addrs, + ExtraOpaqueData: extraOpaqueData, + PubKeyBytes: pubKeyBytes, + } + + // We call this method so that the internal pubkey field is populated + // which then lets us to proper struct comparison later on. + _, err := node.PubKey() + require.NoError(t, err) + + return node +}