mirror of
https://github.com/lightningnetwork/lnd.git
synced 2025-08-28 14:40:51 +02:00
Merge pull request #10073 from ellemouton/graphMigUnitTestsRapid
graph/db: add graph SQL migration rapid unit test
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/btcec/v2"
|
"github.com/btcsuite/btcd/btcec/v2"
|
||||||
|
"github.com/btcsuite/btcd/btcutil"
|
||||||
"github.com/btcsuite/btcd/chaincfg"
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
"github.com/btcsuite/btcd/chaincfg/chainhash"
|
||||||
"github.com/btcsuite/btcd/wire"
|
"github.com/btcsuite/btcd/wire"
|
||||||
@@ -33,6 +34,7 @@ import (
|
|||||||
"github.com/lightningnetwork/lnd/routing/route"
|
"github.com/lightningnetwork/lnd/routing/route"
|
||||||
"github.com/lightningnetwork/lnd/sqldb"
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"pgregory.net/rapid"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -58,6 +60,8 @@ func TestMigrateGraphToSQL(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
|
dbFixture := NewTestDBFixture(t)
|
||||||
|
|
||||||
writeUpdate := func(t *testing.T, db *KVStore, object any) {
|
writeUpdate := func(t *testing.T, db *KVStore, object any) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
@@ -324,7 +328,8 @@ func TestMigrateGraphToSQL(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set up our destination SQL DB.
|
// Set up our destination SQL DB.
|
||||||
sql, ok := NewTestDB(t).(*SQLStore)
|
db := NewTestDBWithFixture(t, dbFixture)
|
||||||
|
sql, ok := db.(*SQLStore)
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
|
|
||||||
// Run the migration.
|
// Run the migration.
|
||||||
@@ -398,6 +403,9 @@ func fetchAllNodes(t *testing.T, store V1Store) []*models.LightningNode {
|
|||||||
_, err := node.PubKey()
|
_, err := node.PubKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Sort the addresses to ensure a consistent order.
|
||||||
|
sortAddrs(node.Addresses)
|
||||||
|
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -586,7 +594,7 @@ func setUpKVStore(t *testing.T) *KVStore {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// genPubKey generates a new public key for testing purposes.
|
// genPubKey generates a new public key for testing purposes.
|
||||||
func genPubKey(t *testing.T) route.Vertex {
|
func genPubKey(t require.TestingT) route.Vertex {
|
||||||
key, err := btcec.NewPrivateKey()
|
key, err := btcec.NewPrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -1162,3 +1170,414 @@ func assertResultState(t *testing.T, sql *SQLStore, expState dbState) {
|
|||||||
require.True(t, isZombie)
|
require.True(t, isZombie)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestMigrateGraphToSQLRapid tests the migration of graph nodes from a KV
|
||||||
|
// store to a SQL store using property-based testing to ensure that the
|
||||||
|
// migration works for a wide variety of randomly generated graph nodes.
|
||||||
|
func TestMigrateGraphToSQLRapid(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
dbFixture := NewTestDBFixture(t)
|
||||||
|
|
||||||
|
rapid.Check(t, func(rt *rapid.T) {
|
||||||
|
const (
|
||||||
|
maxNumNodes = 5
|
||||||
|
maxNumChannels = 5
|
||||||
|
)
|
||||||
|
|
||||||
|
testMigrateGraphToSQLRapidOnce(
|
||||||
|
t, rt, dbFixture, maxNumNodes, maxNumChannels,
|
||||||
|
)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// testMigrateGraphToSQLRapidOnce is a helper function that performs the actual
|
||||||
|
// migration test using property-based testing. It sets up a KV store and a
|
||||||
|
// SQL store, generates random nodes and channels, populates the KV store,
|
||||||
|
// runs the migration, and asserts that the SQL store contains the expected
|
||||||
|
// state.
|
||||||
|
func testMigrateGraphToSQLRapidOnce(t *testing.T, rt *rapid.T,
|
||||||
|
dbFixture *sqldb.TestPgFixture, maxNumNodes, maxNumChannels int) {
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
// Set up our source kvdb DB.
|
||||||
|
kvDB := setUpKVStore(t)
|
||||||
|
|
||||||
|
// Set up our destination SQL DB.
|
||||||
|
sql, ok := NewTestDBWithFixture(t, dbFixture).(*SQLStore)
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// Generate a list of random nodes.
|
||||||
|
nodes := rapid.SliceOfN(
|
||||||
|
rapid.Custom(genRandomNode), 1, maxNumNodes,
|
||||||
|
).Draw(rt, "nodes")
|
||||||
|
|
||||||
|
// Keep track of all nodes that should be in the database. We may expect
|
||||||
|
// more than just the ones we generated above if we have channels that
|
||||||
|
// point to shell nodes.
|
||||||
|
allNodes := make(map[route.Vertex]*models.LightningNode)
|
||||||
|
var nodePubs []route.Vertex
|
||||||
|
for _, node := range nodes {
|
||||||
|
allNodes[node.PubKeyBytes] = node
|
||||||
|
nodePubs = append(nodePubs, node.PubKeyBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a list of random channels and policies for those channels.
|
||||||
|
var (
|
||||||
|
channels []*models.ChannelEdgeInfo
|
||||||
|
chanIDs = make(map[uint64]struct{})
|
||||||
|
policies []*models.ChannelEdgePolicy
|
||||||
|
)
|
||||||
|
channelGen := rapid.Custom(func(rtt *rapid.T) *models.ChannelEdgeInfo {
|
||||||
|
var (
|
||||||
|
edge *models.ChannelEdgeInfo
|
||||||
|
newNodes []route.Vertex
|
||||||
|
)
|
||||||
|
// Loop to ensure that we skip channels with channel IDs
|
||||||
|
// that we have already used.
|
||||||
|
for {
|
||||||
|
edge, newNodes = genRandomChannel(rtt, nodePubs)
|
||||||
|
if _, ok := chanIDs[edge.ChannelID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
chanIDs[edge.ChannelID] = struct{}{}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the new channel points to nodes we don't yet know
|
||||||
|
// of, then update our expected node list to include
|
||||||
|
// shell node entries for these.
|
||||||
|
for _, n := range newNodes {
|
||||||
|
if _, ok := allNodes[n]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
shellNode := makeTestShellNode(
|
||||||
|
t, func(node *models.LightningNode) {
|
||||||
|
node.PubKeyBytes = n
|
||||||
|
},
|
||||||
|
)
|
||||||
|
allNodes[n] = shellNode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate either 0, 1 or two policies for this
|
||||||
|
// channel.
|
||||||
|
numPolicies := rapid.IntRange(0, 2).Draw(
|
||||||
|
rtt, "numPolicies",
|
||||||
|
)
|
||||||
|
switch numPolicies {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
// Randomly pick the direction.
|
||||||
|
policy := genRandomPolicy(
|
||||||
|
rtt, edge, rapid.Bool().Draw(rtt, "isNode1"),
|
||||||
|
)
|
||||||
|
|
||||||
|
policies = append(policies, policy)
|
||||||
|
case 2:
|
||||||
|
// Generate two policies, one for each
|
||||||
|
// direction.
|
||||||
|
policy1 := genRandomPolicy(rtt, edge, true)
|
||||||
|
policy2 := genRandomPolicy(rtt, edge, false)
|
||||||
|
|
||||||
|
policies = append(policies, policy1)
|
||||||
|
policies = append(policies, policy2)
|
||||||
|
}
|
||||||
|
|
||||||
|
return edge
|
||||||
|
})
|
||||||
|
channels = rapid.SliceOfN(
|
||||||
|
channelGen, 1, maxNumChannels,
|
||||||
|
).Draw(rt, "channels")
|
||||||
|
|
||||||
|
// Write the test objects to the kvdb store.
|
||||||
|
for _, node := range allNodes {
|
||||||
|
err := kvDB.AddLightningNode(ctx, node)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
for _, channel := range channels {
|
||||||
|
err := kvDB.AddChannelEdge(ctx, channel)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
for _, policy := range policies {
|
||||||
|
_, _, err := kvDB.UpdateEdgePolicy(ctx, policy)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run the migration.
|
||||||
|
err := MigrateGraphToSQL(ctx, kvDB.db, sql.db, testChain)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Create a slice of all nodes.
|
||||||
|
var nodesSlice []*models.LightningNode
|
||||||
|
for _, node := range allNodes {
|
||||||
|
nodesSlice = append(nodesSlice, node)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a map of channels to their policies.
|
||||||
|
chanMap := make(map[uint64]*chanInfo)
|
||||||
|
for _, channel := range channels {
|
||||||
|
chanMap[channel.ChannelID] = &chanInfo{
|
||||||
|
edgeInfo: channel,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, policy := range policies {
|
||||||
|
info, ok := chanMap[policy.ChannelID]
|
||||||
|
require.True(t, ok)
|
||||||
|
|
||||||
|
// The IsNode1 flag is encoded in the ChannelFlags.
|
||||||
|
if policy.ChannelFlags&lnwire.ChanUpdateDirection == 0 {
|
||||||
|
info.policy1 = policy
|
||||||
|
} else {
|
||||||
|
info.policy2 = policy
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var chanSetForState chanSet
|
||||||
|
for _, info := range chanMap {
|
||||||
|
chanSetForState = append(chanSetForState, *info)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that the sql database has the correct state.
|
||||||
|
assertResultState(t, sql, dbState{
|
||||||
|
nodes: nodesSlice,
|
||||||
|
chans: chanSetForState,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// genRandomChannel is a rapid generator for creating random channel edge infos.
|
||||||
|
// It takes a slice of existing node public keys to draw from. If the slice is
|
||||||
|
// empty, it will always generate new random nodes.
|
||||||
|
func genRandomChannel(rt *rapid.T,
|
||||||
|
nodes []route.Vertex) (*models.ChannelEdgeInfo, []route.Vertex) {
|
||||||
|
|
||||||
|
var newNodes []route.Vertex
|
||||||
|
|
||||||
|
// Generate a random channel ID.
|
||||||
|
chanID := lnwire.RandShortChannelID(rt).ToUint64()
|
||||||
|
|
||||||
|
// Generate a random outpoint.
|
||||||
|
var hash chainhash.Hash
|
||||||
|
_, err := rand.Read(hash[:])
|
||||||
|
require.NoError(rt, err)
|
||||||
|
outpoint := wire.OutPoint{
|
||||||
|
Hash: hash,
|
||||||
|
Index: rapid.Uint32().Draw(rt, "outpointIndex"),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate random capacity.
|
||||||
|
capacity := rapid.Int64Range(1, btcutil.MaxSatoshi).Draw(rt, "capacity")
|
||||||
|
|
||||||
|
// Generate random features.
|
||||||
|
features := lnwire.NewFeatureVector(
|
||||||
|
lnwire.RandFeatureVector(rt),
|
||||||
|
lnwire.Features,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate random keys for the channel.
|
||||||
|
bitcoinKey1Bytes := genPubKey(rt)
|
||||||
|
bitcoinKey2Bytes := genPubKey(rt)
|
||||||
|
|
||||||
|
// Decide if we should use existing nodes or generate new ones.
|
||||||
|
var nodeKey1Bytes, nodeKey2Bytes route.Vertex
|
||||||
|
// With a 50/50 chance, we'll use existing nodes.
|
||||||
|
if len(nodes) > 1 && rapid.Bool().Draw(rt, "useExistingNodes") {
|
||||||
|
// Pick two random nodes from the existing set.
|
||||||
|
idx1 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node1")
|
||||||
|
idx2 := rapid.IntRange(0, len(nodes)-1).Draw(rt, "node2")
|
||||||
|
if idx1 == idx2 {
|
||||||
|
idx2 = (idx1 + 1) % len(nodes)
|
||||||
|
}
|
||||||
|
nodeKey1Bytes = nodes[idx1]
|
||||||
|
nodeKey2Bytes = nodes[idx2]
|
||||||
|
} else {
|
||||||
|
// Generate new random nodes.
|
||||||
|
nodeKey1Bytes = genPubKey(rt)
|
||||||
|
nodeKey2Bytes = genPubKey(rt)
|
||||||
|
newNodes = append(newNodes, nodeKey1Bytes, nodeKey2Bytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
node1Sig := lnwire.RandSignature(rt)
|
||||||
|
node2Sig := lnwire.RandSignature(rt)
|
||||||
|
btc1Sig := lnwire.RandSignature(rt)
|
||||||
|
btc2Sig := lnwire.RandSignature(rt)
|
||||||
|
|
||||||
|
// Generate a random auth proof.
|
||||||
|
authProof := &models.ChannelAuthProof{
|
||||||
|
NodeSig1Bytes: node1Sig.RawBytes(),
|
||||||
|
NodeSig2Bytes: node2Sig.RawBytes(),
|
||||||
|
BitcoinSig1Bytes: btc1Sig.RawBytes(),
|
||||||
|
BitcoinSig2Bytes: btc2Sig.RawBytes(),
|
||||||
|
}
|
||||||
|
|
||||||
|
extraOpaque := lnwire.RandExtraOpaqueData(rt, nil)
|
||||||
|
if len(extraOpaque) == 0 {
|
||||||
|
extraOpaque = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
info := &models.ChannelEdgeInfo{
|
||||||
|
ChannelID: chanID,
|
||||||
|
ChainHash: testChain,
|
||||||
|
NodeKey1Bytes: nodeKey1Bytes,
|
||||||
|
NodeKey2Bytes: nodeKey2Bytes,
|
||||||
|
BitcoinKey1Bytes: bitcoinKey1Bytes,
|
||||||
|
BitcoinKey2Bytes: bitcoinKey2Bytes,
|
||||||
|
Features: features,
|
||||||
|
AuthProof: authProof,
|
||||||
|
ChannelPoint: outpoint,
|
||||||
|
Capacity: btcutil.Amount(capacity),
|
||||||
|
ExtraOpaqueData: extraOpaque,
|
||||||
|
}
|
||||||
|
|
||||||
|
return info, newNodes
|
||||||
|
}
|
||||||
|
|
||||||
|
// genRandomPolicy is a rapid generator for creating random channel edge
|
||||||
|
// policies. It takes a slice of existing channels to draw from.
|
||||||
|
func genRandomPolicy(rt *rapid.T, channel *models.ChannelEdgeInfo,
|
||||||
|
isNode1 bool) *models.ChannelEdgePolicy {
|
||||||
|
|
||||||
|
var toNode route.Vertex
|
||||||
|
if isNode1 {
|
||||||
|
toNode = channel.NodeKey2Bytes
|
||||||
|
} else {
|
||||||
|
toNode = channel.NodeKey1Bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random timestamp.
|
||||||
|
randTime := time.Unix(rapid.Int64Range(
|
||||||
|
0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(),
|
||||||
|
).Draw(rt, "policyTimestamp"), 0)
|
||||||
|
|
||||||
|
// Generate random channel update flags and then just make sure to
|
||||||
|
// unset/set the correct direction bit.
|
||||||
|
chanFlags := lnwire.ChanUpdateChanFlags(
|
||||||
|
rapid.Uint8().Draw(rt, "chanFlags"),
|
||||||
|
)
|
||||||
|
if isNode1 {
|
||||||
|
chanFlags &= ^lnwire.ChanUpdateDirection
|
||||||
|
} else {
|
||||||
|
chanFlags |= lnwire.ChanUpdateDirection
|
||||||
|
}
|
||||||
|
|
||||||
|
extraOpaque := lnwire.RandExtraOpaqueData(rt, nil)
|
||||||
|
if len(extraOpaque) == 0 {
|
||||||
|
extraOpaque = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
hasMaxHTLC := rapid.Bool().Draw(rt, "hasMaxHTLC")
|
||||||
|
var maxHTLC lnwire.MilliSatoshi
|
||||||
|
msgFlags := lnwire.ChanUpdateMsgFlags(
|
||||||
|
rapid.Uint8().Draw(rt, "msgFlags"),
|
||||||
|
)
|
||||||
|
if hasMaxHTLC {
|
||||||
|
msgFlags |= lnwire.ChanUpdateRequiredMaxHtlc
|
||||||
|
maxHTLC = lnwire.MilliSatoshi(
|
||||||
|
rapid.Uint64().Draw(rt, "maxHtlc"),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
msgFlags &= ^lnwire.ChanUpdateRequiredMaxHtlc
|
||||||
|
}
|
||||||
|
|
||||||
|
return &models.ChannelEdgePolicy{
|
||||||
|
SigBytes: testSigBytes,
|
||||||
|
ChannelID: channel.ChannelID,
|
||||||
|
LastUpdate: randTime,
|
||||||
|
MessageFlags: msgFlags,
|
||||||
|
ChannelFlags: chanFlags,
|
||||||
|
TimeLockDelta: rapid.Uint16().Draw(rt, "timeLock"),
|
||||||
|
MinHTLC: lnwire.MilliSatoshi(
|
||||||
|
rapid.Uint64().Draw(rt, "minHtlc"),
|
||||||
|
),
|
||||||
|
MaxHTLC: maxHTLC,
|
||||||
|
FeeBaseMSat: lnwire.MilliSatoshi(
|
||||||
|
rapid.Uint64().Draw(rt, "baseFee"),
|
||||||
|
),
|
||||||
|
FeeProportionalMillionths: lnwire.MilliSatoshi(
|
||||||
|
rapid.Uint64().Draw(rt, "feeRate"),
|
||||||
|
),
|
||||||
|
ToNode: toNode,
|
||||||
|
ExtraOpaqueData: extraOpaque,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortAddrs sorts a slice of net.Addr.
|
||||||
|
func sortAddrs(addrs []net.Addr) {
|
||||||
|
if addrs == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(addrs, func(i, j net.Addr) int {
|
||||||
|
return strings.Compare(i.String(), j.String())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// genRandomNode is a rapid generator for creating random lightning nodes.
|
||||||
|
func genRandomNode(t *rapid.T) *models.LightningNode {
|
||||||
|
// Generate a random alias that is valid.
|
||||||
|
alias := lnwire.RandNodeAlias(t)
|
||||||
|
|
||||||
|
// Generate a random public key.
|
||||||
|
pubKey := lnwire.RandPubKey(t)
|
||||||
|
var pubKeyBytes [33]byte
|
||||||
|
copy(pubKeyBytes[:], pubKey.SerializeCompressed())
|
||||||
|
|
||||||
|
// Generate a random signature.
|
||||||
|
sig := lnwire.RandSignature(t)
|
||||||
|
sigBytes := sig.ToSignatureBytes()
|
||||||
|
|
||||||
|
// Generate a random color.
|
||||||
|
randColor := color.RGBA{
|
||||||
|
R: uint8(rapid.IntRange(0, 255).
|
||||||
|
Draw(t, "R")),
|
||||||
|
G: uint8(rapid.IntRange(0, 255).
|
||||||
|
Draw(t, "G")),
|
||||||
|
B: uint8(rapid.IntRange(0, 255).
|
||||||
|
Draw(t, "B")),
|
||||||
|
A: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a random timestamp.
|
||||||
|
randTime := time.Unix(
|
||||||
|
rapid.Int64Range(
|
||||||
|
0, time.Date(2030, 1, 1, 0, 0, 0, 0, time.UTC).Unix(),
|
||||||
|
).Draw(t, "timestamp"), 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Generate random addresses.
|
||||||
|
addrs := lnwire.RandNetAddrs(t)
|
||||||
|
sortAddrs(addrs)
|
||||||
|
|
||||||
|
// Generate a random feature vector.
|
||||||
|
features := lnwire.RandFeatureVector(t)
|
||||||
|
|
||||||
|
// Generate random extra opaque data.
|
||||||
|
extraOpaqueData := lnwire.RandExtraOpaqueData(t, nil)
|
||||||
|
if len(extraOpaqueData) == 0 {
|
||||||
|
extraOpaqueData = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
node := &models.LightningNode{
|
||||||
|
HaveNodeAnnouncement: true,
|
||||||
|
AuthSigBytes: sigBytes,
|
||||||
|
LastUpdate: randTime,
|
||||||
|
Color: randColor,
|
||||||
|
Alias: alias.String(),
|
||||||
|
Features: lnwire.NewFeatureVector(
|
||||||
|
features, lnwire.Features,
|
||||||
|
),
|
||||||
|
Addresses: addrs,
|
||||||
|
ExtraOpaqueData: extraOpaqueData,
|
||||||
|
PubKeyBytes: pubKeyBytes,
|
||||||
|
}
|
||||||
|
|
||||||
|
// We call this method so that the internal pubkey field is populated
|
||||||
|
// which then lets us to proper struct comparison later on.
|
||||||
|
_, err := node.PubKey()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return node
|
||||||
|
}
|
||||||
|
@@ -6,9 +6,46 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
"github.com/lightningnetwork/lnd/sqldb"
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewTestDB is a helper function that creates a SQLStore backed by a SQL
|
||||||
|
// database for testing.
|
||||||
|
func NewTestDB(t testing.TB) V1Store {
|
||||||
|
return NewTestDBWithFixture(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestDBFixture creates a new sqldb.TestPgFixture for testing purposes.
|
||||||
|
func NewTestDBFixture(t *testing.T) *sqldb.TestPgFixture {
|
||||||
|
pgFixture := sqldb.NewTestPgFixture(t, sqldb.DefaultPostgresFixtureLifetime)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
pgFixture.TearDown(t)
|
||||||
|
})
|
||||||
|
return pgFixture
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a
|
||||||
|
// SQL database for testing.
|
||||||
|
func NewTestDBWithFixture(t testing.TB, pgFixture *sqldb.TestPgFixture) V1Store {
|
||||||
|
var querier BatchedSQLQueries
|
||||||
|
if pgFixture == nil {
|
||||||
|
querier = newBatchQuerier(t)
|
||||||
|
} else {
|
||||||
|
querier = newBatchQuerierWithFixture(t, pgFixture)
|
||||||
|
}
|
||||||
|
|
||||||
|
store, err := NewSQLStore(
|
||||||
|
&SQLStoreConfig{
|
||||||
|
ChainHash: *chaincfg.MainNetParams.GenesisHash,
|
||||||
|
}, querier,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
// newBatchQuerier creates a new BatchedSQLQueries instance for testing
|
// newBatchQuerier creates a new BatchedSQLQueries instance for testing
|
||||||
// using a PostgreSQL database fixture.
|
// using a PostgreSQL database fixture.
|
||||||
func newBatchQuerier(t testing.TB) BatchedSQLQueries {
|
func newBatchQuerier(t testing.TB) BatchedSQLQueries {
|
||||||
@@ -19,6 +56,14 @@ func newBatchQuerier(t testing.TB) BatchedSQLQueries {
|
|||||||
pgFixture.TearDown(t)
|
pgFixture.TearDown(t)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
return newBatchQuerierWithFixture(t, pgFixture)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for
|
||||||
|
// testing using a PostgreSQL database fixture.
|
||||||
|
func newBatchQuerierWithFixture(t testing.TB,
|
||||||
|
pgFixture *sqldb.TestPgFixture) BatchedSQLQueries {
|
||||||
|
|
||||||
db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
db := sqldb.NewTestPostgresDB(t, pgFixture).BaseDB
|
||||||
|
|
||||||
return sqldb.NewTransactionExecutor(
|
return sqldb.NewTransactionExecutor(
|
||||||
|
@@ -1,23 +0,0 @@
|
|||||||
//go:build test_db_postgres || test_db_sqlite
|
|
||||||
|
|
||||||
package graphdb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/btcsuite/btcd/chaincfg"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
// NewTestDB is a helper function that creates a SQLStore backed by a SQL
|
|
||||||
// database for testing.
|
|
||||||
func NewTestDB(t testing.TB) V1Store {
|
|
||||||
store, err := NewSQLStore(
|
|
||||||
&SQLStoreConfig{
|
|
||||||
ChainHash: *chaincfg.MainNetParams.GenesisHash,
|
|
||||||
}, newBatchQuerier(t),
|
|
||||||
)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
return store
|
|
||||||
}
|
|
@@ -6,12 +6,45 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/btcsuite/btcd/chaincfg"
|
||||||
"github.com/lightningnetwork/lnd/sqldb"
|
"github.com/lightningnetwork/lnd/sqldb"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// NewTestDB is a helper function that creates a SQLStore backed by a SQL
|
||||||
|
// database for testing.
|
||||||
|
func NewTestDB(t testing.TB) V1Store {
|
||||||
|
return NewTestDBWithFixture(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestDBFixture is a no-op for the sqlite build.
|
||||||
|
func NewTestDBFixture(_ *testing.T) *sqldb.TestPgFixture {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTestDBWithFixture is a helper function that creates a SQLStore backed by a
|
||||||
|
// SQL database for testing.
|
||||||
|
func NewTestDBWithFixture(t testing.TB, _ *sqldb.TestPgFixture) V1Store {
|
||||||
|
store, err := NewSQLStore(
|
||||||
|
&SQLStoreConfig{
|
||||||
|
ChainHash: *chaincfg.MainNetParams.GenesisHash,
|
||||||
|
}, newBatchQuerier(t),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return store
|
||||||
|
}
|
||||||
|
|
||||||
// newBatchQuerier creates a new BatchedSQLQueries instance for testing
|
// newBatchQuerier creates a new BatchedSQLQueries instance for testing
|
||||||
// using a SQLite database.
|
// using a SQLite database.
|
||||||
func newBatchQuerier(t testing.TB) BatchedSQLQueries {
|
func newBatchQuerier(t testing.TB) BatchedSQLQueries {
|
||||||
|
return newBatchQuerierWithFixture(t, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
// newBatchQuerierWithFixture creates a new BatchedSQLQueries instance for
|
||||||
|
// testing using a SQLite database.
|
||||||
|
func newBatchQuerierWithFixture(t testing.TB,
|
||||||
|
_ *sqldb.TestPgFixture) BatchedSQLQueries {
|
||||||
|
|
||||||
db := sqldb.NewTestSqliteDB(t).BaseDB
|
db := sqldb.NewTestSqliteDB(t).BaseDB
|
||||||
|
|
||||||
return sqldb.NewTransactionExecutor(
|
return sqldb.NewTransactionExecutor(
|
||||||
|
@@ -14,7 +14,11 @@ import (
|
|||||||
"pgregory.net/rapid"
|
"pgregory.net/rapid"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RandChannelUpdate generates a random ChannelUpdate message using rapid's
|
// charset contains valid UTF-8 characters that can be used to generate random
|
||||||
|
// strings for testing purposes.
|
||||||
|
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
|
||||||
|
// RandPartialSig generates a random ParialSig message using rapid's
|
||||||
// generators.
|
// generators.
|
||||||
func RandPartialSig(t *rapid.T) *PartialSig {
|
func RandPartialSig(t *rapid.T) *PartialSig {
|
||||||
// Generate random private key bytes
|
// Generate random private key bytes
|
||||||
@@ -139,11 +143,11 @@ func RandNodeAlias(t *rapid.T) NodeAlias {
|
|||||||
var alias NodeAlias
|
var alias NodeAlias
|
||||||
aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength")
|
aliasLength := rapid.IntRange(0, 32).Draw(t, "aliasLength")
|
||||||
|
|
||||||
aliasBytes := rapid.StringN(
|
aliasBytes := rapid.SliceOfN(
|
||||||
0, aliasLength, aliasLength,
|
rapid.SampledFrom([]rune(charset)), aliasLength, aliasLength,
|
||||||
).Draw(t, "alias")
|
).Draw(t, "alias")
|
||||||
|
|
||||||
copy(alias[:], aliasBytes)
|
copy(alias[:], string(aliasBytes))
|
||||||
|
|
||||||
return alias
|
return alias
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user