sqldb+graph/db: impl DisconnectBlockAtHeight

Which lets us run `TestDisconnectBlockAtHeight` and
`TestStressTestChannelGraphAPI` against our SQL backends.
This commit is contained in:
Elle Mouton
2025-06-11 17:50:26 +02:00
parent 9dd0361ed0
commit e875183c4f
6 changed files with 195 additions and 2 deletions

View File

@@ -88,6 +88,7 @@ circuit. The indices are only available for forwarding events saved after v0.20.
* [6](https://github.com/lightningnetwork/lnd/pull/9936) * [6](https://github.com/lightningnetwork/lnd/pull/9936)
* [7](https://github.com/lightningnetwork/lnd/pull/9937) * [7](https://github.com/lightningnetwork/lnd/pull/9937)
* [8](https://github.com/lightningnetwork/lnd/pull/9938) * [8](https://github.com/lightningnetwork/lnd/pull/9938)
* [9](https://github.com/lightningnetwork/lnd/pull/9939)
## RPC Updates ## RPC Updates

View File

@@ -513,7 +513,7 @@ func TestDisconnectBlockAtHeight(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
graph := MakeTestGraph(t) graph := MakeTestGraphNew(t)
sourceNode := createTestVertex(t) sourceNode := createTestVertex(t)
if err := graph.SetSourceNode(ctx, sourceNode); err != nil { if err := graph.SetSourceNode(ctx, sourceNode); err != nil {
@@ -2440,7 +2440,7 @@ func TestStressTestChannelGraphAPI(t *testing.T) {
t.Parallel() t.Parallel()
ctx := context.Background() ctx := context.Background()
graph := MakeTestGraph(t) graph := MakeTestGraphNew(t)
node1 := createTestVertex(t) node1 := createTestVertex(t)
require.NoError(t, graph.AddLightningNode(ctx, node1)) require.NoError(t, graph.AddLightningNode(ctx, node1))
@@ -2448,6 +2448,10 @@ func TestStressTestChannelGraphAPI(t *testing.T) {
node2 := createTestVertex(t) node2 := createTestVertex(t)
require.NoError(t, graph.AddLightningNode(ctx, node2)) require.NoError(t, graph.AddLightningNode(ctx, node2))
// We need to update the node's timestamp since this call to
// SetSourceNode will trigger an upsert which will only be allowed if
// the newest LastUpdate time is greater than the current one.
node1.LastUpdate = node1.LastUpdate.Add(time.Second)
require.NoError(t, graph.SetSourceNode(ctx, node1)) require.NoError(t, graph.SetSourceNode(ctx, node1))
type chanInfo struct { type chanInfo struct {

View File

@@ -19,6 +19,7 @@ import (
"github.com/btcsuite/btcd/btcutil" "github.com/btcsuite/btcd/btcutil"
"github.com/btcsuite/btcd/chaincfg/chainhash" "github.com/btcsuite/btcd/chaincfg/chainhash"
"github.com/btcsuite/btcd/wire" "github.com/btcsuite/btcd/wire"
"github.com/lightningnetwork/lnd/aliasmgr"
"github.com/lightningnetwork/lnd/batch" "github.com/lightningnetwork/lnd/batch"
"github.com/lightningnetwork/lnd/fn/v2" "github.com/lightningnetwork/lnd/fn/v2"
"github.com/lightningnetwork/lnd/graph/db/models" "github.com/lightningnetwork/lnd/graph/db/models"
@@ -92,6 +93,7 @@ type SQLQueries interface {
CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error) CreateChannel(ctx context.Context, arg sqlc.CreateChannelParams) (int64, error)
GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error) GetChannelBySCID(ctx context.Context, arg sqlc.GetChannelBySCIDParams) (sqlc.Channel, error)
GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error) GetChannelByOutpoint(ctx context.Context, outpoint string) (sqlc.GetChannelByOutpointRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg sqlc.GetChannelsBySCIDRangeParams) ([]sqlc.GetChannelsBySCIDRangeRow, error)
GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error) GetChannelBySCIDWithPolicies(ctx context.Context, arg sqlc.GetChannelBySCIDWithPoliciesParams) (sqlc.GetChannelBySCIDWithPoliciesRow, error)
GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error) GetChannelAndNodesBySCID(ctx context.Context, arg sqlc.GetChannelAndNodesBySCIDParams) (sqlc.GetChannelAndNodesBySCIDRow, error)
GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error) GetChannelFeaturesAndExtras(ctx context.Context, channelID int64) ([]sqlc.GetChannelFeaturesAndExtrasRow, error)
@@ -133,6 +135,7 @@ type SQLQueries interface {
*/ */
GetPruneTip(ctx context.Context) (sqlc.PruneLog, error) GetPruneTip(ctx context.Context) (sqlc.PruneLog, error)
UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error UpsertPruneLogEntry(ctx context.Context, arg sqlc.UpsertPruneLogEntryParams) error
DeletePruneLogEntriesInRange(ctx context.Context, arg sqlc.DeletePruneLogEntriesInRangeParams) error
} }
// BatchedSQLQueries is a version of SQLQueries that's capable of batched // BatchedSQLQueries is a version of SQLQueries that's capable of batched
@@ -2481,6 +2484,97 @@ func (s *SQLStore) pruneGraphNodes(ctx context.Context,
return prunedNodes, nil return prunedNodes, nil
} }
// DisconnectBlockAtHeight is used to indicate that the block specified
// by the passed height has been disconnected from the main chain. This
// will "rewind" the graph back to the height below, deleting channels
// that are no longer confirmed from the graph. The prune log will be
// set to the last prune height valid for the remaining chain.
// Channels that were removed from the graph resulting from the
// disconnected block are returned.
//
// NOTE: part of the V1Store interface.
func (s *SQLStore) DisconnectBlockAtHeight(height uint32) (
[]*models.ChannelEdgeInfo, error) {
ctx := context.TODO()
var (
// Every channel having a ShortChannelID starting at 'height'
// will no longer be confirmed.
startShortChanID = lnwire.ShortChannelID{
BlockHeight: height,
}
// Delete everything after this height from the db up until the
// SCID alias range.
endShortChanID = aliasmgr.StartingAlias
removedChans []*models.ChannelEdgeInfo
)
var chanIDStart [8]byte
byteOrder.PutUint64(chanIDStart[:], startShortChanID.ToUint64())
var chanIDEnd [8]byte
byteOrder.PutUint64(chanIDEnd[:], endShortChanID.ToUint64())
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
rows, err := db.GetChannelsBySCIDRange(
ctx, sqlc.GetChannelsBySCIDRangeParams{
StartScid: chanIDStart[:],
EndScid: chanIDEnd[:],
},
)
if err != nil {
return fmt.Errorf("unable to fetch channels: %w", err)
}
for _, row := range rows {
node1, node2, err := buildNodeVertices(
row.Node1PubKey, row.Node2PubKey,
)
if err != nil {
return err
}
channel, err := getAndBuildEdgeInfo(
ctx, db, s.cfg.ChainHash, row.Channel.ID,
row.Channel, node1, node2,
)
if err != nil {
return err
}
err = db.DeleteChannel(ctx, row.Channel.ID)
if err != nil {
return fmt.Errorf("unable to delete "+
"channel: %w", err)
}
removedChans = append(removedChans, channel)
}
return db.DeletePruneLogEntriesInRange(
ctx, sqlc.DeletePruneLogEntriesInRangeParams{
StartHeight: int64(height),
EndHeight: int64(endShortChanID.BlockHeight),
},
)
}, func() {
removedChans = nil
})
if err != nil {
return nil, fmt.Errorf("unable to disconnect block at "+
"height: %w", err)
}
for _, channel := range removedChans {
s.rejectCache.remove(channel.ChannelID)
s.chanCache.remove(channel.ChannelID)
}
return removedChans, nil
}
// forEachNodeDirectedChannel iterates through all channels of a given // forEachNodeDirectedChannel iterates through all channels of a given
// node, executing the passed callback on the directed edge representing the // node, executing the passed callback on the directed edge representing the
// channel and its incoming policy. If the node is not found, no error is // channel and its incoming policy. If the node is not found, no error is

View File

@@ -200,6 +200,22 @@ func (q *Queries) DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeaturePa
return err return err
} }
const deletePruneLogEntriesInRange = `-- name: DeletePruneLogEntriesInRange :exec
DELETE FROM prune_log
WHERE block_height >= $1
AND block_height <= $2
`
type DeletePruneLogEntriesInRangeParams struct {
StartHeight int64
EndHeight int64
}
func (q *Queries) DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error {
_, err := q.db.ExecContext(ctx, deletePruneLogEntriesInRange, arg.StartHeight, arg.EndHeight)
return err
}
const deleteZombieChannel = `-- name: DeleteZombieChannel :execresult const deleteZombieChannel = `-- name: DeleteZombieChannel :execresult
DELETE FROM zombie_channels DELETE FROM zombie_channels
WHERE scid = $1 WHERE scid = $1
@@ -943,6 +959,67 @@ func (q *Queries) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg Ge
return items, nil return items, nil
} }
const getChannelsBySCIDRange = `-- name: GetChannelsBySCIDRange :many
SELECT c.id, c.version, c.scid, c.node_id_1, c.node_id_2, c.outpoint, c.capacity, c.bitcoin_key_1, c.bitcoin_key_2, c.node_1_signature, c.node_2_signature, c.bitcoin_1_signature, c.bitcoin_2_signature,
n1.pub_key AS node1_pub_key,
n2.pub_key AS node2_pub_key
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE scid >= $1
AND scid < $2
`
type GetChannelsBySCIDRangeParams struct {
StartScid []byte
EndScid []byte
}
type GetChannelsBySCIDRangeRow struct {
Channel Channel
Node1PubKey []byte
Node2PubKey []byte
}
func (q *Queries) GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error) {
rows, err := q.db.QueryContext(ctx, getChannelsBySCIDRange, arg.StartScid, arg.EndScid)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChannelsBySCIDRangeRow
for rows.Next() {
var i GetChannelsBySCIDRangeRow
if err := rows.Scan(
&i.Channel.ID,
&i.Channel.Version,
&i.Channel.Scid,
&i.Channel.NodeID1,
&i.Channel.NodeID2,
&i.Channel.Outpoint,
&i.Channel.Capacity,
&i.Channel.BitcoinKey1,
&i.Channel.BitcoinKey2,
&i.Channel.Node1Signature,
&i.Channel.Node2Signature,
&i.Channel.Bitcoin1Signature,
&i.Channel.Bitcoin2Signature,
&i.Node1PubKey,
&i.Node2PubKey,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many const getExtraNodeTypes = `-- name: GetExtraNodeTypes :many
SELECT node_id, type, value SELECT node_id, type, value
FROM node_extra_types FROM node_extra_types

View File

@@ -25,6 +25,7 @@ type Querier interface {
DeleteNodeAddresses(ctx context.Context, nodeID int64) error DeleteNodeAddresses(ctx context.Context, nodeID int64) error
DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error) DeleteNodeByPubKey(ctx context.Context, arg DeleteNodeByPubKeyParams) (sql.Result, error)
DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error DeleteNodeFeature(ctx context.Context, arg DeleteNodeFeatureParams) error
DeletePruneLogEntriesInRange(ctx context.Context, arg DeletePruneLogEntriesInRangeParams) error
DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error) DeleteZombieChannel(ctx context.Context, arg DeleteZombieChannelParams) (sql.Result, error)
FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error) FetchAMPSubInvoiceHTLCs(ctx context.Context, arg FetchAMPSubInvoiceHTLCsParams) ([]FetchAMPSubInvoiceHTLCsRow, error)
FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error) FetchAMPSubInvoices(ctx context.Context, arg FetchAMPSubInvoicesParams) ([]AmpSubInvoice, error)
@@ -40,6 +41,7 @@ type Querier interface {
GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error) GetChannelPolicyByChannelAndNode(ctx context.Context, arg GetChannelPolicyByChannelAndNodeParams) (ChannelPolicy, error)
GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error) GetChannelPolicyExtraTypes(ctx context.Context, arg GetChannelPolicyExtraTypesParams) ([]GetChannelPolicyExtraTypesRow, error)
GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error) GetChannelsByPolicyLastUpdateRange(ctx context.Context, arg GetChannelsByPolicyLastUpdateRangeParams) ([]GetChannelsByPolicyLastUpdateRangeRow, error)
GetChannelsBySCIDRange(ctx context.Context, arg GetChannelsBySCIDRangeParams) ([]GetChannelsBySCIDRangeRow, error)
GetDatabaseVersion(ctx context.Context) (int32, error) GetDatabaseVersion(ctx context.Context) (int32, error)
GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error) GetExtraNodeTypes(ctx context.Context, nodeID int64) ([]NodeExtraType, error)
// This method may return more than one invoice if filter using multiple fields // This method may return more than one invoice if filter using multiple fields

View File

@@ -208,6 +208,16 @@ INSERT INTO channels (
) )
RETURNING id; RETURNING id;
-- name: GetChannelsBySCIDRange :many
SELECT sqlc.embed(c),
n1.pub_key AS node1_pub_key,
n2.pub_key AS node2_pub_key
FROM channels c
JOIN nodes n1 ON c.node_id_1 = n1.id
JOIN nodes n2 ON c.node_id_2 = n2.id
WHERE scid >= @start_scid
AND scid < @end_scid;
-- name: GetChannelBySCID :one -- name: GetChannelBySCID :one
SELECT * FROM channels SELECT * FROM channels
WHERE scid = $1 AND version = $2; WHERE scid = $1 AND version = $2;
@@ -685,3 +695,8 @@ SELECT block_height, block_hash
FROM prune_log FROM prune_log
ORDER BY block_height DESC ORDER BY block_height DESC
LIMIT 1; LIMIT 1;
-- name: DeletePruneLogEntriesInRange :exec
DELETE FROM prune_log
WHERE block_height >= @start_height
AND block_height <= @end_height;