Merge pull request #10050 from ellemouton/graphMig2-channels

[graph mig 2]: graph/db: migrate graph channels and policies from kvdb to SQL
This commit is contained in:
Oliver Gugger
2025-07-09 13:09:27 +02:00
committed by GitHub
5 changed files with 762 additions and 51 deletions

View File

@@ -97,6 +97,7 @@ circuit. The indices are only available for forwarding events saved after v0.20.
* [11](https://github.com/lightningnetwork/lnd/pull/9972) * [11](https://github.com/lightningnetwork/lnd/pull/9972)
* Add graph SQL migration logic: * Add graph SQL migration logic:
* [1](https://github.com/lightningnetwork/lnd/pull/10036) * [1](https://github.com/lightningnetwork/lnd/pull/10036)
* [2](https://github.com/lightningnetwork/lnd/pull/10050)
## RPC Updates ## RPC Updates
* Previously the `RoutingPolicy` would return the inbound fee record in its * Previously the `RoutingPolicy` would return the inbound fee record in its

View File

@@ -3734,12 +3734,12 @@ func TestDisabledChannelIDs(t *testing.T) {
} }
} }
// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in // TestEdgePolicyMissingMaxHTLC tests that if we find a ChannelEdgePolicy in
// the DB that indicates that it should support the htlc_maximum_value_msat // the DB that indicates that it should support the htlc_maximum_value_msat
// field, but it is not part of the opaque data, then we'll handle it as it is // field, but it is not part of the opaque data, then we'll handle it as it is
// unknown. It also checks that we are correctly able to overwrite it when we // unknown. It also checks that we are correctly able to overwrite it when we
// receive the proper update. // receive the proper update.
func TestEdgePolicyMissingMaxHtcl(t *testing.T) { func TestEdgePolicyMissingMaxHTLC(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
@@ -3797,45 +3797,10 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
require.ErrorIs(t, err, ErrEdgePolicyOptionalFieldNotFound) require.ErrorIs(t, err, ErrEdgePolicyOptionalFieldNotFound)
// Put the stripped bytes in the DB. // Put the stripped bytes in the DB.
err = kvdb.Update(boltStore.db, func(tx kvdb.RwTx) error { putSerializedPolicy(t, boltStore.db, from, chanID, stripped)
edges := tx.ReadWriteBucket(edgeBucket)
if edges == nil {
return ErrEdgeNotFound
}
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
if edgeIndex == nil {
return ErrEdgeNotFound
}
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID)
var scratch [8]byte
var indexKey [8 + 8]byte
copy(indexKey[:], scratch[:])
byteOrder.PutUint64(indexKey[8:], edge1.ChannelID)
updateIndex, err := edges.CreateBucketIfNotExists(
edgeUpdateIndexBucket,
)
if err != nil {
return err
}
if err := updateIndex.Put(indexKey[:], nil); err != nil {
return err
}
return edges.Put(edgeKey[:], stripped)
}, func() {})
require.NoError(t, err, "error writing db")
// And add the second, unmodified edge. // And add the second, unmodified edge.
if err := graph.UpdateEdgePolicy(ctx, edge2); err != nil { require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2))
t.Fatalf("unable to update edge: %v", err)
}
// Attempt to fetch the edge and policies from the DB. Since the policy // Attempt to fetch the edge and policies from the DB. Since the policy
// we added is invalid according to the new format, it should be as we // we added is invalid according to the new format, it should be as we
@@ -3870,6 +3835,38 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) {
assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo)
} }
// putSerializedPolicy is a helper function that writes a serialized
// ChannelEdgePolicy to the edge bucket in the database.
func putSerializedPolicy(t *testing.T, db kvdb.Backend, from []byte,
chanID uint64, b []byte) {
err := kvdb.Update(db, func(tx kvdb.RwTx) error {
edges := tx.ReadWriteBucket(edgeBucket)
require.NotNil(t, edges)
edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket)
require.NotNil(t, edgeIndex)
var edgeKey [33 + 8]byte
copy(edgeKey[:], from)
byteOrder.PutUint64(edgeKey[33:], chanID)
var scratch [8]byte
var indexKey [8 + 8]byte
copy(indexKey[:], scratch[:])
byteOrder.PutUint64(indexKey[8:], chanID)
updateIndex, err := edges.CreateBucketIfNotExists(
edgeUpdateIndexBucket,
)
require.NoError(t, err)
require.NoError(t, updateIndex.Put(indexKey[:], nil))
return edges.Put(edgeKey[:], b)
}, func() {})
require.NoError(t, err, "error writing db")
}
// assertNumZombies queries the provided ChannelGraph for NumZombies, and // assertNumZombies queries the provided ChannelGraph for NumZombies, and
// asserts that the returned number is equal to expZombies. // asserts that the returned number is equal to expZombies.
func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) {

View File

@@ -2844,7 +2844,17 @@ func (c *KVStore) UpdateEdgePolicy(ctx context.Context,
edgeNotFound = false edgeNotFound = false
}, },
Do: func(tx kvdb.RwTx) error { Do: func(tx kvdb.RwTx) error {
var err error // Validate that the ExtraOpaqueData is in fact a valid
// TLV stream. This is done here instead of within
// updateEdgePolicy so that updateEdgePolicy can be used
// by unit tests to recreate the case where we already
// have nodes persisted with invalid TLV data.
err := edge.ExtraOpaqueData.ValidateTLV()
if err != nil {
return fmt.Errorf("%w: %w",
ErrParsingExtraTLVBytes, err)
}
from, to, isUpdate1, err = updateEdgePolicy(tx, edge) from, to, isUpdate1, err = updateEdgePolicy(tx, edge)
if err != nil { if err != nil {
log.Errorf("UpdateEdgePolicy faild: %v", err) log.Errorf("UpdateEdgePolicy faild: %v", err)
@@ -4487,8 +4497,9 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy,
// //
// TODO(halseth): get rid of these invalid policies in a // TODO(halseth): get rid of these invalid policies in a
// migration. // migration.
// TODO(elle): complete the above TODO in migration from kvdb //
// to SQL. // NOTE: the above TODO was completed in the SQL migration and
// so such edge cases no longer need to be handled there.
oldEdgePolicy, err := deserializeChanEdgePolicy( oldEdgePolicy, err := deserializeChanEdgePolicy(
bytes.NewReader(edgeBytes), bytes.NewReader(edgeBytes),
) )
@@ -4703,12 +4714,6 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy,
} }
} }
// Validate that the ExtraOpaqueData is in fact a valid TLV stream.
err = edge.ExtraOpaqueData.ValidateTLV()
if err != nil {
return fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err)
}
if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes {
return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData))
} }

