diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index f0e753abc..814380a09 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -97,6 +97,7 @@ circuit. The indices are only available for forwarding events saved after v0.20. * [11](https://github.com/lightningnetwork/lnd/pull/9972) * Add graph SQL migration logic: * [1](https://github.com/lightningnetwork/lnd/pull/10036) + * [2](https://github.com/lightningnetwork/lnd/pull/10050) ## RPC Updates * Previously the `RoutingPolicy` would return the inbound fee record in its diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index fefc125c4..7df3eb4f7 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -3734,12 +3734,12 @@ func TestDisabledChannelIDs(t *testing.T) { } } -// TestEdgePolicyMissingMaxHtcl tests that if we find a ChannelEdgePolicy in +// TestEdgePolicyMissingMaxHTLC tests that if we find a ChannelEdgePolicy in // the DB that indicates that it should support the htlc_maximum_value_msat // field, but it is not part of the opaque data, then we'll handle it as it is // unknown. It also checks that we are correctly able to overwrite it when we // receive the proper update. -func TestEdgePolicyMissingMaxHtcl(t *testing.T) { +func TestEdgePolicyMissingMaxHTLC(t *testing.T) { t.Parallel() ctx := context.Background() @@ -3797,45 +3797,10 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { require.ErrorIs(t, err, ErrEdgePolicyOptionalFieldNotFound) // Put the stripped bytes in the DB. - err = kvdb.Update(boltStore.db, func(tx kvdb.RwTx) error { - edges := tx.ReadWriteBucket(edgeBucket) - if edges == nil { - return ErrEdgeNotFound - } - - edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket) - if edgeIndex == nil { - return ErrEdgeNotFound - } - - var edgeKey [33 + 8]byte - copy(edgeKey[:], from) - byteOrder.PutUint64(edgeKey[33:], edge1.ChannelID) - - var scratch [8]byte - var indexKey [8 + 8]byte - copy(indexKey[:], scratch[:]) - byteOrder.PutUint64(indexKey[8:], edge1.ChannelID) - - updateIndex, err := edges.CreateBucketIfNotExists( - edgeUpdateIndexBucket, - ) - if err != nil { - return err - } - - if err := updateIndex.Put(indexKey[:], nil); err != nil { - return err - } - - return edges.Put(edgeKey[:], stripped) - }, func() {}) - require.NoError(t, err, "error writing db") + putSerializedPolicy(t, boltStore.db, from, chanID, stripped) // And add the second, unmodified edge. - if err := graph.UpdateEdgePolicy(ctx, edge2); err != nil { - t.Fatalf("unable to update edge: %v", err) - } + require.NoError(t, graph.UpdateEdgePolicy(ctx, edge2)) // Attempt to fetch the edge and policies from the DB. Since the policy // we added is invalid according to the new format, it should be as we @@ -3870,6 +3835,38 @@ func TestEdgePolicyMissingMaxHtcl(t *testing.T) { assertEdgeInfoEqual(t, dbEdgeInfo, edgeInfo) } +// putSerializedPolicy is a helper function that writes a serialized +// ChannelEdgePolicy to the edge bucket in the database. +func putSerializedPolicy(t *testing.T, db kvdb.Backend, from []byte, + chanID uint64, b []byte) { + + err := kvdb.Update(db, func(tx kvdb.RwTx) error { + edges := tx.ReadWriteBucket(edgeBucket) + require.NotNil(t, edges) + + edgeIndex := edges.NestedReadWriteBucket(edgeIndexBucket) + require.NotNil(t, edgeIndex) + + var edgeKey [33 + 8]byte + copy(edgeKey[:], from) + byteOrder.PutUint64(edgeKey[33:], chanID) + + var scratch [8]byte + var indexKey [8 + 8]byte + copy(indexKey[:], scratch[:]) + byteOrder.PutUint64(indexKey[8:], chanID) + + updateIndex, err := edges.CreateBucketIfNotExists( + edgeUpdateIndexBucket, + ) + require.NoError(t, err) + require.NoError(t, updateIndex.Put(indexKey[:], nil)) + + return edges.Put(edgeKey[:], b) + }, func() {}) + require.NoError(t, err, "error writing db") +} + // assertNumZombies queries the provided ChannelGraph for NumZombies, and // asserts that the returned number is equal to expZombies. func assertNumZombies(t *testing.T, graph *ChannelGraph, expZombies uint64) { diff --git a/graph/db/kv_store.go b/graph/db/kv_store.go index 42b9f9245..a0c1be0eb 100644 --- a/graph/db/kv_store.go +++ b/graph/db/kv_store.go @@ -2844,7 +2844,17 @@ func (c *KVStore) UpdateEdgePolicy(ctx context.Context, edgeNotFound = false }, Do: func(tx kvdb.RwTx) error { - var err error + // Validate that the ExtraOpaqueData is in fact a valid + // TLV stream. This is done here instead of within + // updateEdgePolicy so that updateEdgePolicy can be used + // by unit tests to recreate the case where we already + // have nodes persisted with invalid TLV data. + err := edge.ExtraOpaqueData.ValidateTLV() + if err != nil { + return fmt.Errorf("%w: %w", + ErrParsingExtraTLVBytes, err) + } + from, to, isUpdate1, err = updateEdgePolicy(tx, edge) if err != nil { log.Errorf("UpdateEdgePolicy faild: %v", err) @@ -4487,8 +4497,9 @@ func putChanEdgePolicy(edges kvdb.RwBucket, edge *models.ChannelEdgePolicy, // // TODO(halseth): get rid of these invalid policies in a // migration. - // TODO(elle): complete the above TODO in migration from kvdb - // to SQL. + // + // NOTE: the above TODO was completed in the SQL migration and + // so such edge cases no longer need to be handled there. oldEdgePolicy, err := deserializeChanEdgePolicy( bytes.NewReader(edgeBytes), ) @@ -4703,12 +4714,6 @@ func serializeChanEdgePolicy(w io.Writer, edge *models.ChannelEdgePolicy, } } - // Validate that the ExtraOpaqueData is in fact a valid TLV stream. - err = edge.ExtraOpaqueData.ValidateTLV() - if err != nil { - return fmt.Errorf("%w: %w", ErrParsingExtraTLVBytes, err) - } - if len(edge.ExtraOpaqueData) > MaxAllowedExtraOpaqueBytes { return ErrTooManyExtraOpaqueBytes(len(edge.ExtraOpaqueData)) } diff --git a/graph/db/sql_migration.go b/graph/db/sql_migration.go index a70dd62d0..d79357de3 100644 --- a/graph/db/sql_migration.go +++ b/graph/db/sql_migration.go @@ -23,7 +23,7 @@ import ( // only for now and will be called from the main lnd binary once the // migration is fully implemented and tested. func MigrateGraphToSQL(ctx context.Context, kvBackend kvdb.Backend, - sqlDB SQLQueries, _ chainhash.Hash) error { + sqlDB SQLQueries, chain chainhash.Hash) error { log.Infof("Starting migration of the graph store from KV to SQL") t0 := time.Now() @@ -48,6 +48,13 @@ func MigrateGraphToSQL(ctx context.Context, kvBackend kvdb.Backend, return fmt.Errorf("could not migrate source node: %w", err) } + // 3) Migrate all the channels and channel policies. + err = migrateChannelsAndPolicies(ctx, kvBackend, sqlDB, chain) + if err != nil { + return fmt.Errorf("could not migrate channels and policies: %w", + err) + } + log.Infof("Finished migration of the graph store from KV to SQL in %v", time.Since(t0)) @@ -261,3 +268,277 @@ func migrateSourceNode(ctx context.Context, kvdb kvdb.Backend, return nil } + +// migrateChannelsAndPolicies migrates all channels and their policies +// from the KV backend to the SQL database. +func migrateChannelsAndPolicies(ctx context.Context, kvBackend kvdb.Backend, + sqlDB SQLQueries, chain chainhash.Hash) error { + + var ( + channelCount uint64 + skippedChanCount uint64 + policyCount uint64 + skippedPolicyCount uint64 + ) + migChanPolicy := func(policy *models.ChannelEdgePolicy) error { + // If the policy is nil, we can skip it. + if policy == nil { + return nil + } + + // Unlike the special case of invalid TLV bytes for node and + // channel announcements, we don't need to handle the case for + // channel policies here because it is already handled in the + // `forEachChannel` function. If the policy has invalid TLV + // bytes, then `nil` will be passed to this function. + + policyCount++ + + _, _, _, err := updateChanEdgePolicy(ctx, sqlDB, policy) + if err != nil { + return fmt.Errorf("could not migrate channel "+ + "policy %d: %w", policy.ChannelID, err) + } + + return nil + } + + // Iterate over each channel in the KV store and migrate it and its + // policies to the SQL database. + err := forEachChannel(kvBackend, func(channel *models.ChannelEdgeInfo, + policy1 *models.ChannelEdgePolicy, + policy2 *models.ChannelEdgePolicy) error { + + scid := channel.ChannelID + + // Here, we do a sanity check to ensure that the chain hash of + // the channel returned by the KV store matches the expected + // chain hash. This is important since in the SQL store, we will + // no longer explicitly store the chain hash in the channel + // info, but rather rely on the chain hash LND is running with. + // So this is our way of ensuring that LND is running on the + // correct network at migration time. + if channel.ChainHash != chain { + return fmt.Errorf("channel %d has chain hash %s, "+ + "expected %s", scid, channel.ChainHash, chain) + } + + // Sanity check to ensure that the channel has valid extra + // opaque data. If it does not, we'll skip it. We need to do + // this because previously we would just persist any TLV bytes + // that we received without validating them. Now, however, we + // normalise the storage of extra opaque data, so we need to + // ensure that the data is valid. We don't want to abort the + // migration if we encounter a channel with invalid extra opaque + // data, so we'll just skip it and log a warning. + _, err := marshalExtraOpaqueData(channel.ExtraOpaqueData) + if errors.Is(err, ErrParsingExtraTLVBytes) { + log.Warnf("Skipping channel %d with invalid "+ + "extra opaque data: %v", scid, + channel.ExtraOpaqueData) + + skippedChanCount++ + + // If we skip a channel, we also skip its policies. + if policy1 != nil { + skippedPolicyCount++ + } + if policy2 != nil { + skippedPolicyCount++ + } + + return nil + } else if err != nil { + return fmt.Errorf("unable to marshal extra opaque "+ + "data for channel %d (%v): %w", scid, + channel.ExtraOpaqueData, err) + } + + channelCount++ + err = migrateSingleChannel( + ctx, sqlDB, channel, policy1, policy2, migChanPolicy, + ) + if err != nil { + return fmt.Errorf("could not migrate channel %d: %w", + scid, err) + } + + return nil + }) + if err != nil { + return fmt.Errorf("could not migrate channels and policies: %w", + err) + } + + log.Infof("Migrated %d channels and %d policies from KV to SQL "+ + "(skipped %d channels and %d policies due to invalid TLV "+ + "streams)", channelCount, policyCount, skippedChanCount, + skippedPolicyCount) + + return nil +} + +func migrateSingleChannel(ctx context.Context, sqlDB SQLQueries, + channel *models.ChannelEdgeInfo, + policy1, policy2 *models.ChannelEdgePolicy, + migChanPolicy func(*models.ChannelEdgePolicy) error) error { + + scid := channel.ChannelID + + // First, migrate the channel info along with its policies. + dbChanInfo, err := insertChannel(ctx, sqlDB, channel) + if err != nil { + return fmt.Errorf("could not insert record for channel %d "+ + "in SQL store: %w", scid, err) + } + + // Now, migrate the two channel policies. + err = migChanPolicy(policy1) + if err != nil { + return fmt.Errorf("could not migrate policy1(%d): %w", scid, + err) + } + err = migChanPolicy(policy2) + if err != nil { + return fmt.Errorf("could not migrate policy2(%d): %w", scid, + err) + } + + // Now, fetch the channel and its policies from the SQL DB. + row, err := sqlDB.GetChannelBySCIDWithPolicies( + ctx, sqlc.GetChannelBySCIDWithPoliciesParams{ + Scid: channelIDToBytes(scid), + Version: int16(ProtocolV1), + }, + ) + if err != nil { + return fmt.Errorf("could not get channel by SCID(%d): %w", scid, + err) + } + + // Assert that the DB IDs for the channel and nodes are as expected + // given the inserted channel info. + err = sqldb.CompareRecords( + dbChanInfo.channelID, row.Channel.ID, "channel DB ID", + ) + if err != nil { + return err + } + err = sqldb.CompareRecords( + dbChanInfo.node1ID, row.Node.ID, "node1 DB ID", + ) + if err != nil { + return err + } + err = sqldb.CompareRecords( + dbChanInfo.node2ID, row.Node_2.ID, "node2 DB ID", + ) + if err != nil { + return err + } + + migChan, migPol1, migPol2, err := getAndBuildChanAndPolicies( + ctx, sqlDB, row, channel.ChainHash, + ) + if err != nil { + return fmt.Errorf("could not build migrated channel and "+ + "policies: %w", err) + } + + // Finally, compare the original channel info and + // policies with the migrated ones to ensure they match. + if len(channel.ExtraOpaqueData) == 0 { + channel.ExtraOpaqueData = nil + } + if len(migChan.ExtraOpaqueData) == 0 { + migChan.ExtraOpaqueData = nil + } + + err = sqldb.CompareRecords( + channel, migChan, fmt.Sprintf("channel %d", scid), + ) + if err != nil { + return err + } + + checkPolicy := func(expPolicy, + migPolicy *models.ChannelEdgePolicy) error { + + switch { + // Both policies are nil, nothing to compare. + case expPolicy == nil && migPolicy == nil: + return nil + + // One of the policies is nil, but the other is not. + case expPolicy == nil || migPolicy == nil: + return fmt.Errorf("expected both policies to be "+ + "non-nil. Got expPolicy: %v, "+ + "migPolicy: %v", expPolicy, migPolicy) + + // Both policies are non-nil, we can compare them. + default: + } + + if len(expPolicy.ExtraOpaqueData) == 0 { + expPolicy.ExtraOpaqueData = nil + } + if len(migPolicy.ExtraOpaqueData) == 0 { + migPolicy.ExtraOpaqueData = nil + } + + return sqldb.CompareRecords( + *expPolicy, *migPolicy, "channel policy", + ) + } + + err = checkPolicy(policy1, migPol1) + if err != nil { + return fmt.Errorf("policy1 mismatch for channel %d: %w", scid, + err) + } + + err = checkPolicy(policy2, migPol2) + if err != nil { + return fmt.Errorf("policy2 mismatch for channel %d: %w", scid, + err) + } + + return nil +} + +func getAndBuildChanAndPolicies(ctx context.Context, db SQLQueries, + row sqlc.GetChannelBySCIDWithPoliciesRow, + chain chainhash.Hash) (*models.ChannelEdgeInfo, + *models.ChannelEdgePolicy, *models.ChannelEdgePolicy, error) { + + node1, node2, err := buildNodeVertices( + row.Node.PubKey, row.Node_2.PubKey, + ) + if err != nil { + return nil, nil, nil, err + } + + edge, err := getAndBuildEdgeInfo( + ctx, db, chain, row.Channel.ID, row.Channel, node1, node2, + ) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to build channel "+ + "info: %w", err) + } + + dbPol1, dbPol2, err := extractChannelPolicies(row) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to extract channel "+ + "policies: %w", err) + } + + policy1, policy2, err := getAndBuildChanPolicies( + ctx, db, dbPol1, dbPol2, edge.ChannelID, node1, node2, + ) + if err != nil { + return nil, nil, nil, fmt.Errorf("unable to build channel "+ + "policies: %w", err) + } + + return edge, policy1, policy2, nil +} diff --git a/graph/db/sql_migration_test.go b/graph/db/sql_migration_test.go index 8e0df1c90..de07f81d0 100644 --- a/graph/db/sql_migration_test.go +++ b/graph/db/sql_migration_test.go @@ -4,10 +4,13 @@ package graphdb import ( "bytes" + "cmp" "context" "errors" "fmt" "image/color" + "math" + prand "math/rand" "net" "os" "path" @@ -18,6 +21,7 @@ import ( "github.com/btcsuite/btcd/btcec/v2" "github.com/btcsuite/btcd/chaincfg" + "github.com/btcsuite/btcd/wire" "github.com/btcsuite/btclog/v2" "github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/kvdb" @@ -36,6 +40,12 @@ var ( testSigBytes = testSig.Serialize() testExtraData = []byte{1, 1, 1, 2, 2, 2, 2} testEmptyFeatures = lnwire.EmptyFeatureVector() + testAuthProof = &models.ChannelAuthProof{ + NodeSig1Bytes: testSig.Serialize(), + NodeSig2Bytes: testSig.Serialize(), + BitcoinSig1Bytes: testSig.Serialize(), + BitcoinSig2Bytes: testSig.Serialize(), + } ) // TestMigrateGraphToSQL tests various deterministic cases that we want to test @@ -53,12 +63,24 @@ func TestMigrateGraphToSQL(t *testing.T) { switch obj := object.(type) { case *models.LightningNode: err = db.AddLightningNode(ctx, obj) + case *models.ChannelEdgeInfo: + err = db.AddChannelEdge(ctx, obj) + case *models.ChannelEdgePolicy: + _, _, err = db.UpdateEdgePolicy(ctx, obj) default: err = fmt.Errorf("unhandled object type: %T", obj) } require.NoError(t, err) } + var ( + chanID1 = prand.Uint64() + chanID2 = prand.Uint64() + + node1 = genPubKey(t) + node2 = genPubKey(t) + ) + tests := []struct { name string write func(t *testing.T, db *KVStore, object any) @@ -123,6 +145,78 @@ func TestMigrateGraphToSQL(t *testing.T) { srcNodeSet: true, }, }, + { + name: "channels and policies", + write: writeUpdate, + objects: []any{ + // A channel with unknown nodes. This will + // result in two shell nodes being created. + // - channel count += 1 + // - node count += 2 + makeTestChannel(t), + + // Insert some nodes. + // - node count += 1 + makeTestNode(t, func(n *models.LightningNode) { + n.PubKeyBytes = node1 + }), + // - node count += 1 + makeTestNode(t, func(n *models.LightningNode) { + n.PubKeyBytes = node2 + }), + + // A channel with known nodes. + // - channel count += 1 + makeTestChannel( + t, func(c *models.ChannelEdgeInfo) { + c.ChannelID = chanID1 + + c.NodeKey1Bytes = node1 + c.NodeKey2Bytes = node2 + }, + ), + + // Insert a channel with no auth proof, no + // extra opaque data, and empty features. + // Use known nodes. + // - channel count += 1 + makeTestChannel( + t, func(c *models.ChannelEdgeInfo) { + c.ChannelID = chanID2 + + c.NodeKey1Bytes = node1 + c.NodeKey2Bytes = node2 + + c.AuthProof = nil + c.ExtraOpaqueData = nil + c.Features = testEmptyFeatures + }, + ), + + // Now, insert a single update for the + // first channel. + // - channel policy count += 1 + makeTestPolicy(chanID1, node1, false), + + // Insert two updates for the second + // channel, one for each direction. + // - channel policy count += 1 + makeTestPolicy(chanID2, node1, false), + // This one also has no extra opaque data. + // - channel policy count += 1 + makeTestPolicy( + chanID2, node2, true, + func(p *models.ChannelEdgePolicy) { + p.ExtraOpaqueData = nil + }, + ), + }, + expGraphStats: graphStats{ + numNodes: 4, + numChannels: 3, + numPolicies: 3, + }, + }, } for _, test := range tests { @@ -155,8 +249,10 @@ func TestMigrateGraphToSQL(t *testing.T) { // graphStats holds expected statistics about the graph after migration. type graphStats struct { - numNodes int - srcNodeSet bool + numNodes int + srcNodeSet bool + numChannels int + numPolicies int } // assertInSync checks that the KVStore and SQLStore both contain the same @@ -174,6 +270,12 @@ func assertInSync(t *testing.T, kvDB *KVStore, sqlDB *SQLStore, sqlSourceNode := fetchSourceNode(t, sqlDB) require.Equal(t, stats.srcNodeSet, sqlSourceNode != nil) require.Equal(t, fetchSourceNode(t, kvDB), sqlSourceNode) + + // 3) Compare the channels and policies in the two stores. + sqlChannels := fetchAllChannelsAndPolicies(t, sqlDB) + require.Len(t, sqlChannels, stats.numChannels) + require.Equal(t, stats.numPolicies, sqlChannels.CountPolicies()) + require.Equal(t, fetchAllChannelsAndPolicies(t, kvDB), sqlChannels) } // fetchAllNodes retrieves all nodes from the given store and returns them @@ -215,6 +317,67 @@ func fetchSourceNode(t *testing.T, store V1Store) *models.LightningNode { return node } +// chanInfo holds information about a channel, including its edge info +// and the policies for both directions. +type chanInfo struct { + edgeInfo *models.ChannelEdgeInfo + policy1 *models.ChannelEdgePolicy + policy2 *models.ChannelEdgePolicy +} + +// chanSet is a slice of chanInfo +type chanSet []chanInfo + +// CountPolicies counts the total number of policies in the channel set. +func (c chanSet) CountPolicies() int { + var count int + for _, info := range c { + if info.policy1 != nil { + count++ + } + if info.policy2 != nil { + count++ + } + } + return count +} + +// fetchAllChannelsAndPolicies retrieves all channels and their policies +// from the given store and returns them sorted by their channel ID. +func fetchAllChannelsAndPolicies(t *testing.T, store V1Store) chanSet { + channels := make(chanSet, 0) + err := store.ForEachChannel(func(info *models.ChannelEdgeInfo, + p1 *models.ChannelEdgePolicy, + p2 *models.ChannelEdgePolicy) error { + + if len(info.ExtraOpaqueData) == 0 { + info.ExtraOpaqueData = nil + } + if p1 != nil && len(p1.ExtraOpaqueData) == 0 { + p1.ExtraOpaqueData = nil + } + if p2 != nil && len(p2.ExtraOpaqueData) == 0 { + p2.ExtraOpaqueData = nil + } + + channels = append(channels, chanInfo{ + edgeInfo: info, + policy1: p1, + policy2: p2, + }) + + return nil + }) + require.NoError(t, err) + + // Sort the channels by their channel ID to ensure a consistent order. + slices.SortFunc(channels, func(i, j chanInfo) int { + return cmp.Compare(i.edgeInfo.ChannelID, j.edgeInfo.ChannelID) + }) + + return channels +} + // setUpKVStore initializes a new KVStore for testing. func setUpKVStore(t *testing.T) *KVStore { kvDB, cleanup, err := kvdb.GetTestBackend(t.TempDir(), "graph") @@ -293,6 +456,73 @@ func makeTestShellNode(t *testing.T, return n } +// modify the attributes of a models.ChannelEdgeInfo created by makeTestChannel. +type testChanOpt func(info *models.ChannelEdgeInfo) + +// makeTestChannel creates a test models.ChannelEdgeInfo. The functional options +// can be used to modify the channel's attributes. +func makeTestChannel(t *testing.T, + opts ...testChanOpt) *models.ChannelEdgeInfo { + + c := &models.ChannelEdgeInfo{ + ChannelID: prand.Uint64(), + ChainHash: testChain, + NodeKey1Bytes: genPubKey(t), + NodeKey2Bytes: genPubKey(t), + BitcoinKey1Bytes: genPubKey(t), + BitcoinKey2Bytes: genPubKey(t), + Features: testFeatures, + AuthProof: testAuthProof, + ChannelPoint: wire.OutPoint{ + Hash: rev, + Index: prand.Uint32(), + }, + Capacity: 10000, + ExtraOpaqueData: testExtraData, + } + + for _, opt := range opts { + opt(c) + } + + return c +} + +// testPolicyOpt defines a functional option type that can be used to modify the +// attributes of a models.ChannelEdgePolicy created by makeTestPolicy. +type testPolicyOpt func(*models.ChannelEdgePolicy) + +// makeTestPolicy creates a test models.ChannelEdgePolicy. The functional +// options can be used to modify the policy's attributes. +func makeTestPolicy(chanID uint64, toNode route.Vertex, isNode1 bool, + opts ...testPolicyOpt) *models.ChannelEdgePolicy { + + chanFlags := lnwire.ChanUpdateChanFlags(1) + if isNode1 { + chanFlags = 0 + } + + p := &models.ChannelEdgePolicy{ + SigBytes: testSigBytes, + ChannelID: chanID, + LastUpdate: nextUpdateTime(), + MessageFlags: 1, + ChannelFlags: chanFlags, + TimeLockDelta: math.MaxUint16, + MinHTLC: math.MaxUint64, + MaxHTLC: math.MaxUint64, + FeeBaseMSat: math.MaxUint64, + FeeProportionalMillionths: math.MaxUint64, + ToNode: toNode, + } + + for _, opt := range opts { + opt(p) + } + + return p +} + // TestMigrationWithChannelDB tests the migration of the graph store from a // bolt backed channel.db or a kvdb channel.sqlite to a SQL database. Note that // this test does not attempt to be a complete migration test for all graph @@ -425,6 +655,8 @@ func TestSQLMigrationEdgeCases(t *testing.T) { // with invalid TLV data, the migration will still succeed, but the // node will not end up in the SQL store. t.Run("node with bad tlv data", func(t *testing.T) { + t.Parallel() + // Make one valid node and one node with invalid TLV data. n1 := makeTestNode(t) n2 := makeTestNode(t, func(n *models.LightningNode) { @@ -443,6 +675,197 @@ func TestSQLMigrationEdgeCases(t *testing.T) { nodes: []*models.LightningNode{n1}, }) }) + + // Here, we test that in the case where the KV store contains a channel + // with invalid TLV data, the migration will still succeed, but the + // channel and its policies will not end up in the SQL store. + t.Run("channel with bad tlv data", func(t *testing.T) { + t.Parallel() + + // Make two valid nodes to point to. + n1 := makeTestNode(t) + n2 := makeTestNode(t) + + // Create two channels between these nodes, one valid one + // and one with invalid TLV data. + c1 := makeTestChannel(t, func(c *models.ChannelEdgeInfo) { + c.NodeKey1Bytes = n1.PubKeyBytes + c.NodeKey2Bytes = n2.PubKeyBytes + }) + c2 := makeTestChannel(t, func(c *models.ChannelEdgeInfo) { + c.NodeKey1Bytes = n1.PubKeyBytes + c.NodeKey2Bytes = n2.PubKeyBytes + c.ExtraOpaqueData = invalidTLVData + }) + + // Create policies for both channels. + p1 := makeTestPolicy(c1.ChannelID, n2.PubKeyBytes, true) + p2 := makeTestPolicy(c2.ChannelID, n1.PubKeyBytes, false) + + populateKV := func(t *testing.T, db *KVStore) { + // Insert both nodes into the KV store. + require.NoError(t, db.AddLightningNode(ctx, n1)) + require.NoError(t, db.AddLightningNode(ctx, n2)) + + // Insert both channels into the KV store. + require.NoError(t, db.AddChannelEdge(ctx, c1)) + require.NoError(t, db.AddChannelEdge(ctx, c2)) + + // Insert policies for both channels. + _, _, err := db.UpdateEdgePolicy(ctx, p1) + require.NoError(t, err) + _, _, err = db.UpdateEdgePolicy(ctx, p2) + require.NoError(t, err) + } + + runTestMigration(t, populateKV, dbState{ + // Both nodes will be present. + nodes: []*models.LightningNode{n1, n2}, + // We only expect the first channel and its policy to + // be present in the SQL db. + chans: chanSet{{ + edgeInfo: c1, + policy1: p1, + }}, + }) + }) + + // Here, we test that in the case where the KV store contains a + // channel policy with invalid TLV data, the migration will still + // succeed, but the channel policy will not end up in the SQL store. + t.Run("channel policy with bad tlv data", func(t *testing.T) { + t.Parallel() + + // Make two valid nodes to point to. + n1 := makeTestNode(t) + n2 := makeTestNode(t) + + // Create one valid channels between these nodes. + c := makeTestChannel(t, func(c *models.ChannelEdgeInfo) { + c.NodeKey1Bytes = n1.PubKeyBytes + c.NodeKey2Bytes = n2.PubKeyBytes + }) + + // Now, create two policies for this channel, one valid one + // and one with invalid TLV data. + p1 := makeTestPolicy(c.ChannelID, n2.PubKeyBytes, true) + p2 := makeTestPolicy( + c.ChannelID, n1.PubKeyBytes, false, + func(p *models.ChannelEdgePolicy) { + p.ExtraOpaqueData = invalidTLVData + }, + ) + + populateKV := func(t *testing.T, db *KVStore) { + // Insert both nodes into the KV store. + require.NoError(t, db.AddLightningNode(ctx, n1)) + require.NoError(t, db.AddLightningNode(ctx, n2)) + + // Insert the channel into the KV store. + require.NoError(t, db.AddChannelEdge(ctx, c)) + + // Insert policies for the channel. + _, _, err := db.UpdateEdgePolicy(ctx, p1) + require.NoError(t, err) + + // We need to write this invalid one with the + // updateEdgePolicy helper function in order to bypass + // the newly added TLV validation in the + // UpdateEdgePolicy method of the KVStore. + err = db.db.Update(func(tx kvdb.RwTx) error { + _, _, _, err := updateEdgePolicy(tx, p2) + return err + }, func() {}) + require.NoError(t, err) + } + + runTestMigration(t, populateKV, dbState{ + // Both nodes will be present. + nodes: []*models.LightningNode{n1, n2}, + // The channel will be present, but only the + // valid policy will be included in the SQL db. + chans: chanSet{{ + edgeInfo: c, + policy1: p1, + }}, + }) + }) + + // Here, we test that in the case where the KV store contains a + // channel policy that has a bit indicating that it contains a max HTLC + // field, but the field is missing. The migration will still succeed, + // but the policy will not end up in the SQL store. + t.Run("channel policy with missing max htlc", func(t *testing.T) { + t.Parallel() + + // Make two valid nodes to point to. + n1 := makeTestNode(t) + n2 := makeTestNode(t) + + // Create one valid channels between these nodes. + c := makeTestChannel(t, func(c *models.ChannelEdgeInfo) { + c.NodeKey1Bytes = n1.PubKeyBytes + c.NodeKey2Bytes = n2.PubKeyBytes + }) + + // Now, create two policies for this channel, one valid one + // and one with an invalid max htlc field. + p1 := makeTestPolicy(c.ChannelID, n2.PubKeyBytes, true) + p2 := makeTestPolicy(c.ChannelID, n1.PubKeyBytes, false) + + // We'll remove the no max_htlc field from the first edge + // policy, and all other opaque data, and serialize it. + p2.MessageFlags = 0 + p2.ExtraOpaqueData = nil + + var b bytes.Buffer + require.NoError(t, serializeChanEdgePolicy( + &b, p2, n1.PubKeyBytes[:], + )) + + // Set the max_htlc field. The extra bytes added to the + // serialization will be the opaque data containing the + // serialized field. + p2.MessageFlags = lnwire.ChanUpdateRequiredMaxHtlc + p2.MaxHTLC = math.MaxUint64 + var b2 bytes.Buffer + require.NoError(t, serializeChanEdgePolicy( + &b2, p2, n1.PubKeyBytes[:], + )) + withMaxHtlc := b2.Bytes() + + // Remove the opaque data from the serialization. + stripped := withMaxHtlc[:len(b.Bytes())] + + populateKV := func(t *testing.T, db *KVStore) { + // Insert both nodes into the KV store. + require.NoError(t, db.AddLightningNode(ctx, n1)) + require.NoError(t, db.AddLightningNode(ctx, n2)) + + // Insert the channel into the KV store. + require.NoError(t, db.AddChannelEdge(ctx, c)) + + // Insert policies for the channel. + _, _, err := db.UpdateEdgePolicy(ctx, p1) + require.NoError(t, err) + + putSerializedPolicy( + t, db.db, n2.PubKeyBytes[:], c.ChannelID, + stripped, + ) + } + + runTestMigration(t, populateKV, dbState{ + // Both nodes will be present. + nodes: []*models.LightningNode{n1, n2}, + // The channel will be present, but only the + // valid policy will be included in the SQL db. + chans: chanSet{{ + edgeInfo: c, + policy1: p1, + }}, + }) + }) } // runTestMigration is a helper function that sets up the KVStore and SQLStore, @@ -475,6 +898,7 @@ func runTestMigration(t *testing.T, populateKV func(t *testing.T, db *KVStore), // dbState describes the expected state of the SQLStore after a migration. type dbState struct { nodes []*models.LightningNode + chans chanSet } // assertResultState asserts that the SQLStore contains the expected @@ -482,4 +906,7 @@ type dbState struct { func assertResultState(t *testing.T, sql *SQLStore, expState dbState) { // Assert that the sql store contains the expected nodes. require.ElementsMatch(t, expState.nodes, fetchAllNodes(t, sql)) + require.ElementsMatch( + t, expState.chans, fetchAllChannelsAndPolicies(t, sql), + ) }