graph/db: add graph SQL migration rapid unit test

This commit is contained in:
Elle Mouton
2025-07-14 10:06:04 +02:00
parent ed17574196
commit d7b8259a36

View File

@@ -21,6 +21,7 @@ import (
"time"
"github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire"
@@ -33,6 +34,7 @@ import (
"github.com/lightningnetwork/lnd/routing/route"
"github.com/lightningnetwork/lnd/sqldb"
"github.com/stretchr/testify/require"
"pgregory.net/rapid"
)
var (
@@ -401,6 +403,9 @@ func fetchAllNodes(t *testing.T, store V1Store) []*models.LightningNode {
_, err := node.PubKey()
require.NoError(t, err)
// Sort the addresses to ensure a consistent order.
sortAddrs(node.Addresses)
nodes = append(nodes, node)
return nil
@@ -589,7 +594,7 @@ func setUpKVStore(t *testing.T) *KVStore {
}
// genPubKey generates a new public key for testing purposes.
func genPubKey(t *testing.T) route.Vertex {
func genPubKey(t require.TestingT) route.Vertex {
key, err := btcec.NewPrivateKey()
require.NoError(t, err)
@@ -1165,3 +1170,414 @@ func assertResultState(t *testing.T, sql *SQLStore, expState dbState) {
require.True(t, isZombie)
}
}
// TestMigrateGraphToSQLRapid tests the migration of graph nodes from a KV
// store to a SQL store using property-based testing to ensure that the
// migration works for a wide variety of randomly generated graph nodes.
func TestMigrateGraphToSQLRapid(t *testing.T) {
t.Parallel()
dbFixture := NewTestDBFixture(t)
rapid.Check(t, func(rt *rapid.T) {
const (
maxNumNodes = 5
maxNumChannels = 5
)
testMigrateGraphToSQLRapidOnce(
t, rt, dbFixture, maxNumNodes, maxNumChannels,
)
})
}
// testMigrateGraphToSQLRapidOnce is a helper function that performs the actual
// migration test using property-based testing. It sets up a KV store and a
// SQL store, generates random nodes and channels, populates the KV store,
// runs the migration, and asserts that the SQL store contains the expected
// state.
func testMigrateGraphToSQLRapidOnce(t *testing.T, rt *rapid.T,
dbFixture *sqldb.TestPgFixture, maxNumNodes, maxNumChannels int) {
ctx := context.Background()
// Set up our source kvdb DB.
kvDB := setUpKVStore(t)
// Set up our destination SQL DB.
sql, ok := NewTestDBWithFixture(t, dbFixture).(*SQLStore)
require.True(t, ok)
// Generate a list of random nodes.
nodes := rapid.SliceOfN(
rapid.Custom(genRandomNode), 1, maxNumNodes,
).Draw(rt, "nodes")
// Keep track of all nodes that should be in the database. We may expect
// more than just the ones we generated above if we have channels that
// point to shell nodes.
allNodes := make(map[route.Vertex]*models.LightningNode)
var nodePubs []route.Vertex
for _, node := range nodes {
allNodes[node.PubKeyBytes] = node
nodePubs = append(nodePubs, node.PubKeyBytes)
}
// Generate a list of random channels and policies for those channels.
var (
channels []*models.ChannelEdgeInfo
chanIDs = make(map[uint64]struct{})
policies []*models.ChannelEdgePolicy
)
channelGen := rapid.Custom(func(rtt *rapid.T) *models.ChannelEdgeInfo {
var (
edge *models.ChannelEdgeInfo
newNodes []route.Vertex
)
// Loop to ensure that we skip channels with channel IDs
// that we have already used.
for {
edge, newNodes = genRandomChannel(rtt, nodePubs)
if _, ok := chanIDs[edge.ChannelID]; ok {
continue
}
chanIDs[edge.ChannelID] = struct{}{}
break
}
// If the new channel points to nodes we don't yet know
// of, then update our expected node list to include
// shell node entries for these.
for _, n := range newNodes {
if _, ok := allNodes[n]; ok {
continue
}
shellNode := makeTestShellNode(
t, func(node *models.LightningNode) {
node.PubKeyBytes = n
},
)
allNodes[n] = shellNode
}
// Generate either 0, 1 or two policies for this
// channel.
numPolicies := rapid.IntRange(0, 2).Draw(
rtt, "numPolicies",
)
switch numPolicies {
case 0:
case 1:
// Randomly pick the direction.
policy := genRandomPolicy(
rtt, edge, rapid.Bool().Draw(rtt, "isNode1"),
)
policies = append(policies, policy)
case 2:
// Generate two policies, one for each
// direction.
policy1 := genRandomPolicy(rtt, edge, true)
policy2 := genRandomPolicy(rtt, edge, false)
policies = append(policies, policy1)
policies = append(policies, policy2)
}
return edge
})
channels = rapid.SliceOfN(
channelGen, 1, maxNumChannels,
).Draw(rt, "channels")
// Write the test objects to the kvdb store.
for _, node := range allNodes {
err := kvDB.AddLightningNode(ctx, node)
require.NoError(t, err)
}
for _, channel := range channels {
err := kvDB.AddChannelEdge(ctx, channel)
require.NoError(t, err)
}
for _, policy := range policies {
_, _, err := kvDB.UpdateEdgePolicy(ctx, policy)
require.NoError(t, err)
}
// Run the migration.
err := MigrateGraphToSQL(ctx, kvDB.db, sql.db, testChain)
require.NoError(t, err)
// Create a slice of all nodes.
var nodesSlice []*models.LightningNode
for _, node := range allNodes {
nodesSlice = append(nodesSlice, node)
}
// Create a map of channels to their policies.
chanMap := make(map[uint64]*chanInfo)
for _, channel := range channels {
chanMap[channel.ChannelID] = &chanInfo{
edgeInfo: channel,
}
}
for _, policy := range policies {
info, ok := chanMap[policy.ChannelID]
require.True(t, ok)
// The IsNode1 flag is encoded in the ChannelFlags.
if policy.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
info.policy1 = policy
} else {
info.policy2 = policy
}
}
var chanSetForState chanSet
for _, info := range chanMap {
chanSetForState = append(chanSetForState, *info)
}
// Validate that the sql database has the correct state.
assertResultState(t, sql, dbState{
nodes: nodesSlice,
chans: chanSetForState,
})
}
// genRandomChannel is a rapid generator for creating random channel edge infos.
// It takes a slice of existing node public keys to draw from. If the slice is
// empty, it will always generate new random nodes.
func genRandomChannel(rt *rapid.T,
nodes []route.Vertex) (*models.ChannelEdgeInfo, []route.Vertex) {
var newNodes []route.Vertex
// Generate a random channel ID.
chanID := lnwire.RandShortChannelID(rt).ToUint64()
// Generate a random outpoint.
var hash chainhash.Hash
_, err := rand.Read(hash[:])
require.NoError(rt, err)
outpoint := wire.OutPoint{
Hash: hash,
Index: rapid.Uint32().Draw(rt, "outpointIndex"),
}
// Generate random capacity.
capacity := rapid.Int64Range(1, btcutil.MaxSatoshi).Draw(rt, "capacity")
// Generate random features.
features := lnwire.NewFeatureVector(
lnwire.RandFeatureVector(rt),
lnwire.Features,
)
// Generate random keys for the channel.
bitcoinKey1Bytes := genPubKey(rt)
bitcoinKey2Bytes := genPubKey(rt)
// Decide if we should use existing nodes or generate new ones.
var nodeKey1Bytes, nodeKey2Bytes route.Vertex
// With a 50/50 chance, we'll use existing nodes.
if len(nodes) > 1 && rapid.Bool().Draw(rt, "useExistingNodes") {
// Pick two random nodes from the existing set.
idx1 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node1")
idx2 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node2")
if idx1 == idx2 {
idx2 = (idx1 + 1) % len(nodes)
}
nodeKey1Bytes = nodes[idx1]
nodeKey2Bytes = nodes[idx2]
} else {
// Generate new random nodes.
nodeKey1Bytes = genPubKey(rt)
nodeKey2Bytes = genPubKey(rt)
newNodes = append(newNodes, nodeKey1Bytes, nodeKey2Bytes)
}
node1Sig := lnwire.RandSignature(rt)
node2Sig := lnwire.RandSignature(rt)
btc1Sig := lnwire.RandSignature(rt)
btc2Sig := lnwire.RandSignature(rt)
// Generate a random auth proof.
authProof := &models.ChannelAuthProof{
NodeSig1Bytes: node1Sig.RawBytes(),
NodeSig2Bytes: node2Sig.RawBytes(),
BitcoinSig1Bytes: btc1Sig.RawBytes(),
BitcoinSig2Bytes: btc2Sig.RawBytes(),
}
extraOpaque := lnwire.RandExtraOpaqueData(rt, nil)
if len(extraOpaque) == 0 {
extraOpaque = nil
}
info := &models.ChannelEdgeInfo{
ChannelID: chanID,
ChainHash: testChain,
NodeKey1Bytes: nodeKey1Bytes,
NodeKey2Bytes: nodeKey2Bytes,
BitcoinKey1Bytes: bitcoinKey1Bytes,
BitcoinKey2Bytes: bitcoinKey2Bytes,
Features: features,
AuthProof: authProof,
ChannelPoint: outpoint,
Capacity: btcutil.Amount(capacity),
ExtraOpaqueData: extraOpaque,
}
return info, newNodes
}
// genRandomPolicy is a rapid generator for creating random channel edge
// policies. It takes a slice of existing channels to draw from.
func genRandomPolicy(rt *rapid.T, channel *models.ChannelEdgeInfo,
isNode1 bool) *models.ChannelEdgePolicy {
var toNode route.Vertex
if isNode1 {
toNode = channel.NodeKey2Bytes
} else {
toNode = channel.NodeKey1Bytes
}
// Generate a random timestamp.
randTime := time.Unix(rapid.Int64Range(
0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(),
).Draw(rt, "policyTimestamp"), 0)
// Generate random channel update flags and then just make sure to
// unset/set the correct direction bit.
chanFlags := lnwire.ChanUpdateChanFlags(
rapid.Uint8().Draw(rt, "chanFlags"),
)
if isNode1 {
chanFlags &= ^lnwire.ChanUpdateDirection
} else {
chanFlags |= lnwire.ChanUpdateDirection
}
extraOpaque := lnwire.RandExtraOpaqueData(rt, nil)
if len(extraOpaque) == 0 {
extraOpaque = nil
}
hasMaxHTLC := rapid.Bool().Draw(rt, "hasMaxHTLC")
var maxHTLC lnwire.MilliSatoshi
msgFlags := lnwire.ChanUpdateMsgFlags(
rapid.Uint8().Draw(rt, "msgFlags"),
)
if hasMaxHTLC {
msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
maxHTLC = lnwire.MilliSatoshi(
rapid.Uint64().Draw(rt, "maxHtlc"),
)
} else {
msgFlags &= ^lnwire.ChanUpdateRequiredMaxHtlc
}
return &models.ChannelEdgePolicy{
SigBytes: testSigBytes,
ChannelID: channel.ChannelID,
LastUpdate: randTime,
MessageFlags: msgFlags,
ChannelFlags: chanFlags,
TimeLockDelta: rapid.Uint16().Draw(rt, "timeLock"),
MinHTLC: lnwire.MilliSatoshi(
rapid.Uint64().Draw(rt, "minHtlc"),
),
MaxHTLC: maxHTLC,
FeeBaseMSat: lnwire.MilliSatoshi(
rapid.Uint64().Draw(rt, "baseFee"),
),
FeeProportionalMillionths: lnwire.MilliSatoshi(
rapid.Uint64().Draw(rt, "feeRate"),
),
ToNode: toNode,
ExtraOpaqueData: extraOpaque,
}
}
// sortAddrs sorts a slice of net.Addr.
func sortAddrs(addrs []net.Addr) {
if addrs == nil {
return
}
slices.SortFunc(addrs, func(i, j net.Addr) int {
return strings.Compare(i.String(), j.String())
})
}
// genRandomNode is a rapid generator for creating random lightning nodes.
func genRandomNode(t *rapid.T) *models.LightningNode {
// Generate a random alias that is valid.
alias := lnwire.RandNodeAlias(t)
// Generate a random public key.
pubKey := lnwire.RandPubKey(t)
var pubKeyBytes [33]byte
copy(pubKeyBytes[:], pubKey.SerializeCompressed())
// Generate a random signature.
sig := lnwire.RandSignature(t)
sigBytes := sig.ToSignatureBytes()
// Generate a random color.
randColor := color.RGBA{
R: uint8(rapid.IntRange(0, 255).
Draw(t, "R")),
G: uint8(rapid.IntRange(0, 255).
Draw(t, "G")),
B: uint8(rapid.IntRange(0, 255).
Draw(t, "B")),
A: 0,
}
// Generate a random timestamp.
randTime := time.Unix(
rapid.Int64Range(
0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(),
).Draw(t, "timestamp"), 0,
)
// Generate random addresses.
addrs := lnwire.RandNetAddrs(t)
sortAddrs(addrs)
// Generate a random feature vector.
features := lnwire.RandFeatureVector(t)
// Generate random extra opaque data.
extraOpaqueData := lnwire.RandExtraOpaqueData(t, nil)
if len(extraOpaqueData) == 0 {
extraOpaqueData = nil
}
node := &models.LightningNode{
HaveNodeAnnouncement: true,
AuthSigBytes: sigBytes,
LastUpdate: randTime,
Color: randColor,
Alias: alias.String(),
Features: lnwire.NewFeatureVector(
features, lnwire.Features,
),
Addresses: addrs,
ExtraOpaqueData: extraOpaqueData,
PubKeyBytes: pubKeyBytes,
}
// We call this method so that the internal pubkey field is populated
// which then lets us to proper struct comparison later on.
_, err := node.PubKey()
require.NoError(t, err)
return node
}