View File

@@ -23,7 +23,7 @@ import (
// only for now and will be called from the main lnd binary once the // only for now and will be called from the main lnd binary once the
// migration is fully implemented and tested. // migration is fully implemented and tested.
func MigrateGraphToSQL(ctx context.Context, kvBackend kvdb.Backend, func MigrateGraphToSQL(ctx context.Context, kvBackend kvdb.Backend,
sqlDB SQLQueries, _ chainhash.Hash) error { sqlDB SQLQueries, chain chainhash.Hash) error {
log.Infof("Starting migration of the graph store from KV to SQL") log.Infof("Starting migration of the graph store from KV to SQL")
t0 := time.Now() t0 := time.Now()
@@ -48,6 +48,13 @@ func MigrateGraphToSQL(ctx context.Context, kvBackend kvdb.Backend,
return fmt.Errorf("could not migrate source node: %w", err) return fmt.Errorf("could not migrate source node: %w", err)
} }
// 3) Migrate all the channels and channel policies.
err = migrateChannelsAndPolicies(ctx, kvBackend, sqlDB, chain)
if err != nil {
return fmt.Errorf("could not migrate channels and policies: %w",
err)
}
log.Infof("Finished migration of the graph store from KV to SQL in %v", log.Infof("Finished migration of the graph store from KV to SQL in %v",
time.Since(t0)) time.Since(t0))
@@ -261,3 +268,277 @@ func migrateSourceNode(ctx context.Context, kvdb kvdb.Backend,
return nil return nil
} }
// migrateChannelsAndPolicies migrates all channels and their policies
// from the KV backend to the SQL database.
func migrateChannelsAndPolicies(ctx context.Context, kvBackend kvdb.Backend,
sqlDB SQLQueries, chain chainhash.Hash) error {
var (
channelCount uint64
skippedChanCount uint64
policyCount uint64
skippedPolicyCount uint64
)
migChanPolicy := func(policy *models.ChannelEdgePolicy) error {
// If the policy is nil, we can skip it.
if policy == nil {
return nil
}
// Unlike the special case of invalid TLV bytes for node and
// channel announcements, we don't need to handle the case for
// channel policies here because it is already handled in the
// `forEachChannel` function. If the policy has invalid TLV
// bytes, then `nil` will be passed to this function.
policyCount++
_, _, _, err := updateChanEdgePolicy(ctx, sqlDB, policy)
if err != nil {
return fmt.Errorf("could not migrate channel "+
"policy %d: %w", policy.ChannelID, err)
}
return nil
}
// Iterate over each channel in the KV store and migrate it and its
// policies to the SQL database.
err := forEachChannel(kvBackend, func(channel *models.ChannelEdgeInfo,
policy1 *models.ChannelEdgePolicy,
policy2 *models.ChannelEdgePolicy) error {
scid := channel.ChannelID
// Here, we do a sanity check to ensure that the chain hash of
// the channel returned by the KV store matches the expected
// chain hash. This is important since in the SQL store, we will
// no longer explicitly store the chain hash in the channel
// info, but rather rely on the chain hash LND is running with.
// So this is our way of ensuring that LND is running on the
// correct network at migration time.
if channel.ChainHash != chain {
return fmt.Errorf("channel %d has chain hash %s, "+
"expected %s", scid, channel.ChainHash, chain)
}
// Sanity check to ensure that the channel has valid extra
// opaque data. If it does not, we'll skip it. We need to do
// this because previously we would just persist any TLV bytes
// that we received without validating them. Now, however, we
// normalise the storage of extra opaque data, so we need to
// ensure that the data is valid. We don't want to abort the
// migration if we encounter a channel with invalid extra opaque
// data, so we'll just skip it and log a warning.
_, err := marshalExtraOpaqueData(channel.ExtraOpaqueData)
if errors.Is(err, ErrParsingExtraTLVBytes) {
log.Warnf("Skipping channel %d with invalid "+
"extra opaque data: %v", scid,
channel.ExtraOpaqueData)
skippedChanCount++
// If we skip a channel, we also skip its policies.
if policy1 != nil {
skippedPolicyCount++
}
if policy2 != nil {
skippedPolicyCount++
}
return nil
} else if err != nil {
return fmt.Errorf("unable to marshal extra opaque "+
"data for channel %d (%v): %w", scid,
channel.ExtraOpaqueData, err)
}
channelCount++
err = migrateSingleChannel(
ctx, sqlDB, channel, policy1, policy2, migChanPolicy,
)
if err != nil {
return fmt.Errorf("could not migrate channel %d: %w",
scid, err)
}
return nil
})
if err != nil {
return fmt.Errorf("could not migrate channels and policies: %w",
err)
}
log.Infof("Migrated %d channels and %d policies from KV to SQL "+
"(skipped %d channels and %d policies due to invalid TLV "+
"streams)", channelCount, policyCount, skippedChanCount,
skippedPolicyCount)
return nil
}
func migrateSingleChannel(ctx context.Context, sqlDB SQLQueries,
channel *models.ChannelEdgeInfo,
policy1, policy2 *models.ChannelEdgePolicy,
migChanPolicy func(*models.ChannelEdgePolicy) error) error {
scid := channel.ChannelID
// First, migrate the channel info along with its policies.
dbChanInfo, err := insertChannel(ctx, sqlDB, channel)
if err != nil {
return fmt.Errorf("could not insert record for channel %d "+
"in SQL store: %w", scid, err)
}
// Now, migrate the two channel policies.
err = migChanPolicy(policy1)
if err != nil {
return fmt.Errorf("could not migrate policy1(%d): %w", scid,
err)
}
err = migChanPolicy(policy2)
if err != nil {
return fmt.Errorf("could not migrate policy2(%d): %w", scid,
err)
}
// Now, fetch the channel and its policies from the SQL DB.
row, err := sqlDB.GetChannelBySCIDWithPolicies(
ctx, sqlc.GetChannelBySCIDWithPoliciesParams{
Scid: channelIDToBytes(scid),
Version: int16(ProtocolV1),
},
)
if err != nil {
return fmt.Errorf("could not get channel by SCID(%d): %w", scid,
err)
}
// Assert that the DB IDs for the channel and nodes are as expected
// given the inserted channel info.
err = sqldb.CompareRecords(
dbChanInfo.channelID, row.Channel.ID, "channel DB ID",
)
if err != nil {
return err
}
err = sqldb.CompareRecords(
dbChanInfo.node1ID, row.Node.ID, "node1 DB ID",
)
if err != nil {
return err
}
err = sqldb.CompareRecords(
dbChanInfo.node2ID, row.Node_2.ID, "node2 DB ID",
)
if err != nil {
return err
}
migChan, migPol1, migPol2, err := getAndBuildChanAndPolicies(
ctx, sqlDB, row, channel.ChainHash,
)
if err != nil {
return fmt.Errorf("could not build migrated channel and "+
"policies: %w", err)
}
// Finally, compare the original channel info and
// policies with the migrated ones to ensure they match.
if len(channel.ExtraOpaqueData) == 0 {
channel.ExtraOpaqueData = nil
}
if len(migChan.ExtraOpaqueData) == 0 {
migChan.ExtraOpaqueData = nil
}
err = sqldb.CompareRecords(
channel, migChan, fmt.Sprintf("channel %d", scid),
)
if err != nil {
return err
}
checkPolicy := func(expPolicy,
migPolicy *models.ChannelEdgePolicy) error {
switch {
// Both policies are nil, nothing to compare.
case expPolicy == nil && migPolicy == nil:
return nil
// One of the policies is nil, but the other is not.
case expPolicy == nil || migPolicy == nil:
return fmt.Errorf("expected both policies to be "+
"non-nil. Got expPolicy: %v, "+
"migPolicy: %v", expPolicy, migPolicy)
// Both policies are non-nil, we can compare them.
default:
}
if len(expPolicy.ExtraOpaqueData) == 0 {
expPolicy.ExtraOpaqueData = nil
}
if len(migPolicy.ExtraOpaqueData) == 0 {
migPolicy.ExtraOpaqueData = nil
}
return sqldb.CompareRecords(
*expPolicy, *migPolicy, "channel policy",
)
}
err = checkPolicy(policy1, migPol1)
if err != nil {
return fmt.Errorf("policy1 mismatch for channel %d: %w", scid,
err)
}
err = checkPolicy(policy2, migPol2)
if err != nil {
return fmt.Errorf("policy2 mismatch for channel %d: %w", scid,
err)
}
return nil
}
func getAndBuildChanAndPolicies(ctx context.Context, db SQLQueries,
row sqlc.GetChannelBySCIDWithPoliciesRow,
chain chainhash.Hash) (*models.ChannelEdgeInfo,
*models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) {
node1, node2, err := buildNodeVertices(
row.Node.PubKey, row.Node_2.PubKey,
)
if err != nil {
return nil, nil, nil, err
}
edge, err := getAndBuildEdgeInfo(
ctx, db, chain, row.Channel.ID, row.Channel, node1, node2,
)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to build channel "+
"info: %w", err)
}
dbPol1, dbPol2, err := extractChannelPolicies(row)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to extract channel "+
"policies: %w", err)
}
policy1, policy2, err := getAndBuildChanPolicies(
ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2,
)
if err != nil {
return nil, nil, nil, fmt.Errorf("unable to build channel "+
"policies: %w", err)
}
return edge, policy1, policy2, nil
}

