From e875183c4f01a304e59fd5d7dbcad0c01b4370af Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Wed, 11 Jun 2025 17:50:26 +0200 Subject: [PATCH] sqldb+graph/db: impl DisconnectBlockAtHeight Which lets us run `TestDisconnectBlockAtHeight` and `TestStressTestChannelGraphAPI` against our SQL backends. --- docs/release-notes/release-notes-0.20.0.md | 1 + graph/db/graph_test.go | 8 +- graph/db/sql_store.go | 94 ++++++++++++++++++++++ sqldb/sqlc/graph.sql.go | 77 ++++++++++++++++++ sqldb/sqlc/querier.go | 2 + sqldb/sqlc/queries/graph.sql | 15 ++++ 6 files changed, 195 insertions(+), 2 deletions(-) diff --git a/docs/release-notes/release-notes-0.20.0.md b/docs/release-notes/release-notes-0.20.0.md index 040290f33..19cff1f38 100644 --- a/docs/release-notes/release-notes-0.20.0.md +++ b/docs/release-notes/release-notes-0.20.0.md @@ -88,6 +88,7 @@ circuit. The indices are only available for forwarding events saved after v0.20. * [6](https://github.com/lightningnetwork/lnd/pull/9936) * [7](https://github.com/lightningnetwork/lnd/pull/9937) * [8](https://github.com/lightningnetwork/lnd/pull/9938) + * [9](https://github.com/lightningnetwork/lnd/pull/9939) ## RPC Updates diff --git a/graph/db/graph_test.go b/graph/db/graph_test.go index fb2d2b72e..419c96bc3 100644 --- a/graph/db/graph_test.go +++ b/graph/db/graph_test.go @@ -513,7 +513,7 @@ func TestDisconnectBlockAtHeight(t *testing.T) { t.Parallel() ctx := context.Background() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) sourceNode := createTestVertex(t) if err := graph.SetSourceNode(ctx, sourceNode); err != nil { @@ -2440,7 +2440,7 @@ func TestStressTestChannelGraphAPI(t *testing.T) { t.Parallel() ctx := context.Background() - graph := MakeTestGraph(t) + graph := MakeTestGraphNew(t) node1 := createTestVertex(t) require.NoError(t, graph.AddLightningNode(ctx, node1)) @@ -2448,6 +2448,10 @@ func TestStressTestChannelGraphAPI(t *testing.T) { node2 := createTestVertex(t) require.NoError(t, graph.AddLightningNode(ctx, node2)) + // We need to update the node's timestamp since this call to + // SetSourceNode will trigger an upsert which will only be allowed if + // the newest LastUpdate time is greater than the current one. + node1.LastUpdate = node1.LastUpdate.Add(time.Second) require.NoError(t, graph.SetSourceNode(ctx, node1)) type chanInfo struct { diff --git a/graph/db/sql_store.go b/graph/db/sql_store.go index e03ed273f..a6c67948c 100644 --- a/graph/db/sql_store.go +++ b/graph/db/sql_store.go @@ -19,6 +19,7 @@ import ( "github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/wire" + "github.com/lightningnetwork/lnd/aliasmgr" "github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/graph/db/models" @@ -92,6 +93,7 @@ type SQLQueries interface { CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error) + GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) @@ -133,6 +135,7 @@ type SQLQueries interface { */ GetPruneTip(ctx context.Context) (sqlc.PruneLog, error) UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error + DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error } // BatchedSQLQueries is a version of SQLQueries that's capable of batched @@ -2481,6 +2484,97 @@ func (s *SQLStore) pruneGraphNodes(ctx context.Context, return prunedNodes, nil } +// DisconnectBlockAtHeight is used to indicate that the block specified +// by the passed height has been disconnected from the main chain. This +// will "rewind" the graph back to the height below, deleting channels +// that are no longer confirmed from the graph. The prune log will be +// set to the last prune height valid for the remaining chain. +// Channels that were removed from the graph resulting from the +// disconnected block are returned. +// +// NOTE: part of the V1Store interface. +func (s *SQLStore) DisconnectBlockAtHeight(height uint32) ( + []*models.ChannelEdgeInfo, error) { + + ctx := context.TODO() + + var ( + // Every channel having a ShortChannelID starting at 'height' + // will no longer be confirmed. + startShortChanID = lnwire.ShortChannelID{ + BlockHeight: height, + } + + // Delete everything after this height from the db up until the + // SCID alias range. + endShortChanID = aliasmgr.StartingAlias + + removedChans []*models.ChannelEdgeInfo + ) + + var chanIDStart [8]byte + byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64()) + var chanIDEnd [8]byte + byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64()) + + err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error { + rows, err := db.GetChannelsBySCIDRange( + ctx, sqlc.GetChannelsBySCIDRangeParams{ + StartScid: chanIDStart[:], + EndScid: chanIDEnd[:], + }, + ) + if err != nil { + return fmt.Errorf("unable to fetch channels: %w", err) + } + + for _, row := range rows { + node1, node2, err := buildNodeVertices( + row.Node1PubKey, row.Node2PubKey, + ) + if err != nil { + return err + } + + channel, err := getAndBuildEdgeInfo( + ctx, db, s.cfg.ChainHash, row.Channel.ID, + row.Channel, node1, node2, + ) + if err != nil { + return err + } + + err = db.DeleteChannel(ctx, row.Channel.ID) + if err != nil { + return fmt.Errorf("unable to delete "+ + "channel: %w", err) + } + + removedChans = append(removedChans, channel) + } + + return db.DeletePruneLogEntriesInRange( + ctx, sqlc.DeletePruneLogEntriesInRangeParams{ + StartHeight: int64(height), + EndHeight: int64(endShortChanID.BlockHeight), + }, + ) + }, func() { + removedChans = nil + }) + if err != nil { + return nil, fmt.Errorf("unable to disconnect block at "+ + "height: %w", err) + } + + for _, channel := range removedChans { + s.rejectCache.remove(channel.ChannelID) + s.chanCache.remove(channel.ChannelID) + } + + return removedChans, nil +} + // forEachNodeDirectedChannel iterates through all channels of a given // node, executing the passed callback on the directed edge representing the // channel and its incoming policy. If the node is not found, no error is diff --git a/sqldb/sqlc/graph.sql.go b/sqldb/sqlc/graph.sql.go index d774a12f7..c74b4c050 100644 --- a/sqldb/sqlc/graph.sql.go +++ b/sqldb/sqlc/graph.sql.go @@ -200,6 +200,22 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa return err } +const deletePruneLogEntriesInRange = `-- name: DeletePruneLogEntriesInRange :exec +DELETE FROM prune_log +WHERE block_height >= $1 + AND block_height <= $2 +` + +type DeletePruneLogEntriesInRangeParams struct { + StartHeight int64 + EndHeight int64 +} + +func (q *Queries) DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error { + _, err := q.db.ExecContext(ctx, deletePruneLogEntriesInRange, arg.StartHeight, arg.EndHeight) + return err +} + const deleteZombieChannel = `-- name: DeleteZombieChannel :execresult DELETE FROM zombie_channels WHERE scid = $1 @@ -943,6 +959,67 @@ func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg Ge return items, nil } +const getChannelsBySCIDRange = `-- name: GetChannelsBySCIDRange :many +SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature, + n1.pub_key AS node1_pub_key, + n2.pub_key AS node2_pub_key +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id +WHERE scid >= $1 + AND scid < $2 +` + +type GetChannelsBySCIDRangeParams struct { + StartScid []byte + EndScid []byte +} + +type GetChannelsBySCIDRangeRow struct { + Channel Channel + Node1PubKey []byte + Node2PubKey []byte +} + +func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) { + rows, err := q.db.QueryContext(ctx, getChannelsBySCIDRange, arg.StartScid, arg.EndScid) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChannelsBySCIDRangeRow + for rows.Next() { + var i GetChannelsBySCIDRangeRow + if err := rows.Scan( + &i.Channel.ID, + &i.Channel.Version, + &i.Channel.Scid, + &i.Channel.NodeID1, + &i.Channel.NodeID2, + &i.Channel.Outpoint, + &i.Channel.Capacity, + &i.Channel.BitcoinKey1, + &i.Channel.BitcoinKey2, + &i.Channel.Node1Signature, + &i.Channel.Node2Signature, + &i.Channel.Bitcoin1Signature, + &i.Channel.Bitcoin2Signature, + &i.Node1PubKey, + &i.Node2PubKey, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many SELECT node_id, type, value FROM node_extra_types diff --git a/sqldb/sqlc/querier.go b/sqldb/sqlc/querier.go index fc9b2128f..f0e1ad3e7 100644 --- a/sqldb/sqlc/querier.go +++ b/sqldb/sqlc/querier.go @@ -25,6 +25,7 @@ type Querier interface { DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error + DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) @@ -40,6 +41,7 @@ type Querier interface { GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) + GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) GetDatabaseVersion(ctx context.Context) (int32, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) // This method may return more than one invoice if filter using multiple fields diff --git a/sqldb/sqlc/queries/graph.sql b/sqldb/sqlc/queries/graph.sql index 412c4fd15..67b5dae6b 100644 --- a/sqldb/sqlc/queries/graph.sql +++ b/sqldb/sqlc/queries/graph.sql @@ -208,6 +208,16 @@ INSERT INTO channels ( ) RETURNING id; +-- name: GetChannelsBySCIDRange :many +SELECT sqlc.embed(c), + n1.pub_key AS node1_pub_key, + n2.pub_key AS node2_pub_key +FROM channels c + JOIN nodes n1 ON c.node_id_1 = n1.id + JOIN nodes n2 ON c.node_id_2 = n2.id +WHERE scid >= @start_scid + AND scid < @end_scid; + -- name: GetChannelBySCID :one SELECT * FROM channels WHERE scid = $1 AND version = $2; @@ -685,3 +695,8 @@ SELECT block_height, block_hash FROM prune_log ORDER BY block_height DESC LIMIT 1; + +-- name: DeletePruneLogEntriesInRange :exec +DELETE FROM prune_log +WHERE block_height >= @start_height + AND block_height <= @end_height;