View File

@@ -4,10 +4,13 @@ package graphdb
import ( import (
"bytes" "bytes"
"cmp"
"context" "context"
"errors" "errors"
"fmt" "fmt"
"image/color" "image/color"
"math"
prand "math/rand"
"net" "net"
"os" "os"
"path" "path"
@@ -18,6 +21,7 @@ import (
"github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/btcec/v2"
"github.com/btcsuite/btcd/chaincfg" "github.com/btcsuite/btcd/chaincfg"
"github.com/btcsuite/btcd/wire"
"github.com/btcsuite/btclog/v2" "github.com/btcsuite/btclog/v2"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
"github.com/lightningnetwork/lnd/kvdb" "github.com/lightningnetwork/lnd/kvdb"
@@ -36,6 +40,12 @@ var (
testSigBytes = testSig.Serialize() testSigBytes = testSig.Serialize()
testExtraData = []byte{1, 1, 1, 2, 2, 2, 2} testExtraData = []byte{1, 1, 1, 2, 2, 2, 2}
testEmptyFeatures = lnwire.EmptyFeatureVector() testEmptyFeatures = lnwire.EmptyFeatureVector()
testAuthProof = &models.ChannelAuthProof{
NodeSig1Bytes: testSig.Serialize(),
NodeSig2Bytes: testSig.Serialize(),
BitcoinSig1Bytes: testSig.Serialize(),
BitcoinSig2Bytes: testSig.Serialize(),
}
) )
// TestMigrateGraphToSQL tests various deterministic cases that we want to test // TestMigrateGraphToSQL tests various deterministic cases that we want to test
@@ -53,12 +63,24 @@ func TestMigrateGraphToSQL(t *testing.T) {
switch obj := object.(type) { switch obj := object.(type) {
case *models.LightningNode: case *models.LightningNode:
err = db.AddLightningNode(ctx, obj) err = db.AddLightningNode(ctx, obj)
case *models.ChannelEdgeInfo:
err = db.AddChannelEdge(ctx, obj)
case *models.ChannelEdgePolicy:
_, _, err = db.UpdateEdgePolicy(ctx, obj)
default: default:
err = fmt.Errorf("unhandled object type: %T", obj) err = fmt.Errorf("unhandled object type: %T", obj)
} }
require.NoError(t, err) require.NoError(t, err)
} }
var (
chanID1 = prand.Uint64()
chanID2 = prand.Uint64()
node1 = genPubKey(t)
node2 = genPubKey(t)
)
tests := []struct { tests := []struct {
name string name string
write func(t *testing.T, db *KVStore, object any) write func(t *testing.T, db *KVStore, object any)
@@ -123,6 +145,78 @@ func TestMigrateGraphToSQL(t *testing.T) {
srcNodeSet: true, srcNodeSet: true,
}, },
}, },
{
name: "channels and policies",
write: writeUpdate,
objects: []any{
// A channel with unknown nodes. This will
// result in two shell nodes being created.
// - channel count += 1
// - node count += 2
makeTestChannel(t),
// Insert some nodes.
// - node count += 1
makeTestNode(t, func(n *models.LightningNode) {
n.PubKeyBytes = node1
}),
// - node count += 1
makeTestNode(t, func(n *models.LightningNode) {
n.PubKeyBytes = node2
}),
// A channel with known nodes.
// - channel count += 1
makeTestChannel(
t, func(c *models.ChannelEdgeInfo) {
c.ChannelID = chanID1
c.NodeKey1Bytes = node1
c.NodeKey2Bytes = node2
},
),
// Insert a channel with no auth proof, no
// extra opaque data, and empty features.
// Use known nodes.
// - channel count += 1
makeTestChannel(
t, func(c *models.ChannelEdgeInfo) {
c.ChannelID = chanID2
c.NodeKey1Bytes = node1
c.NodeKey2Bytes = node2
c.AuthProof = nil
c.ExtraOpaqueData = nil
c.Features = testEmptyFeatures
},
),
// Now, insert a single update for the
// first channel.
// - channel policy count += 1
makeTestPolicy(chanID1, node1, false),
// Insert two updates for the second
// channel, one for each direction.
// - channel policy count += 1
makeTestPolicy(chanID2, node1, false),
// This one also has no extra opaque data.
// - channel policy count += 1
makeTestPolicy(
chanID2, node2, true,
func(p *models.ChannelEdgePolicy) {
p.ExtraOpaqueData = nil
},
),
},
expGraphStats: graphStats{
numNodes: 4,
numChannels: 3,
numPolicies: 3,
},
},
} }
for _, test := range tests { for _, test := range tests {
@@ -155,8 +249,10 @@ func TestMigrateGraphToSQL(t *testing.T) {
// graphStats holds expected statistics about the graph after migration. // graphStats holds expected statistics about the graph after migration.
type graphStats struct { type graphStats struct {
numNodes int numNodes int
srcNodeSet bool srcNodeSet bool
numChannels int
numPolicies int
} }
// assertInSync checks that the KVStore and SQLStore both contain the same // assertInSync checks that the KVStore and SQLStore both contain the same
@@ -174,6 +270,12 @@ func assertInSync(t *testing.T, kvDB *KVStore, sqlDB *SQLStore,
sqlSourceNode := fetchSourceNode(t, sqlDB) sqlSourceNode := fetchSourceNode(t, sqlDB)
require.Equal(t, stats.srcNodeSet, sqlSourceNode != nil) require.Equal(t, stats.srcNodeSet, sqlSourceNode != nil)
require.Equal(t, fetchSourceNode(t, kvDB), sqlSourceNode) require.Equal(t, fetchSourceNode(t, kvDB), sqlSourceNode)
// 3) Compare the channels and policies in the two stores.
sqlChannels := fetchAllChannelsAndPolicies(t, sqlDB)
require.Len(t, sqlChannels, stats.numChannels)
require.Equal(t, stats.numPolicies, sqlChannels.CountPolicies())
require.Equal(t, fetchAllChannelsAndPolicies(t, kvDB), sqlChannels)
} }
// fetchAllNodes retrieves all nodes from the given store and returns them // fetchAllNodes retrieves all nodes from the given store and returns them
@@ -215,6 +317,67 @@ func fetchSourceNode(t *testing.T, store V1Store) *models.LightningNode {
return node return node
} }
// chanInfo holds information about a channel, including its edge info
// and the policies for both directions.
type chanInfo struct {
edgeInfo *models.ChannelEdgeInfo
policy1 *models.ChannelEdgePolicy
policy2 *models.ChannelEdgePolicy
}
// chanSet is a slice of chanInfo
type chanSet []chanInfo
// CountPolicies counts the total number of policies in the channel set.
func (c chanSet) CountPolicies() int {
var count int
for _, info := range c {
if info.policy1 != nil {
count++
}
if info.policy2 != nil {
count++
}
}
return count
}
// fetchAllChannelsAndPolicies retrieves all channels and their policies
// from the given store and returns them sorted by their channel ID.
func fetchAllChannelsAndPolicies(t *testing.T, store V1Store) chanSet {
channels := make(chanSet, 0)
err := store.ForEachChannel(func(info *models.ChannelEdgeInfo,
p1 *models.ChannelEdgePolicy,
p2 *models.ChannelEdgePolicy) error {
if len(info.ExtraOpaqueData) == 0 {
info.ExtraOpaqueData = nil
}
if p1 != nil && len(p1.ExtraOpaqueData) == 0 {
p1.ExtraOpaqueData = nil
}
if p2 != nil && len(p2.ExtraOpaqueData) == 0 {
p2.ExtraOpaqueData = nil
}
channels = append(channels, chanInfo{
edgeInfo: info,
policy1: p1,
policy2: p2,
})
return nil
})
require.NoError(t, err)
// Sort the channels by their channel ID to ensure a consistent order.
slices.SortFunc(channels, func(i, j chanInfo) int {
return cmp.Compare(i.edgeInfo.ChannelID, j.edgeInfo.ChannelID)
})
return channels
}
// setUpKVStore initializes a new KVStore for testing. // setUpKVStore initializes a new KVStore for testing.
func setUpKVStore(t *testing.T) *KVStore { func setUpKVStore(t *testing.T) *KVStore {
kvDB, cleanup, err := kvdb.GetTestBackend(t.TempDir(), "graph") kvDB, cleanup, err := kvdb.GetTestBackend(t.TempDir(), "graph")
@@ -293,6 +456,73 @@ func makeTestShellNode(t *testing.T,
return n return n
} }
// modify the attributes of a models.ChannelEdgeInfo created by makeTestChannel.
type testChanOpt func(info *models.ChannelEdgeInfo)
// makeTestChannel creates a test models.ChannelEdgeInfo. The functional options
// can be used to modify the channel's attributes.
func makeTestChannel(t *testing.T,
opts ...testChanOpt) *models.ChannelEdgeInfo {
c := &models.ChannelEdgeInfo{
ChannelID: prand.Uint64(),
ChainHash: testChain,
NodeKey1Bytes: genPubKey(t),
NodeKey2Bytes: genPubKey(t),
BitcoinKey1Bytes: genPubKey(t),
BitcoinKey2Bytes: genPubKey(t),
Features: testFeatures,
AuthProof: testAuthProof,
ChannelPoint: wire.OutPoint{
Hash: rev,
Index: prand.Uint32(),
},
Capacity: 10000,
ExtraOpaqueData: testExtraData,
}
for _, opt := range opts {
opt(c)
}
return c
}
// testPolicyOpt defines a functional option type that can be used to modify the
// attributes of a models.ChannelEdgePolicy created by makeTestPolicy.
type testPolicyOpt func(*models.ChannelEdgePolicy)
// makeTestPolicy creates a test models.ChannelEdgePolicy. The functional
// options can be used to modify the policy's attributes.
func makeTestPolicy(chanID uint64, toNode route.Vertex, isNode1 bool,
opts ...testPolicyOpt) *models.ChannelEdgePolicy {
chanFlags := lnwire.ChanUpdateChanFlags(1)
if isNode1 {
chanFlags = 0
}
p := &models.ChannelEdgePolicy{
SigBytes: testSigBytes,
ChannelID: chanID,
LastUpdate: nextUpdateTime(),
MessageFlags: 1,
ChannelFlags: chanFlags,
TimeLockDelta: math.MaxUint16,
MinHTLC: math.MaxUint64,
MaxHTLC: math.MaxUint64,
FeeBaseMSat: math.MaxUint64,
FeeProportionalMillionths: math.MaxUint64,
ToNode: toNode,
}
for _, opt := range opts {
opt(p)
}
return p
}
// TestMigrationWithChannelDB tests the migration of the graph store from a // TestMigrationWithChannelDB tests the migration of the graph store from a
// bolt backed channel.db or a kvdb channel.sqlite to a SQL database. Note that // bolt backed channel.db or a kvdb channel.sqlite to a SQL database. Note that
// this test does not attempt to be a complete migration test for all graph // this test does not attempt to be a complete migration test for all graph
@@ -425,6 +655,8 @@ func TestSQLMigrationEdgeCases(t *testing.T) {
// with invalid TLV data, the migration will still succeed, but the // with invalid TLV data, the migration will still succeed, but the
// node will not end up in the SQL store. // node will not end up in the SQL store.
t.Run("node with bad tlv data", func(t *testing.T) { t.Run("node with bad tlv data", func(t *testing.T) {
t.Parallel()
// Make one valid node and one node with invalid TLV data. // Make one valid node and one node with invalid TLV data.
n1 := makeTestNode(t) n1 := makeTestNode(t)
n2 := makeTestNode(t, func(n *models.LightningNode) { n2 := makeTestNode(t, func(n *models.LightningNode) {
@@ -443,6 +675,197 @@ func TestSQLMigrationEdgeCases(t *testing.T) {
nodes: []*models.LightningNode{n1}, nodes: []*models.LightningNode{n1},
}) })
}) })
// Here, we test that in the case where the KV store contains a channel
// with invalid TLV data, the migration will still succeed, but the
// channel and its policies will not end up in the SQL store.
t.Run("channel with bad tlv data", func(t *testing.T) {
t.Parallel()
// Make two valid nodes to point to.
n1 := makeTestNode(t)
n2 := makeTestNode(t)
// Create two channels between these nodes, one valid one
// and one with invalid TLV data.
c1 := makeTestChannel(t, func(c *models.ChannelEdgeInfo) {
c.NodeKey1Bytes = n1.PubKeyBytes
c.NodeKey2Bytes = n2.PubKeyBytes
})
c2 := makeTestChannel(t, func(c *models.ChannelEdgeInfo) {
c.NodeKey1Bytes = n1.PubKeyBytes
c.NodeKey2Bytes = n2.PubKeyBytes
c.ExtraOpaqueData = invalidTLVData
})
// Create policies for both channels.
p1 := makeTestPolicy(c1.ChannelID, n2.PubKeyBytes, true)
p2 := makeTestPolicy(c2.ChannelID, n1.PubKeyBytes, false)
populateKV := func(t *testing.T, db *KVStore) {
// Insert both nodes into the KV store.
require.NoError(t, db.AddLightningNode(ctx, n1))
require.NoError(t, db.AddLightningNode(ctx, n2))
// Insert both channels into the KV store.
require.NoError(t, db.AddChannelEdge(ctx, c1))
require.NoError(t, db.AddChannelEdge(ctx, c2))
// Insert policies for both channels.
_, _, err := db.UpdateEdgePolicy(ctx, p1)
require.NoError(t, err)
_, _, err = db.UpdateEdgePolicy(ctx, p2)
require.NoError(t, err)
}
runTestMigration(t, populateKV, dbState{
// Both nodes will be present.
nodes: []*models.LightningNode{n1, n2},
// We only expect the first channel and its policy to
// be present in the SQL db.
chans: chanSet{{
edgeInfo: c1,
policy1: p1,
}},
})
})
// Here, we test that in the case where the KV store contains a
// channel policy with invalid TLV data, the migration will still
// succeed, but the channel policy will not end up in the SQL store.
t.Run("channel policy with bad tlv data", func(t *testing.T) {
t.Parallel()
// Make two valid nodes to point to.
n1 := makeTestNode(t)
n2 := makeTestNode(t)
// Create one valid channels between these nodes.
c := makeTestChannel(t, func(c *models.ChannelEdgeInfo) {
c.NodeKey1Bytes = n1.PubKeyBytes
c.NodeKey2Bytes = n2.PubKeyBytes
})
// Now, create two policies for this channel, one valid one
// and one with invalid TLV data.
p1 := makeTestPolicy(c.ChannelID, n2.PubKeyBytes, true)
p2 := makeTestPolicy(
c.ChannelID, n1.PubKeyBytes, false,
func(p *models.ChannelEdgePolicy) {
p.ExtraOpaqueData = invalidTLVData
},
)
populateKV := func(t *testing.T, db *KVStore) {
// Insert both nodes into the KV store.
require.NoError(t, db.AddLightningNode(ctx, n1))
require.NoError(t, db.AddLightningNode(ctx, n2))
// Insert the channel into the KV store.
require.NoError(t, db.AddChannelEdge(ctx, c))
// Insert policies for the channel.
_, _, err := db.UpdateEdgePolicy(ctx, p1)
require.NoError(t, err)
// We need to write this invalid one with the
// updateEdgePolicy helper function in order to bypass
// the newly added TLV validation in the
// UpdateEdgePolicy method of the KVStore.
err = db.db.Update(func(tx kvdb.RwTx) error {
_, _, _, err := updateEdgePolicy(tx, p2)
return err
}, func() {})
require.NoError(t, err)
}
runTestMigration(t, populateKV, dbState{
// Both nodes will be present.
nodes: []*models.LightningNode{n1, n2},
// The channel will be present, but only the
// valid policy will be included in the SQL db.
chans: chanSet{{
edgeInfo: c,
policy1: p1,
}},
})
})
// Here, we test that in the case where the KV store contains a
// channel policy that has a bit indicating that it contains a max HTLC
// field, but the field is missing. The migration will still succeed,
// but the policy will not end up in the SQL store.
t.Run("channel policy with missing max htlc", func(t *testing.T) {
t.Parallel()
// Make two valid nodes to point to.
n1 := makeTestNode(t)
n2 := makeTestNode(t)
// Create one valid channels between these nodes.
c := makeTestChannel(t, func(c *models.ChannelEdgeInfo) {
c.NodeKey1Bytes = n1.PubKeyBytes
c.NodeKey2Bytes = n2.PubKeyBytes
})
// Now, create two policies for this channel, one valid one
// and one with an invalid max htlc field.
p1 := makeTestPolicy(c.ChannelID, n2.PubKeyBytes, true)
p2 := makeTestPolicy(c.ChannelID, n1.PubKeyBytes, false)
// We'll remove the no max_htlc field from the first edge
// policy, and all other opaque data, and serialize it.
p2.MessageFlags = 0
p2.ExtraOpaqueData = nil
var b bytes.Buffer
require.NoError(t, serializeChanEdgePolicy(
&b, p2, n1.PubKeyBytes[:],
))
// Set the max_htlc field. The extra bytes added to the
// serialization will be the opaque data containing the
// serialized field.
p2.MessageFlags = lnwire.ChanUpdateRequiredMaxHtlc
p2.MaxHTLC = math.MaxUint64
var b2 bytes.Buffer
require.NoError(t, serializeChanEdgePolicy(
&b2, p2, n1.PubKeyBytes[:],
))
withMaxHtlc := b2.Bytes()
// Remove the opaque data from the serialization.
stripped := withMaxHtlc[:len(b.Bytes())]
populateKV := func(t *testing.T, db *KVStore) {
// Insert both nodes into the KV store.
require.NoError(t, db.AddLightningNode(ctx, n1))
require.NoError(t, db.AddLightningNode(ctx, n2))
// Insert the channel into the KV store.
require.NoError(t, db.AddChannelEdge(ctx, c))
// Insert policies for the channel.
_, _, err := db.UpdateEdgePolicy(ctx, p1)
require.NoError(t, err)
putSerializedPolicy(
t, db.db, n2.PubKeyBytes[:], c.ChannelID,
stripped,
)
}
runTestMigration(t, populateKV, dbState{
// Both nodes will be present.
nodes: []*models.LightningNode{n1, n2},
// The channel will be present, but only the
// valid policy will be included in the SQL db.
chans: chanSet{{
edgeInfo: c,
policy1: p1,
}},
})
})
} }
// runTestMigration is a helper function that sets up the KVStore and SQLStore, // runTestMigration is a helper function that sets up the KVStore and SQLStore,
@@ -475,6 +898,7 @@ func runTestMigration(t *testing.T, populateKV func(t *testing.T, db *KVStore),
// dbState describes the expected state of the SQLStore after a migration. // dbState describes the expected state of the SQLStore after a migration.
type dbState struct { type dbState struct {
nodes []*models.LightningNode nodes []*models.LightningNode
chans chanSet
} }
// assertResultState asserts that the SQLStore contains the expected // assertResultState asserts that the SQLStore contains the expected
@@ -482,4 +906,7 @@ type dbState struct {
func assertResultState(t *testing.T, sql *SQLStore, expState dbState) { func assertResultState(t *testing.T, sql *SQLStore, expState dbState) {
// Assert that the sql store contains the expected nodes. // Assert that the sql store contains the expected nodes.
require.ElementsMatch(t, expState.nodes, fetchAllNodes(t, sql)) require.ElementsMatch(t, expState.nodes, fetchAllNodes(t, sql))
require.ElementsMatch(
t, expState.chans, fetchAllChannelsAndPolicies(t, sql),
)
} }